Coverage for openhcs/processing/backends/pos_gen/mist/phase_correlation.py: 5.0%

198 statements  

« prev     ^ index     » next       coverage.py v7.10.3, created at 2025-08-14 05:57 +0000

1""" 

2Phase Correlation Functions for MIST Algorithm 

3 

4GPU-accelerated phase correlation with subpixel accuracy. 

5""" 

6from __future__ import annotations 

7 

8from typing import TYPE_CHECKING, Tuple, List 

9 

10from openhcs.core.utils import optional_import 

11 

12# For type checking only 

13if TYPE_CHECKING: 13 ↛ 14line 13 didn't jump to line 14 because the condition on line 13 was never true

14 import cupy as cp 

15 

16# Import CuPy as an optional dependency 

17cp = optional_import("cupy") 

18 

19 

20def _validate_cupy_array(array, name: str = "input") -> None: # type: ignore 

21 """Validate that the input is a CuPy array.""" 

22 if not isinstance(array, cp.ndarray): 

23 raise TypeError(f"{name} must be a CuPy array, got {type(array)}") 

24 

25 

26def constrained_hill_climbing( 

27 correlation_surface: "cp.ndarray", # type: ignore 

28 initial_peak: Tuple[int, int], 

29 max_shift: int 

30) -> Tuple[float, float]: 

31 """ 

32 Find optimal shift within constrained region using gradient ascent. 

33 

34 Args: 

35 correlation_surface: 2D correlation surface (CuPy array) 

36 initial_peak: (y, x) coordinates of initial peak 

37 max_shift: Maximum allowed shift from initial peak 

38 

39 Returns: 

40 Tuple of (dy, dx) refined shift values 

41 """ 

42 _validate_cupy_array(correlation_surface, "correlation_surface") 

43 

44 if correlation_surface.ndim != 2: 

45 raise ValueError(f"Correlation surface must be 2D, got {correlation_surface.ndim}D") 

46 

47 h, w = correlation_surface.shape 

48 y_init, x_init = initial_peak 

49 

50 # Define search bounds 

51 y_min = max(0, y_init - max_shift) 

52 y_max = min(h, y_init + max_shift + 1) 

53 x_min = max(0, x_init - max_shift) 

54 x_max = min(w, x_init + max_shift + 1) 

55 

56 # Extract constrained region 

57 region = correlation_surface[y_min:y_max, x_min:x_max] 

58 

59 if region.size == 0: 

60 return float(y_init), float(x_init) 

61 

62 # Find peak in constrained region 

63 peak_idx = cp.unravel_index(cp.argmax(region), region.shape) 

64 y_peak_local = peak_idx[0] 

65 x_peak_local = peak_idx[1] 

66 

67 # Convert back to global coordinates 

68 y_peak_global = y_min + y_peak_local 

69 x_peak_global = x_min + x_peak_local 

70 

71 # Subpixel refinement using center of mass in 3x3 neighborhood 

72 if (1 <= y_peak_local < region.shape[0] - 1 and 

73 1 <= x_peak_local < region.shape[1] - 1): 

74 

75 # Extract 3x3 neighborhood around peak 

76 neighborhood = region[y_peak_local-1:y_peak_local+2, 

77 x_peak_local-1:x_peak_local+2] 

78 

79 # Compute center of mass 

80 total_mass = cp.sum(neighborhood) 

81 if total_mass > 0: 

82 y_coords, x_coords = cp.mgrid[0:3, 0:3] 

83 y_com = cp.sum(y_coords * neighborhood) / total_mass 

84 x_com = cp.sum(x_coords * neighborhood) / total_mass 

85 

86 # Adjust to global coordinates with subpixel precision 

87 y_refined = y_min + y_peak_local - 1 + y_com 

88 x_refined = x_min + x_peak_local - 1 + x_com 

89 

90 return float(y_refined), float(x_refined) 

91 

92 return float(y_peak_global), float(x_peak_global) 

93 

94 

95def phase_correlation_gpu_only( 

96 image1: "cp.ndarray", # type: ignore 

97 image2: "cp.ndarray", # type: ignore 

98 *, 

99 window: bool = True, 

100 subpixel: bool = True, 

101 subpixel_radius: int = 3, 

102 regularization_eps_multiplier: float = 1000.0 

103) -> Tuple[float, float]: 

