Coverage for openhcs/processing/backends/pos_gen/mist_processor_cupy.py: 18.3%

59 statements  

« prev     ^ index     » next       coverage.py v7.10.3, created at 2025-08-14 05:57 +0000

1""" 

2MIST (Microscopy Image Stitching Tool) GPU Implementation 

3 

4This module provides GPU-accelerated MIST implementation using CuPy. 

5All legacy functions have been moved to the modular implementation in the mist/ subfolder. 

6""" 

7from __future__ import annotations 

8 

9import logging 

10from typing import TYPE_CHECKING, Tuple 

11 

12from openhcs.constants.constants import DEFAULT_PATCH_SIZE, DEFAULT_SEARCH_RADIUS 

13from openhcs.core.utils import optional_import 

14 

15# For type checking only 

16if TYPE_CHECKING: 16 ↛ 17line 16 didn't jump to line 17 because the condition on line 16 was never true

17 import cupy as cp 

18 

19# Import CuPy as an optional dependency 

20cp = optional_import("cupy") 

21 

22logger = logging.getLogger(__name__) 

23 

24 

25def _validate_cupy_array(array, name: str = "input") -> None: # type: ignore 

26 """Validate that the input is a CuPy array.""" 

27 if not isinstance(array, cp.ndarray): 

28 raise TypeError(f"{name} must be a CuPy array, got {type(array)}") 

29 

30 

31def phase_correlation_gpu_only( 

32 image1: "cp.ndarray", # type: ignore 

33 image2: "cp.ndarray", # type: ignore 

34 *, 

35 window: bool = True, 

36 subpixel: bool = True, 

37 subpixel_radius: int = 3, 

38 regularization_eps_multiplier: float = 1000.0 

39) -> Tuple[float, float]: 

40 """ 

41 Full GPU phase correlation with all operations on device. 

42  

43 Args: 

44 image1: First image (CuPy array) 

45 image2: Second image (CuPy array) 

46 window: Apply Hann window 

47 subpixel: Enable subpixel accuracy 

48 subpixel_radius: Radius for subpixel interpolation 

49 regularization_eps_multiplier: Multiplier for numerical stability 

50  

51 Returns: 

52 (dy, dx) shift values 

53 """ 

54 _validate_cupy_array(image1, "image1") 

55 _validate_cupy_array(image2, "image2") 

56 

57 if image1.shape != image2.shape: 

58 raise ValueError(f"Images must have the same shape, got {image1.shape} and {image2.shape}") 

59 

60 # Ensure float32 and remove DC component (all GPU operations) 

61 img1 = image1.astype(cp.float32) 

62 img2 = image2.astype(cp.float32) 

63 

64 img1 = img1 - cp.mean(img1) 

65 img2 = img2 - cp.mean(img2) 

66 

67 # Apply Hann window (all GPU) 

68 if window: 

69 h, w = img1.shape 

70 win_y = cp.hanning(h).reshape(-1, 1) 

71 win_x = cp.hanning(w).reshape(1, -1) 

72 window_2d = win_y * win_x 

73 img1 = img1 * window_2d 

74 img2 = img2 * window_2d 

75 

76 # FFT operations (GPU) 

77 fft1 = cp.fft.fft2(img1) 

78 fft2 = cp.fft.fft2(img2) 

79 

80 # Cross-power spectrum with configurable regularization (GPU) 

81 cross_power = fft1 * cp.conj(fft2) 

82 magnitude = cp.abs(cross_power) 

83 eps = cp.finfo(cp.float32).eps * regularization_eps_multiplier 

84 cross_power_norm = cross_power / (magnitude + eps) 

85 

86 # Inverse FFT (GPU) 

87 correlation = cp.real(cp.fft.ifft2(cross_power_norm)) 

88 

89 # Find peak (GPU) 

90 peak_idx = cp.unravel_index(cp.argmax(correlation), correlation.shape) 

91 y_peak = peak_idx[0] # Keep as CuPy scalar 

92 x_peak = peak_idx[1] # Keep as CuPy scalar 

93 

94 # Convert to signed shifts (GPU arithmetic) 

95 h, w = correlation.shape 

96 dy = cp.where(y_peak <= h // 2, y_peak, y_peak - h) 

97 dx = cp.where(x_peak <= w // 2, x_peak, x_peak - w) 

98 

99 # Subpixel refinement (all GPU) 

100 if subpixel: 

101 # Convert to int for indexing 

102 y_peak_int = int(y_peak) 

103 x_peak_int = int(x_peak) 

104 

105 y_min = cp.maximum(0, y_peak_int - subpixel_radius) 

106 y_max = cp.minimum(h, y_peak_int + subpixel_radius + 1) 

107 x_min = cp.maximum(0, x_peak_int - subpixel_radius) 

108 x_max = cp.minimum(w, x_peak_int + subpixel_radius + 1) 

109 

110 region = correlation[y_min:y_max, x_min:x_max] 

111 

112 total_mass = cp.sum(region) 

113 if total_mass > 0: 

114 y_coords, x_coords = cp.mgrid[y_min:y_max, x_min:x_max] 

115 y_com = cp.sum(y_coords * region) / total_mass 

116 x_com = cp.sum(x_coords * region) / total_mass 

117 

118 dy = cp.where(y_com <= h // 2, y_com, y_com - h) 

119 dx = cp.where(x_com <= w // 2, x_com, x_com - w) 

120 

121 return float(dy), float(dx) 

122 

123 

124# Import the modular MIST implementation 

125from .mist.mist_main import mist_compute_tile_positions 

126 

127# Re-export for backward compatibility 

128__all__ = ['mist_compute_tile_positions', 'phase_correlation_gpu_only']