Coverage for openhcs/processing/backends/pos_gen/mist_processor_cupy.py: 17.1%
58 statements
« prev ^ index » next coverage.py v7.11.0, created at 2025-11-04 02:09 +0000
« prev ^ index » next coverage.py v7.11.0, created at 2025-11-04 02:09 +0000
1"""
2MIST (Microscopy Image Stitching Tool) GPU Implementation
4This module provides GPU-accelerated MIST implementation using CuPy.
5All legacy functions have been moved to the modular implementation in the mist/ subfolder.
6"""
7from __future__ import annotations
9import logging
10from typing import TYPE_CHECKING, Tuple
12from openhcs.core.utils import optional_import
14# For type checking only
15if TYPE_CHECKING: 15 ↛ 16line 15 didn't jump to line 16 because the condition on line 15 was never true
16 import cupy as cp
18# Import CuPy as an optional dependency
19cp = optional_import("cupy")
21logger = logging.getLogger(__name__)
24def _validate_cupy_array(array, name: str = "input") -> None: # type: ignore
25 """Validate that the input is a CuPy array."""
26 if not isinstance(array, cp.ndarray):
27 raise TypeError(f"{name} must be a CuPy array, got {type(array)}")
30def phase_correlation_gpu_only(
31 image1: "cp.ndarray", # type: ignore
32 image2: "cp.ndarray", # type: ignore
33 *,
34 window: bool = True,
35 subpixel: bool = True,
36 subpixel_radius: int = 3,
37 regularization_eps_multiplier: float = 1000.0
38) -> Tuple[float, float]:
39 """
40 Full GPU phase correlation with all operations on device.
42 Args:
43 image1: First image (CuPy array)
44 image2: Second image (CuPy array)
45 window: Apply Hann window
46 subpixel: Enable subpixel accuracy
47 subpixel_radius: Radius for subpixel interpolation
48 regularization_eps_multiplier: Multiplier for numerical stability
50 Returns:
51 (dy, dx) shift values
52 """
53 _validate_cupy_array(image1, "image1")
54 _validate_cupy_array(image2, "image2")
56 if image1.shape != image2.shape:
57 raise ValueError(f"Images must have the same shape, got {image1.shape} and {image2.shape}")
59 # Ensure float32 and remove DC component (all GPU operations)
60 img1 = image1.astype(cp.float32)
61 img2 = image2.astype(cp.float32)
63 img1 = img1 - cp.mean(img1)
64 img2 = img2 - cp.mean(img2)
66 # Apply Hann window (all GPU)
67 if window:
68 h, w = img1.shape
69 win_y = cp.hanning(h).reshape(-1, 1)
70 win_x = cp.hanning(w).reshape(1, -1)
71 window_2d = win_y * win_x
72 img1 = img1 * window_2d
73 img2 = img2 * window_2d
75 # FFT operations (GPU)
76 fft1 = cp.fft.fft2(img1)
77 fft2 = cp.fft.fft2(img2)
79 # Cross-power spectrum with configurable regularization (GPU)
80 cross_power = fft1 * cp.conj(fft2)
81 magnitude = cp.abs(cross_power)
82 eps = cp.finfo(cp.float32).eps * regularization_eps_multiplier
83 cross_power_norm = cross_power / (magnitude + eps)
85 # Inverse FFT (GPU)
86 correlation = cp.real(cp.fft.ifft2(cross_power_norm))
88 # Find peak (GPU)
89 peak_idx = cp.unravel_index(cp.argmax(correlation), correlation.shape)
90 y_peak = peak_idx[0] # Keep as CuPy scalar
91 x_peak = peak_idx[1] # Keep as CuPy scalar
93 # Convert to signed shifts (GPU arithmetic)
94 h, w = correlation.shape
95 dy = cp.where(y_peak <= h // 2, y_peak, y_peak - h)
96 dx = cp.where(x_peak <= w // 2, x_peak, x_peak - w)
98 # Subpixel refinement (all GPU)
99 if subpixel:
100 # Convert to int for indexing
101 y_peak_int = int(y_peak)
102 x_peak_int = int(x_peak)
104 y_min = cp.maximum(0, y_peak_int - subpixel_radius)
105 y_max = cp.minimum(h, y_peak_int + subpixel_radius + 1)
106 x_min = cp.maximum(0, x_peak_int - subpixel_radius)
107 x_max = cp.minimum(w, x_peak_int + subpixel_radius + 1)
109 region = correlation[y_min:y_max, x_min:x_max]
111 total_mass = cp.sum(region)
112 if total_mass > 0:
113 y_coords, x_coords = cp.mgrid[y_min:y_max, x_min:x_max]
114 y_com = cp.sum(y_coords * region) / total_mass
115 x_com = cp.sum(x_coords * region) / total_mass
117 dy = cp.where(y_com <= h // 2, y_com, y_com - h)
118 dx = cp.where(x_com <= w // 2, x_com, x_com - w)
120 return float(dy), float(dx)
123# Import the modular MIST implementation
124from .mist.mist_main import mist_compute_tile_positions
126# Re-export for backward compatibility
127__all__ = ['mist_compute_tile_positions', 'phase_correlation_gpu_only']