Coverage for openhcs/processing/backends/pos_gen/mist_processor_cupy.py: 17.1%

58 statements  

« prev     ^ index     » next       coverage.py v7.11.0, created at 2025-11-04 02:09 +0000

1""" 

2MIST (Microscopy Image Stitching Tool) GPU Implementation 

3 

4This module provides GPU-accelerated MIST implementation using CuPy. 

5All legacy functions have been moved to the modular implementation in the mist/ subfolder. 

6""" 

7from __future__ import annotations 

8 

9import logging 

10from typing import TYPE_CHECKING, Tuple 

11 

12from openhcs.core.utils import optional_import 

13 

14# For type checking only 

15if TYPE_CHECKING: 15 ↛ 16line 15 didn't jump to line 16 because the condition on line 15 was never true

16 import cupy as cp 

17 

18# Import CuPy as an optional dependency 

19cp = optional_import("cupy") 

20 

21logger = logging.getLogger(__name__) 

22 

23 

24def _validate_cupy_array(array, name: str = "input") -> None: # type: ignore 

25 """Validate that the input is a CuPy array.""" 

26 if not isinstance(array, cp.ndarray): 

27 raise TypeError(f"{name} must be a CuPy array, got {type(array)}") 

28 

29 

30def phase_correlation_gpu_only( 

31 image1: "cp.ndarray", # type: ignore 

32 image2: "cp.ndarray", # type: ignore 

33 *, 

34 window: bool = True, 

35 subpixel: bool = True, 

36 subpixel_radius: int = 3, 

37 regularization_eps_multiplier: float = 1000.0 

38) -> Tuple[float, float]: 

39 """ 

40 Full GPU phase correlation with all operations on device. 

41  

42 Args: 

43 image1: First image (CuPy array) 

44 image2: Second image (CuPy array) 

45 window: Apply Hann window 

46 subpixel: Enable subpixel accuracy 

47 subpixel_radius: Radius for subpixel interpolation 

48 regularization_eps_multiplier: Multiplier for numerical stability 

49  

50 Returns: 

51 (dy, dx) shift values 

52 """ 

53 _validate_cupy_array(image1, "image1") 

54 _validate_cupy_array(image2, "image2") 

55 

56 if image1.shape != image2.shape: 

57 raise ValueError(f"Images must have the same shape, got {image1.shape} and {image2.shape}") 

58 

59 # Ensure float32 and remove DC component (all GPU operations) 

60 img1 = image1.astype(cp.float32) 

61 img2 = image2.astype(cp.float32) 

62 

63 img1 = img1 - cp.mean(img1) 

64 img2 = img2 - cp.mean(img2) 

65 

66 # Apply Hann window (all GPU) 

67 if window: 

68 h, w = img1.shape 

69 win_y = cp.hanning(h).reshape(-1, 1) 

70 win_x = cp.hanning(w).reshape(1, -1) 

71 window_2d = win_y * win_x 

72 img1 = img1 * window_2d 

73 img2 = img2 * window_2d 

74 

75 # FFT operations (GPU) 

76 fft1 = cp.fft.fft2(img1) 

77 fft2 = cp.fft.fft2(img2) 

78 

79 # Cross-power spectrum with configurable regularization (GPU) 

80 cross_power = fft1 * cp.conj(fft2) 

81 magnitude = cp.abs(cross_power) 

82 eps = cp.finfo(cp.float32).eps * regularization_eps_multiplier 

83 cross_power_norm = cross_power / (magnitude + eps) 

84 

85 # Inverse FFT (GPU) 

86 correlation = cp.real(cp.fft.ifft2(cross_power_norm)) 

87 

88 # Find peak (GPU) 

89 peak_idx = cp.unravel_index(cp.argmax(correlation), correlation.shape) 

90 y_peak = peak_idx[0] # Keep as CuPy scalar 

91 x_peak = peak_idx[1] # Keep as CuPy scalar 

92 

93 # Convert to signed shifts (GPU arithmetic) 

94 h, w = correlation.shape 

95 dy = cp.where(y_peak <= h // 2, y_peak, y_peak - h) 

96 dx = cp.where(x_peak <= w // 2, x_peak, x_peak - w) 

97 

98 # Subpixel refinement (all GPU) 

99 if subpixel: 

100 # Convert to int for indexing 

101 y_peak_int = int(y_peak) 

102 x_peak_int = int(x_peak) 

103 

104 y_min = cp.maximum(0, y_peak_int - subpixel_radius) 

105 y_max = cp.minimum(h, y_peak_int + subpixel_radius + 1) 

106 x_min = cp.maximum(0, x_peak_int - subpixel_radius) 

107 x_max = cp.minimum(w, x_peak_int + subpixel_radius + 1) 

108 

109 region = correlation[y_min:y_max, x_min:x_max] 

110 

111 total_mass = cp.sum(region) 

112 if total_mass > 0: 

113 y_coords, x_coords = cp.mgrid[y_min:y_max, x_min:x_max] 

114 y_com = cp.sum(y_coords * region) / total_mass 

115 x_com = cp.sum(x_coords * region) / total_mass 

116 

117 dy = cp.where(y_com <= h // 2, y_com, y_com - h) 

118 dx = cp.where(x_com <= w // 2, x_com, x_com - w) 

119 

120 return float(dy), float(dx) 

121 

122 

123# Import the modular MIST implementation 

124from .mist.mist_main import mist_compute_tile_positions 

125 

126# Re-export for backward compatibility 

127__all__ = ['mist_compute_tile_positions', 'phase_correlation_gpu_only']