Coverage for openhcs/processing/backends/pos_gen/mist_processor_cupy.py: 18.3%
59 statements
« prev ^ index » next coverage.py v7.10.3, created at 2025-08-14 05:57 +0000
« prev ^ index » next coverage.py v7.10.3, created at 2025-08-14 05:57 +0000
1"""
2MIST (Microscopy Image Stitching Tool) GPU Implementation
4This module provides GPU-accelerated MIST implementation using CuPy.
5All legacy functions have been moved to the modular implementation in the mist/ subfolder.
6"""
7from __future__ import annotations
9import logging
10from typing import TYPE_CHECKING, Tuple
12from openhcs.constants.constants import DEFAULT_PATCH_SIZE, DEFAULT_SEARCH_RADIUS
13from openhcs.core.utils import optional_import
15# For type checking only
16if TYPE_CHECKING: 16 ↛ 17line 16 didn't jump to line 17 because the condition on line 16 was never true
17 import cupy as cp
19# Import CuPy as an optional dependency
20cp = optional_import("cupy")
22logger = logging.getLogger(__name__)
25def _validate_cupy_array(array, name: str = "input") -> None: # type: ignore
26 """Validate that the input is a CuPy array."""
27 if not isinstance(array, cp.ndarray):
28 raise TypeError(f"{name} must be a CuPy array, got {type(array)}")
31def phase_correlation_gpu_only(
32 image1: "cp.ndarray", # type: ignore
33 image2: "cp.ndarray", # type: ignore
34 *,
35 window: bool = True,
36 subpixel: bool = True,
37 subpixel_radius: int = 3,
38 regularization_eps_multiplier: float = 1000.0
39) -> Tuple[float, float]:
40 """
41 Full GPU phase correlation with all operations on device.
43 Args:
44 image1: First image (CuPy array)
45 image2: Second image (CuPy array)
46 window: Apply Hann window
47 subpixel: Enable subpixel accuracy
48 subpixel_radius: Radius for subpixel interpolation
49 regularization_eps_multiplier: Multiplier for numerical stability
51 Returns:
52 (dy, dx) shift values
53 """
54 _validate_cupy_array(image1, "image1")
55 _validate_cupy_array(image2, "image2")
57 if image1.shape != image2.shape:
58 raise ValueError(f"Images must have the same shape, got {image1.shape} and {image2.shape}")
60 # Ensure float32 and remove DC component (all GPU operations)
61 img1 = image1.astype(cp.float32)
62 img2 = image2.astype(cp.float32)
64 img1 = img1 - cp.mean(img1)
65 img2 = img2 - cp.mean(img2)
67 # Apply Hann window (all GPU)
68 if window:
69 h, w = img1.shape
70 win_y = cp.hanning(h).reshape(-1, 1)
71 win_x = cp.hanning(w).reshape(1, -1)
72 window_2d = win_y * win_x
73 img1 = img1 * window_2d
74 img2 = img2 * window_2d
76 # FFT operations (GPU)
77 fft1 = cp.fft.fft2(img1)
78 fft2 = cp.fft.fft2(img2)
80 # Cross-power spectrum with configurable regularization (GPU)
81 cross_power = fft1 * cp.conj(fft2)
82 magnitude = cp.abs(cross_power)
83 eps = cp.finfo(cp.float32).eps * regularization_eps_multiplier
84 cross_power_norm = cross_power / (magnitude + eps)
86 # Inverse FFT (GPU)
87 correlation = cp.real(cp.fft.ifft2(cross_power_norm))
89 # Find peak (GPU)
90 peak_idx = cp.unravel_index(cp.argmax(correlation), correlation.shape)
91 y_peak = peak_idx[0] # Keep as CuPy scalar
92 x_peak = peak_idx[1] # Keep as CuPy scalar
94 # Convert to signed shifts (GPU arithmetic)
95 h, w = correlation.shape
96 dy = cp.where(y_peak <= h // 2, y_peak, y_peak - h)
97 dx = cp.where(x_peak <= w // 2, x_peak, x_peak - w)
99 # Subpixel refinement (all GPU)
100 if subpixel:
101 # Convert to int for indexing
102 y_peak_int = int(y_peak)
103 x_peak_int = int(x_peak)
105 y_min = cp.maximum(0, y_peak_int - subpixel_radius)
106 y_max = cp.minimum(h, y_peak_int + subpixel_radius + 1)
107 x_min = cp.maximum(0, x_peak_int - subpixel_radius)
108 x_max = cp.minimum(w, x_peak_int + subpixel_radius + 1)
110 region = correlation[y_min:y_max, x_min:x_max]
112 total_mass = cp.sum(region)
113 if total_mass > 0:
114 y_coords, x_coords = cp.mgrid[y_min:y_max, x_min:x_max]
115 y_com = cp.sum(y_coords * region) / total_mass
116 x_com = cp.sum(x_coords * region) / total_mass
118 dy = cp.where(y_com <= h // 2, y_com, y_com - h)
119 dx = cp.where(x_com <= w // 2, x_com, x_com - w)
121 return float(dy), float(dx)
124# Import the modular MIST implementation
125from .mist.mist_main import mist_compute_tile_positions
127# Re-export for backward compatibility
128__all__ = ['mist_compute_tile_positions', 'phase_correlation_gpu_only']