104 """ 

105 Full GPU phase correlation with all operations on device. 

106  

107 Args: 

108 image1: First image (CuPy array) 

109 image2: Second image (CuPy array) 

110 window: Apply Hann window 

111 subpixel: Enable subpixel accuracy 

112 subpixel_radius: Radius for subpixel interpolation 

113 regularization_eps_multiplier: Multiplier for numerical stability 

114  

115 Returns: 

116 (dy, dx) shift values 

117 """ 

118 _validate_cupy_array(image1, "image1") 

119 _validate_cupy_array(image2, "image2") 

120 

121 if image1.shape != image2.shape: 

122 raise ValueError(f"Images must have the same shape, got {image1.shape} and {image2.shape}") 

123 

124 # Ensure float32 and remove DC component (all GPU operations) 

125 img1 = image1.astype(cp.float32) 

126 img2 = image2.astype(cp.float32) 

127 

128 img1 = img1 - cp.mean(img1) 

129 img2 = img2 - cp.mean(img2) 

130 

131 # Apply Hann window (all GPU) 

132 if window: 

133 h, w = img1.shape 

134 win_y = cp.hanning(h).reshape(-1, 1) 

135 win_x = cp.hanning(w).reshape(1, -1) 

136 window_2d = win_y * win_x 

137 img1 = img1 * window_2d 

138 img2 = img2 * window_2d 

139 

140 # FFT operations (GPU) 

141 fft1 = cp.fft.fft2(img1) 

142 fft2 = cp.fft.fft2(img2) 

143 

144 # Cross-power spectrum with configurable regularization (GPU) 

145 cross_power = fft1 * cp.conj(fft2) 

146 magnitude = cp.abs(cross_power) 

147 

148 # More robust regularization - use relative threshold 

149 eps = cp.finfo(cp.float32).eps * regularization_eps_multiplier 

150 magnitude_threshold = cp.maximum(eps, cp.mean(magnitude) * 1e-6) 

151 cross_power_norm = cross_power / (magnitude + magnitude_threshold) 

152 

153 # Inverse FFT (GPU) 

154 correlation = cp.real(cp.fft.ifft2(cross_power_norm)) 

155 

156 # Find peak (GPU) 

157 peak_idx = cp.unravel_index(cp.argmax(correlation), correlation.shape) 

158 y_peak = peak_idx[0] # Keep as CuPy scalar 

159 x_peak = peak_idx[1] # Keep as CuPy scalar 

160 

161 # Convert to signed shifts (GPU arithmetic) 

162 # For FFT shift conversion, peaks in second half represent negative shifts 

163 h, w = correlation.shape 

