Coverage for openhcs/processing/backends/pos_gen/mist/phase_correlation.py: 5.0%
198 statements
« prev ^ index » next coverage.py v7.10.3, created at 2025-08-14 05:57 +0000
« prev ^ index » next coverage.py v7.10.3, created at 2025-08-14 05:57 +0000
1"""
2Phase Correlation Functions for MIST Algorithm
4GPU-accelerated phase correlation with subpixel accuracy.
5"""
6from __future__ import annotations
8from typing import TYPE_CHECKING, Tuple, List
10from openhcs.core.utils import optional_import
12# For type checking only
13if TYPE_CHECKING: 13 ↛ 14line 13 didn't jump to line 14 because the condition on line 13 was never true
14 import cupy as cp
16# Import CuPy as an optional dependency
17cp = optional_import("cupy")
20def _validate_cupy_array(array, name: str = "input") -> None: # type: ignore
21 """Validate that the input is a CuPy array."""
22 if not isinstance(array, cp.ndarray):
23 raise TypeError(f"{name} must be a CuPy array, got {type(array)}")
26def constrained_hill_climbing(
27 correlation_surface: "cp.ndarray", # type: ignore
28 initial_peak: Tuple[int, int],
29 max_shift: int
30) -> Tuple[float, float]:
31 """
32 Find optimal shift within constrained region using gradient ascent.
34 Args:
35 correlation_surface: 2D correlation surface (CuPy array)
36 initial_peak: (y, x) coordinates of initial peak
37 max_shift: Maximum allowed shift from initial peak
39 Returns:
40 Tuple of (dy, dx) refined shift values
41 """
42 _validate_cupy_array(correlation_surface, "correlation_surface")
44 if correlation_surface.ndim != 2:
45 raise ValueError(f"Correlation surface must be 2D, got {correlation_surface.ndim}D")
47 h, w = correlation_surface.shape
48 y_init, x_init = initial_peak
50 # Define search bounds
51 y_min = max(0, y_init - max_shift)
52 y_max = min(h, y_init + max_shift + 1)
53 x_min = max(0, x_init - max_shift)
54 x_max = min(w, x_init + max_shift + 1)
56 # Extract constrained region
57 region = correlation_surface[y_min:y_max, x_min:x_max]
59 if region.size == 0:
60 return float(y_init), float(x_init)
62 # Find peak in constrained region
63 peak_idx = cp.unravel_index(cp.argmax(region), region.shape)
64 y_peak_local = peak_idx[0]
65 x_peak_local = peak_idx[1]
67 # Convert back to global coordinates
68 y_peak_global = y_min + y_peak_local
69 x_peak_global = x_min + x_peak_local
71 # Subpixel refinement using center of mass in 3x3 neighborhood
72 if (1 <= y_peak_local < region.shape[0] - 1 and
73 1 <= x_peak_local < region.shape[1] - 1):
75 # Extract 3x3 neighborhood around peak
76 neighborhood = region[y_peak_local-1:y_peak_local+2,
77 x_peak_local-1:x_peak_local+2]
79 # Compute center of mass
80 total_mass = cp.sum(neighborhood)
81 if total_mass > 0:
82 y_coords, x_coords = cp.mgrid[0:3, 0:3]
83 y_com = cp.sum(y_coords * neighborhood) / total_mass
84 x_com = cp.sum(x_coords * neighborhood) / total_mass
86 # Adjust to global coordinates with subpixel precision
87 y_refined = y_min + y_peak_local - 1 + y_com
88 x_refined = x_min + x_peak_local - 1 + x_com
90 return float(y_refined), float(x_refined)
92 return float(y_peak_global), float(x_peak_global)
95def phase_correlation_gpu_only(
96 image1: "cp.ndarray", # type: ignore
97 image2: "cp.ndarray", # type: ignore
98 *,
99 window: bool = True,
100 subpixel: bool = True,
101 subpixel_radius: int = 3,
102 regularization_eps_multiplier: float = 1000.0
103) -> Tuple[float, float]:
104 """
105 Full GPU phase correlation with all operations on device.
107 Args:
108 image1: First image (CuPy array)
109 image2: Second image (CuPy array)
110 window: Apply Hann window
111 subpixel: Enable subpixel accuracy
112 subpixel_radius: Radius for subpixel interpolation
113 regularization_eps_multiplier: Multiplier for numerical stability
115 Returns:
116 (dy, dx) shift values
117 """
118 _validate_cupy_array(image1, "image1")
119 _validate_cupy_array(image2, "image2")
121 if image1.shape != image2.shape:
122 raise ValueError(f"Images must have the same shape, got {image1.shape} and {image2.shape}")
124 # Ensure float32 and remove DC component (all GPU operations)
125 img1 = image1.astype(cp.float32)
126 img2 = image2.astype(cp.float32)
128 img1 = img1 - cp.mean(img1)
129 img2 = img2 - cp.mean(img2)
131 # Apply Hann window (all GPU)
132 if window:
133 h, w = img1.shape
134 win_y = cp.hanning(h).reshape(-1, 1)
135 win_x = cp.hanning(w).reshape(1, -1)
136 window_2d = win_y * win_x
137 img1 = img1 * window_2d
138 img2 = img2 * window_2d
140 # FFT operations (GPU)
141 fft1 = cp.fft.fft2(img1)
142 fft2 = cp.fft.fft2(img2)
144 # Cross-power spectrum with configurable regularization (GPU)
145 cross_power = fft1 * cp.conj(fft2)
146 magnitude = cp.abs(cross_power)
148 # More robust regularization - use relative threshold
149 eps = cp.finfo(cp.float32).eps * regularization_eps_multiplier
150 magnitude_threshold = cp.maximum(eps, cp.mean(magnitude) * 1e-6)
151 cross_power_norm = cross_power / (magnitude + magnitude_threshold)
153 # Inverse FFT (GPU)
154 correlation = cp.real(cp.fft.ifft2(cross_power_norm))
156 # Find peak (GPU)
157 peak_idx = cp.unravel_index(cp.argmax(correlation), correlation.shape)
158 y_peak = peak_idx[0] # Keep as CuPy scalar
159 x_peak = peak_idx[1] # Keep as CuPy scalar
161 # Convert to signed shifts (GPU arithmetic)
162 # For FFT shift conversion, peaks in second half represent negative shifts
163 h, w = correlation.shape
164 dy = cp.where(y_peak < h // 2, y_peak, y_peak - h)
165 dx = cp.where(x_peak < w // 2, x_peak, x_peak - w)
167 # Subpixel refinement (all GPU)
168 if subpixel:
169 # Convert to int for indexing
170 y_peak_int = int(y_peak)
171 x_peak_int = int(x_peak)
173 y_min = cp.maximum(0, y_peak_int - subpixel_radius)
174 y_max = cp.minimum(h, y_peak_int + subpixel_radius + 1)
175 x_min = cp.maximum(0, x_peak_int - subpixel_radius)
176 x_max = cp.minimum(w, x_peak_int + subpixel_radius + 1)
178 region = correlation[y_min:y_max, x_min:x_max]
180 total_mass = cp.sum(region)
181 if total_mass > 0:
182 # Create local coordinates for the region, then convert to global
183 region_h, region_w = region.shape
184 y_local, x_local = cp.mgrid[0:region_h, 0:region_w]
186 # Calculate center of mass in local coordinates
187 y_com_local = cp.sum(y_local * region) / total_mass
188 x_com_local = cp.sum(x_local * region) / total_mass
190 # Convert local COM to global coordinates
191 y_com = y_min + y_com_local
192 x_com = x_min + x_com_local
194 # Apply same FFT coordinate conversion for subpixel values
195 dy = cp.where(y_com < h // 2, y_com, y_com - h)
196 dx = cp.where(x_com < w // 2, x_com, x_com - w)
198 return float(dy), float(dx)
201def phase_correlation_nist_gpu(
202 image1: "cp.ndarray",
203 image2: "cp.ndarray",
204 direction: str,
205 n_peaks: int = 2,
206 use_nist_normalization: bool = True
207) -> Tuple[float, float, float]:
208 """
209 GPU-native implementation of NIST MIST phase correlation with robustness features.
211 Args:
212 image1, image2: Input images (CuPy arrays)
213 direction: 'horizontal' or 'vertical' for directional constraints
214 n_peaks: Number of peaks to test (NIST default: 2)
215 use_nist_normalization: Use fc/abs(fc) instead of Hann windowing
217 Returns:
218 (dy, dx, quality): Best displacement and correlation quality
219 """
220 # Ensure float32 and remove DC component
221 img1 = image1.astype(cp.float32)
222 img2 = image2.astype(cp.float32)
224 img1 = img1 - cp.mean(img1)
225 img2 = img2 - cp.mean(img2)
227 # FFT operations
228 fft1 = cp.fft.fft2(img1)
229 fft2 = cp.fft.fft2(img2)
231 # Cross-power spectrum
232 cross_power = fft1 * cp.conj(fft2)
234 if use_nist_normalization:
235 # NIST normalization: fc / abs(fc)
236 magnitude = cp.abs(cross_power)
237 # Prevent division by zero with small epsilon
238 eps = cp.finfo(cp.float32).eps * 1000
239 cross_power_norm = cross_power / (magnitude + eps)
240 else:
241 # Current OpenHCS approach with regularization
242 magnitude = cp.abs(cross_power)
243 eps = cp.finfo(cp.float32).eps * 1000.0
244 magnitude_threshold = cp.maximum(eps, cp.mean(magnitude) * 1e-6)
245 cross_power_norm = cross_power / (magnitude + magnitude_threshold)
247 # Inverse FFT to get correlation matrix
248 correlation = cp.real(cp.fft.ifft2(cross_power_norm))
250 # Find multiple peaks
251 peaks = _find_multiple_peaks_gpu(correlation, n_peaks)
253 best_quality = -1.0
254 best_dy, best_dx = 0.0, 0.0
256 # Test each peak with multiple interpretations
257 for peak_y, peak_x, peak_value in peaks:
258 interpretations = _test_fft_interpretations(
259 correlation, peak_y, peak_x, direction
260 )
262 # Test each interpretation
263 for interp_y, interp_x in interpretations:
264 # Convert to signed displacements
265 h, w = correlation.shape
266 dy = interp_y if interp_y < h // 2 else interp_y - h
267 dx = interp_x if interp_x < w // 2 else interp_x - w
269 # Compute quality for this interpretation
270 quality = _compute_interpretation_quality(img1, img2, dy, dx)
272 if quality > best_quality:
273 best_quality = quality
274 best_dy, best_dx = dy, dx
276 return float(best_dy), float(best_dx), float(best_quality)
279def _find_multiple_peaks_gpu(
280 correlation_matrix: "cp.ndarray",
281 n_peaks: int = 2,
282 min_distance: int = 5
283) -> List[Tuple[int, int, float]]:
284 """
285 GPU-optimized multi-peak detection with minimum distance constraint.
287 Prevents finding multiple peaks that are too close together.
288 """
289 h, w = correlation_matrix.shape
291 # Use GPU-accelerated peak finding
292 flat_corr = correlation_matrix.flatten()
294 # Find top candidates (more than needed)
295 n_candidates = min(n_peaks * 4, flat_corr.size)
296 top_indices = cp.argpartition(flat_corr, -n_candidates)[-n_candidates:]
298 # Convert to 2D coordinates and sort by value
299 candidates = []
300 for idx in top_indices:
301 y, x = cp.unravel_index(idx, correlation_matrix.shape)
302 value = correlation_matrix[y, x]
303 candidates.append((int(y), int(x), float(value)))
305 candidates.sort(key=lambda p: p[2], reverse=True)
307 # Apply minimum distance constraint
308 selected_peaks = []
309 for y, x, value in candidates:
310 # Check distance from already selected peaks
311 too_close = False
312 for sel_y, sel_x, _ in selected_peaks:
313 distance = cp.sqrt((y - sel_y)**2 + (x - sel_x)**2)
314 if distance < min_distance:
315 too_close = True
316 break
318 if not too_close:
319 selected_peaks.append((y, x, value))
321 if len(selected_peaks) >= n_peaks:
322 break
324 return selected_peaks
327def _test_fft_interpretations(
328 correlation_matrix: "cp.ndarray",
329 peak_y: int,
330 peak_x: int,
331 direction: str
332) -> List[Tuple[int, int]]:
333 """
334 Generate FFT periodicity interpretations with directional constraints.
336 Args:
337 correlation_matrix: Phase correlation matrix
338 peak_y, peak_x: Peak coordinates
339 direction: 'horizontal' or 'vertical' for directional constraints
341 Returns:
342 List of (y, x) interpretation coordinates
343 """
344 h, w = correlation_matrix.shape
345 interpretations = []
347 # NIST Algorithm 5: Test 16 interpretations with directional constraints
348 if direction == 'horizontal':
349 # Left-right pairs: test (x, ±y) with 4 FFT possibilities
350 for y_sign in [1, -1]:
351 for x_offset in [0, w]: # FFT periodicity in x
352 for y_offset in [0, h]: # FFT periodicity in y
353 interp_x = (peak_x + x_offset) % w
354 interp_y = (peak_y * y_sign + y_offset) % h
355 interpretations.append((interp_y, interp_x))
357 elif direction == 'vertical':
358 # Up-down pairs: test (±x, y) with 4 FFT possibilities
359 for x_sign in [1, -1]:
360 for x_offset in [0, w]: # FFT periodicity in x
361 for y_offset in [0, h]: # FFT periodicity in y
362 interp_x = (peak_x * x_sign + x_offset) % w
363 interp_y = (peak_y + y_offset) % h
364 interpretations.append((interp_y, interp_x))
366 # Remove duplicates while preserving order
367 seen = set()
368 unique_interpretations = []
369 for interp in interpretations:
370 if interp not in seen:
371 seen.add(interp)
372 unique_interpretations.append(interp)
374 return unique_interpretations
377def _compute_interpretation_quality(
378 region1: "cp.ndarray",
379 region2: "cp.ndarray",
380 dy: float,
381 dx: float
382) -> float:
383 """
384 Compute quality for a specific displacement interpretation.
386 Args:
387 region1, region2: Input image regions
388 dy, dx: Displacement to test
390 Returns:
391 Normalized cross-correlation quality
392 """
393 # Pre-center regions
394 r1_mean = cp.mean(region1)
395 r2_mean = cp.mean(region2)
396 r1_centered = region1 - r1_mean
397 r2_centered = region2 - r2_mean
399 shift_y, shift_x = int(round(dy)), int(round(dx))
400 h, w = r1_centered.shape
402 # Calculate overlap bounds
403 y1_start = max(0, shift_y)
404 y1_end = min(h, h + shift_y)
405 x1_start = max(0, shift_x)
406 x1_end = min(w, w + shift_x)
408 y2_start = max(0, -shift_y)
409 y2_end = min(h, h - shift_y)
410 x2_start = max(0, -shift_x)
411 x2_end = min(w, w - shift_x)
413 # Extract overlapping regions
414 r1_overlap = r1_centered[y1_start:y1_end, x1_start:x1_end]
415 r2_overlap = r2_centered[y2_start:y2_end, x2_start:x2_end]
417 if r1_overlap.size == 0 or r2_overlap.size == 0:
418 return -1.0
420 # Ensure same size (should be guaranteed by bounds calculation)
421 min_h = min(r1_overlap.shape[0], r2_overlap.shape[0])
422 min_w = min(r1_overlap.shape[1], r2_overlap.shape[1])
424 r1_overlap = r1_overlap[:min_h, :min_w]
425 r2_overlap = r2_overlap[:min_h, :min_w]
427 # GPU-accelerated correlation computation
428 r1_flat = r1_overlap.flatten()
429 r2_flat = r2_overlap.flatten()
431 numerator = cp.dot(r1_flat, r2_flat)
432 norm1 = cp.linalg.norm(r1_flat)
433 norm2 = cp.linalg.norm(r2_flat)
435 denominator = norm1 * norm2
437 if denominator == 0:
438 return -1.0
440 return float(numerator / denominator)