164 dy = cp.where(y_peak < h // 2, y_peak, y_peak - h) 

165 dx = cp.where(x_peak < w // 2, x_peak, x_peak - w) 

166 

167 # Subpixel refinement (all GPU) 

168 if subpixel: 

169 # Convert to int for indexing 

170 y_peak_int = int(y_peak) 

171 x_peak_int = int(x_peak) 

172 

173 y_min = cp.maximum(0, y_peak_int - subpixel_radius) 

174 y_max = cp.minimum(h, y_peak_int + subpixel_radius + 1) 

175 x_min = cp.maximum(0, x_peak_int - subpixel_radius) 

176 x_max = cp.minimum(w, x_peak_int + subpixel_radius + 1) 

177 

178 region = correlation[y_min:y_max, x_min:x_max] 

179 

180 total_mass = cp.sum(region) 

181 if total_mass > 0: 

182 # Create local coordinates for the region, then convert to global 

183 region_h, region_w = region.shape 

184 y_local, x_local = cp.mgrid[0:region_h, 0:region_w] 

185 

186 # Calculate center of mass in local coordinates 

187 y_com_local = cp.sum(y_local * region) / total_mass 

188 x_com_local = cp.sum(x_local * region) / total_mass 

189 

190 # Convert local COM to global coordinates 

191 y_com = y_min + y_com_local 

192 x_com = x_min + x_com_local 

193 

194 # Apply same FFT coordinate conversion for subpixel values 

195 dy = cp.where(y_com < h // 2, y_com, y_com - h) 

196 dx = cp.where(x_com < w // 2, x_com, x_com - w) 

197 

198 return float(dy), float(dx) 

199 

200 

201def phase_correlation_nist_gpu( 

202 image1: "cp.ndarray", 

203 image2: "cp.ndarray", 

204 direction: str, 

205 n_peaks: int = 2, 

206 use_nist_normalization: bool = True 

207) -> Tuple[float, float, float]: 

208 """ 

209 GPU-native implementation of NIST MIST phase correlation with robustness features. 

210 

211 Args: 

212 image1, image2: Input images (CuPy arrays) 

213 direction: 'horizontal' or 'vertical' for directional constraints 

214 n_peaks: Number of peaks to test (NIST default: 2) 

215 use_nist_normalization: Use fc/abs(fc) instead of Hann windowing 

216 

217 Returns: 

218 (dy, dx, quality): Best displacement and correlation quality 

219 """ 

220 # Ensure float32 and remove DC component 

221 img1 = image1.astype(cp.float32) 

222 img2 = image2.astype(cp.float32) 

223 

224 img1 = img1 - cp.mean(img1) 

225 img2 = img2 - cp.mean(img2) 

226 

227 # FFT operations 

228 fft1 = cp.fft.fft2(img1) 

229 fft2 = cp.fft.fft2(img2) 

230 

231 # Cross-power spectrum 

232 cross_power = fft1 * cp.conj(fft2) 

233 

234 if use_nist_normalization: 

235 # NIST normalization: fc / abs(fc) 

236 magnitude = cp.abs(cross_power) 

237 # Prevent division by zero with small epsilon 

238 eps = cp.finfo(cp.float32).eps * 1000 

239 cross_power_norm = cross_power / (magnitude + eps) 

240 else: 

241 # Current OpenHCS approach with regularization 

242 magnitude = cp.abs(cross_power) 

243 eps = cp.finfo(cp.float32).eps * 1000.0 

244 magnitude_threshold = cp.maximum(eps, cp.mean(magnitude) * 1e-6) 

245 cross_power_norm = cross_power / (magnitude + magnitude_threshold) 

246 

247 # Inverse FFT to get correlation matrix 

248 correlation = cp.real(cp.fft.ifft2(cross_power_norm)) 

249 

250 # Find multiple peaks 

251 peaks = _find_multiple_peaks_gpu(correlation, n_peaks) 

252 

253 best_quality = -1.0 

254 best_dy, best_dx = 0.0, 0.0 

255 

256 # Test each peak with multiple interpretations 

257 for peak_y, peak_x, peak_value in peaks: 

258 interpretations = _test_fft_interpretations( 

259 correlation, peak_y, peak_x, direction 

260 ) 

261 

262 # Test each interpretation 

263 for interp_y, interp_x in interpretations: 

264 # Convert to signed displacements 

265 h, w = correlation.shape 

266 dy = interp_y if interp_y < h // 2 else interp_y - h 

267 dx = interp_x if interp_x < w // 2 else interp_x - w 

268 

269 # Compute quality for this interpretation 

270 quality = _compute_interpretation_quality(img1, img2, dy, dx) 

271 

272 if quality > best_quality: 

273 best_quality = quality 

274 best_dy, best_dx = dy, dx 

275 

276 return float(best_dy), float(best_dx), float(best_quality) 

277 

278 

279def _find_multiple_peaks_gpu( 

280 correlation_matrix: "cp.ndarray", 

281 n_peaks: int = 2, 

282 min_distance: int = 5 

283) -> List[Tuple[int, int, float]]: 

284 """ 

285 GPU-optimized multi-peak detection with minimum distance constraint. 

286 

287 Prevents finding multiple peaks that are too close together. 

288 """ 

289 h, w = correlation_matrix.shape 

290 

291 # Use GPU-accelerated peak finding 

292 flat_corr = correlation_matrix.flatten() 

293 

294 # Find top candidates (more than needed) 

295 n_candidates = min(n_peaks * 4, flat_corr.size) 

296 top_indices = cp.argpartition(flat_corr, -n_candidates)[-n_candidates:] 

297 

298 # Convert to 2D coordinates and sort by value 

299 candidates = [] 

300 for idx in top_indices: 

301 y, x = cp.unravel_index(idx, correlation_matrix.shape) 

302 value = correlation_matrix[y, x] 

303 candidates.append((int(y), int(x), float(value))) 

304 

305 candidates.sort(key=lambda p: p[2], reverse=True) 

306 

307 # Apply minimum distance constraint 

308 selected_peaks = [] 

309 for y, x, value in candidates: 

310 # Check distance from already selected peaks 

311 too_close = False 

312 for sel_y, sel_x, _ in selected_peaks: 

313 distance = cp.sqrt((y - sel_y)**2 + (x - sel_x)**2) 

314 if distance < min_distance: 

315 too_close = True 

316 break 

317 

318 if not too_close: 

319 selected_peaks.append((y, x, value)) 

320 

321 if len(selected_peaks) >= n_peaks: 

322 break 

323 

324 return selected_peaks 

325 

326 

327def _test_fft_interpretations( 

328 correlation_matrix: "cp.ndarray", 

329 peak_y: int, 

330 peak_x: int, 

331 direction: str 

332) -> List[Tuple[int, int]]: 

333 """ 

334 Generate FFT periodicity interpretations with directional constraints. 

335 

336 Args: 

337 correlation_matrix: Phase correlation matrix 

338 peak_y, peak_x: Peak coordinates 

339 direction: 'horizontal' or 'vertical' for directional constraints 

340 

341 Returns: 

342 List of (y, x) interpretation coordinates 

343 """ 

344 h, w = correlation_matrix.shape 

345 interpretations = [] 

346 

347 # NIST Algorithm 5: Test 16 interpretations with directional constraints 

348 if direction == 'horizontal': 

349 # Left-right pairs: test (x, ±y) with 4 FFT possibilities 

350 for y_sign in [1, -1]: 

351 for x_offset in [0, w]: # FFT periodicity in x 

352 for y_offset in [0, h]: # FFT periodicity in y 

353 interp_x = (peak_x + x_offset) % w 

354 interp_y = (peak_y * y_sign + y_offset) % h 

355 interpretations.append((interp_y, interp_x)) 

356 

357 elif direction == 'vertical': 

358 # Up-down pairs: test (±x, y) with 4 FFT possibilities 

359 for x_sign in [1, -1]: 

360 for x_offset in [0, w]: # FFT periodicity in x 

361 for y_offset in [0, h]: # FFT periodicity in y 

362 interp_x = (peak_x * x_sign + x_offset) % w 

363 interp_y = (peak_y + y_offset) % h 

364 interpretations.append((interp_y, interp_x)) 

365 

366 # Remove duplicates while preserving order 

367 seen = set() 

368 unique_interpretations = [] 

369 for interp in interpretations: 

370 if interp not in seen: 

371 seen.add(interp) 

372 unique_interpretations.append(interp) 

373 

374 return unique_interpretations 

375 

376 

377def _compute_interpretation_quality( 

378 region1: "cp.ndarray", 

379 region2: "cp.ndarray", 

380 dy: float, 

381 dx: float 

382) -> float: 

383 """ 

384 Compute quality for a specific displacement interpretation. 

385 

386 Args: 

387 region1, region2: Input image regions 

388 dy, dx: Displacement to test 

389 

390 Returns: 

391 Normalized cross-correlation quality 

392 """ 

393 # Pre-center regions 

394 r1_mean = cp.mean(region1) 

395 r2_mean = cp.mean(region2) 

396 r1_centered = region1 - r1_mean 

397 r2_centered = region2 - r2_mean 

398 

399 shift_y, shift_x = int(round(dy)), int(round(dx)) 

400 h, w = r1_centered.shape 

401 

402 # Calculate overlap bounds 

403 y1_start = max(0, shift_y) 

404 y1_end = min(h, h + shift_y) 

405 x1_start = max(0, shift_x) 

406 x1_end = min(w, w + shift_x) 

407 

408 y2_start = max(0, -shift_y) 

409 y2_end = min(h, h - shift_y) 

410 x2_start = max(0, -shift_x) 

411 x2_end = min(w, w - shift_x) 

412 

413 # Extract overlapping regions 

414 r1_overlap = r1_centered[y1_start:y1_end, x1_start:x1_end] 

415 r2_overlap = r2_centered[y2_start:y2_end, x2_start:x2_end] 

416 

417 if r1_overlap.size == 0 or r2_overlap.size == 0: 

418 return -1.0 

419 

420 # Ensure same size (should be guaranteed by bounds calculation) 

421 min_h = min(r1_overlap.shape[0], r2_overlap.shape[0]) 

422 min_w = min(r1_overlap.shape[1], r2_overlap.shape[1]) 

423 

424 r1_overlap = r1_overlap[:min_h, :min_w] 

425 r2_overlap = r2_overlap[:min_h, :min_w] 

426 

427 # GPU-accelerated correlation computation 

428 r1_flat = r1_overlap.flatten() 

429 r2_flat = r2_overlap.flatten() 

430 

431 numerator = cp.dot(r1_flat, r2_flat) 

432 norm1 = cp.linalg.norm(r1_flat) 

433 norm2 = cp.linalg.norm(r2_flat) 

434 

435 denominator = norm1 * norm2 

436 

437 if denominator == 0: 

438 return -1.0 

439 

440 return float(numerator / denominator)