Coverage for openhcs/processing/backends/pos_gen/mist/mist_main.py: 5.1%

304 statements  

« prev     ^ index     » next       coverage.py v7.11.0, created at 2025-11-04 02:09 +0000

1""" 

2Main MIST Implementation 

3 

4Full GPU-accelerated MIST implementation with zero CPU operations. 

5Orchestrates all MIST components for tile position computation. 

6""" 

7from __future__ import annotations 

8 

9import logging 

10from typing import TYPE_CHECKING, Tuple 

11 

12from openhcs.core.memory.decorators import cupy as cupy_func 

13from openhcs.core.pipeline.function_contracts import special_inputs, special_outputs 

14from openhcs.core.utils import optional_import 

15 

16from .phase_correlation import phase_correlation_gpu_only, phase_correlation_nist_gpu 

17from .quality_metrics import ( 

18 compute_correlation_quality_gpu_aligned, 

19 compute_adaptive_quality_threshold, 

20 validate_translation_consistency, 

21 log_coordinate_transformation 

22) 

23from .position_reconstruction import build_mst_gpu, rebuild_positions_from_mst_gpu 

24 

25# For type checking only 

26if TYPE_CHECKING: 26 ↛ 27line 26 didn't jump to line 27 because the condition on line 26 was never true

27 import cupy as cp 

28 

29# Import CuPy as an optional dependency 

30cp = optional_import("cupy") 

31 

32logger = logging.getLogger(__name__) 

33 

34 

35def _convert_overlap_to_tile_coordinates( 

36 dy: float, dx: float, 

37 overlap_h: int, overlap_w: int, 

38 tile_h: int, tile_w: int, 

39 direction: str 

40) -> Tuple[float, float]: 

41 """ 

42 Convert overlap-region-relative displacements to tile-center coordinates. 

43 

44 Args: 

45 dy, dx: Phase correlation displacements in overlap region coordinates 

46 overlap_h, overlap_w: Overlap region dimensions 

47 tile_h, tile_w: Full tile dimensions 

48 direction: 'horizontal' or 'vertical' 

49 

50 Returns: 

51 (tile_dy, tile_dx): Displacements in tile-center coordinates 

52 """ 

53 if direction == 'horizontal': 

54 # For horizontal connections (left-right) 

55 # Expected displacement is approximately tile_w - overlap_w 

56 expected_dx = tile_w - overlap_w 

57 tile_dx = expected_dx + dx # Add phase correlation correction 

58 tile_dy = dy # Vertical should be minimal 

59 

60 elif direction == 'vertical': 

61 # For vertical connections (top-bottom) 

62 # Expected displacement is approximately tile_h - overlap_h 

63 expected_dy = tile_h - overlap_h 

64 tile_dy = expected_dy + dy # Add phase correlation correction 

65 tile_dx = dx # Horizontal should be minimal 

66 

67 else: 

68 raise ValueError(f"Invalid direction: {direction}. Must be 'horizontal' or 'vertical'") 

69 

70 return tile_dy, tile_dx 

71 

72 

73 

74 

75 

76def _validate_displacement_magnitude( 

77 tile_dx: float, tile_dy: float, 

78 expected_dx: float, expected_dy: float, 

79 direction: str, 

80 tolerance_factor: float = 2.0, 

81 tolerance_percent: float = 0.1 

82) -> bool: 

83 """ 

84 Validate that displacement magnitudes are reasonable. 

85 

86 Args: 

87 tile_dx, tile_dy: Computed tile-center displacements 

88 expected_dx, expected_dy: Expected displacements 

89 direction: 'horizontal' or 'vertical' 

90 tolerance_factor: How much deviation to allow 

91 

92 Returns: 

93 True if displacement is reasonable, False otherwise 

94 """ 

95 if direction == 'horizontal': 

96 # For horizontal connections, dx should be close to expected_dx 

97 dx_error = abs(tile_dx - expected_dx) 

98 max_allowed_error = tolerance_factor * expected_dx * tolerance_percent 

99 dx_valid = dx_error <= max_allowed_error 

100 

101 # dy should be small (minimal vertical drift relative to expected_dx, not expected_dy) 

102 max_allowed_dy = tolerance_factor * expected_dx * tolerance_percent 

103 dy_valid = abs(tile_dy) <= max_allowed_dy 

104 

105 return dx_valid and dy_valid 

106 

107 elif direction == 'vertical': 

108 # For vertical connections, dy should be close to expected_dy 

109 dy_error = abs(tile_dy - expected_dy) 

110 max_allowed_error = tolerance_factor * expected_dy * tolerance_percent 

111 dy_valid = dy_error <= max_allowed_error 

112 

113 # dx should be small (minimal horizontal drift relative to expected_dy, not expected_dx) 

114 max_allowed_dx = tolerance_factor * expected_dy * tolerance_percent 

115 dx_valid = abs(tile_dx) <= max_allowed_dx 

116 

117 return dy_valid and dx_valid 

118 

119 return False 

120 

121 

122def _validate_cupy_array(array, name: str = "input") -> None: # type: ignore 

123 """Validate that the input is a CuPy array.""" 

124 if not isinstance(array, cp.ndarray): 

125 raise TypeError(f"{name} must be a CuPy array, got {type(array)}") 

126 

127 

128def _global_optimization_gpu_only( 

129 positions: "cp.ndarray", # type: ignore 

130 tile_grid: "cp.ndarray", # type: ignore 

131 num_rows: int, 

132 num_cols: int, 

133 expected_dx: float, 

134 expected_dy: float, 

135 overlap_ratio: float, 

136 subpixel: bool, 

137 *, 

138 

139 quality_threshold: float = 0.5, # NIST Algorithm 15: ncc >= 0.5 for valid translations 

140 subpixel_radius: int = 3, 

141 regularization_eps_multiplier: float = 1000.0, 

142 anchor_tile_index: int = 0, 

143 debug_connection_limit: int = 3, 

144 debug_vertical_limit: int = 6, 

145 displacement_tolerance_factor: float = 2.0, 

146 displacement_tolerance_percent: float = 0.3, 

147 consistency_threshold_percent: float = 0.5, 

148 max_connections_multiplier: int = 2, 

149 adaptive_base_threshold: float = 0.3, 

150 adaptive_percentile_threshold: float = 0.25, 

151 translation_tolerance_factor: float = 0.2, 

152 translation_min_quality: float = 0.3, 

153 magnitude_threshold_multiplier: float = 1e-6, 

154 peak_candidates_multiplier: int = 4, 

155 min_peak_distance: int = 5, 

156 use_nist_robustness: bool = True, # NIST Algorithm 2: Enable multi-peak PCIAM with interpretation testing 

157 n_peaks: int = 2, # NIST Algorithm 2: n=2 peaks tested (manually selected based on experimental testing) 

158 use_nist_normalization: bool = True, # NIST Algorithm 3: Use fc/abs(fc) normalization instead of regularized approach 

159 

160 # NIST Algorithm 9: Stage model parameters 

161 overlap_uncertainty_percent: float = 3.0, # NIST default: 3% overlap uncertainty (pou) 

162 outlier_threshold_multiplier: float = 1.5, # NIST Algorithm 16: 1.5 × IQR for outlier detection 

163) -> "cp.ndarray": # type: ignore 

164 """ 

165 GPU-only global optimization using simplified MST approach. 

166 """ 

167 H, W = tile_grid.shape[2], tile_grid.shape[3] 

168 num_tiles = num_rows * num_cols 

169 

170 # Pre-allocate GPU arrays for connections 

171 max_connections = max_connections_multiplier * num_tiles # Each tile has at most 2 neighbors (right, bottom) 

172 connection_from = cp.full(max_connections, -1, dtype=cp.int32) 

173 connection_to = cp.full(max_connections, -1, dtype=cp.int32) 

174 connection_dx = cp.zeros(max_connections, dtype=cp.float32) 

175 connection_dy = cp.zeros(max_connections, dtype=cp.float32) 

176 connection_quality = cp.zeros(max_connections, dtype=cp.float32) 

177 

178 conn_idx = 0 

179 

180 # Debug: Track quality filtering 

181 total_correlations = 0 

182 passed_threshold = 0 

183 all_qualities = [] 

184 

185 # Debug: Print expected displacements and coordinate validation 

186 print(f"🔥 EXPECTED DISPLACEMENTS: dx={float(expected_dx):.1f}, dy={float(expected_dy):.1f}") 

187 print(f"🔥 OVERLAP RATIO: {overlap_ratio}, H={H}, W={W}") 

188 print("🔥 COORDINATE VALIDATION:") 

189 print(f" Expected tile spacing: dx={float(expected_dx):.1f}, dy={float(expected_dy):.1f}") 

190 print(f" Overlap regions: H*ratio={H*overlap_ratio:.1f}, W*ratio={W*overlap_ratio:.1f}") 

191 print(f" Actual overlap: H={H*overlap_ratio:.1f}, W={W*overlap_ratio:.1f} pixels") 

192 

193 # Debug: Check if images are black 

194 tile_stats = [] 

195 for r in range(num_rows): 

196 for c in range(num_cols): 

197 tile = tile_grid[r, c] 

198 tile_min = float(cp.min(tile)) 

199 tile_max = float(cp.max(tile)) 

200 tile_mean = float(cp.mean(tile)) 

201 tile_stats.append((tile_min, tile_max, tile_mean)) 

202 

203 print(f"🔥 TILE STATS: First {debug_connection_limit} tiles - min/max/mean:") 

204 for i, (tmin, tmax, tmean) in enumerate(tile_stats[:debug_connection_limit]): 

205 print(f" Tile {i}: [{tmin:.1f}, {tmax:.1f}], mean={tmean:.1f}") 

206 

207 # Build connections (GPU operations) 

208 for r in range(num_rows): 

209 for c in range(num_cols): 

210 tile_idx = r * num_cols + c 

211 current_tile = tile_grid[r, c] 

212 

213 # Horizontal connection 

214 if c < num_cols - 1: 

215 right_idx = r * num_cols + (c + 1) 

216 right_tile = tile_grid[r, c + 1] 

217 

218 overlap_w = cp.int32(W * overlap_ratio) 

219 left_region = current_tile[:, -overlap_w:] # Right edge of left tile 

220 right_region = right_tile[:, :overlap_w] # Left edge of right tile 

221 

222 # Debug: Check overlap region extraction (avoid GPU sync on .shape) 

223 if conn_idx < debug_connection_limit: 

224 print(f"🔥 HORIZONTAL OVERLAP {conn_idx}: tiles {tile_idx}->{right_idx}") 

225 print(f" overlap_w={int(overlap_w)}, W={W}") 

226 # Avoid .shape access which can cause GPU sync issues 

227 print(" Processing overlap regions (shapes not shown to avoid GPU sync)") 

228 

229 if use_nist_robustness: 

230 dy, dx, quality = phase_correlation_nist_gpu( 

231 left_region, right_region, 

232 direction='horizontal', 

233 n_peaks=n_peaks, 

234 use_nist_normalization=use_nist_normalization 

235 ) 

236 else: 

237 dy, dx = phase_correlation_gpu_only( 

238 left_region, right_region, # Standardized: left_region first 

239 subpixel=subpixel, 

240 subpixel_radius=subpixel_radius, 

241 regularization_eps_multiplier=regularization_eps_multiplier 

242 ) 

243 # Compute quality after applying the shift 

244 quality = compute_correlation_quality_gpu_aligned(left_region, right_region, dx, dy) 

245 

246 # Debug: Track all quality values 

247 total_correlations += 1 

248 all_qualities.append(quality) 

249 

250 if quality >= quality_threshold: 

251 # Convert overlap-region coordinates to tile-center coordinates 

252 tile_dy, tile_dx = _convert_overlap_to_tile_coordinates( 

253 dy, dx, int(overlap_w), int(overlap_w), H, W, 'horizontal' 

254 ) 

255 

256 # Log coordinate transformation for debugging 

257 if conn_idx < debug_connection_limit: # Only log first few for brevity 

258 log_coordinate_transformation( 

259 dy, dx, tile_dy, tile_dx, 'horizontal', (tile_idx, right_idx) 

260 ) 

261 

262 # Validate displacement magnitude 

263 displacement_valid = _validate_displacement_magnitude( 

264 tile_dx, tile_dy, float(expected_dx), 0.0, 'horizontal', 

265 displacement_tolerance_factor, displacement_tolerance_percent 

266 ) 

267 

268 if displacement_valid: 

269 passed_threshold += 1 

270 connection_from[conn_idx] = tile_idx 

271 connection_to[conn_idx] = right_idx 

272 connection_dx[conn_idx] = tile_dx 

273 connection_dy[conn_idx] = tile_dy 

274 connection_quality[conn_idx] = quality 

275 

276 # Debug: Print first few connections 

277 if conn_idx < debug_connection_limit: 

278 print(f"🔥 HORIZONTAL CONNECTION {conn_idx}: {tile_idx}->{right_idx}") 

279 print(f" overlap coords: dx={float(dx):.3f}, dy={float(dy):.3f}") 

280 print(f" tile coords: dx={float(tile_dx):.3f}, dy={float(tile_dy):.3f}") 

281 print(f" quality={float(quality):.6f}, displacement_valid={displacement_valid}") 

282 

283 conn_idx += 1 

284 else: 

285 # Debug: Log rejected connections 

286 if conn_idx < debug_connection_limit: 

287 print(f"🔥 REJECTED HORIZONTAL {tile_idx}->{right_idx}: displacement invalid") 

288 print(f" tile coords: dx={float(tile_dx):.3f}, dy={float(tile_dy):.3f}") 

289 print(f" expected: dx={float(expected_dx):.3f}, dy={float(expected_dy):.3f}") 

290 # Show validation details 

291 dx_error = abs(tile_dx - expected_dx) 

292 max_allowed_error = displacement_tolerance_factor * expected_dx * displacement_tolerance_percent 

293 max_allowed_dy = displacement_tolerance_factor * expected_dx * displacement_tolerance_percent 

294 print(f" dx_error={dx_error:.3f} vs max_allowed={max_allowed_error:.3f}") 

295 print(f" abs(dy)={abs(tile_dy):.3f} vs max_allowed_dy={max_allowed_dy:.3f}") 

296 

297 # Vertical connection 

298 if r < num_rows - 1: 

299 bottom_idx = (r + 1) * num_cols + c 

300 bottom_tile = tile_grid[r + 1, c] 

301 

302 overlap_h = cp.int32(H * overlap_ratio) 

303 top_region = current_tile[-overlap_h:, :] # Bottom edge of top tile 

304 bottom_region = bottom_tile[:overlap_h, :] # Top edge of bottom tile 

305 

306 if use_nist_robustness: 

307 dy, dx, quality = phase_correlation_nist_gpu( 

308 top_region, bottom_region, 

309 direction='vertical', 

310 n_peaks=n_peaks, 

311 use_nist_normalization=use_nist_normalization 

312 ) 

313 else: 

314 dy, dx = phase_correlation_gpu_only( 

315 top_region, bottom_region, # Standardized: top_region first 

316 subpixel=subpixel, 

317 subpixel_radius=subpixel_radius, 

318 regularization_eps_multiplier=regularization_eps_multiplier 

319 ) 

320 # Compute quality after applying the shift 

321 quality = compute_correlation_quality_gpu_aligned(top_region, bottom_region, dx, dy) 

322 

323 # Debug: Track all quality values 

324 total_correlations += 1 

325 all_qualities.append(quality) 

326 

327 if quality >= quality_threshold: 

328 # Convert overlap-region coordinates to tile-center coordinates 

329 tile_dy, tile_dx = _convert_overlap_to_tile_coordinates( 

330 dy, dx, int(overlap_h), int(overlap_h), H, W, 'vertical' 

331 ) 

332 

333 # Log coordinate transformation for debugging 

334 if conn_idx < debug_vertical_limit: # Only log first few for brevity 

335 log_coordinate_transformation( 

336 dy, dx, tile_dy, tile_dx, 'vertical', (tile_idx, bottom_idx) 

337 ) 

338 

339 # Validate displacement magnitude 

340 displacement_valid = _validate_displacement_magnitude( 

341 tile_dx, tile_dy, 0.0, float(expected_dy), 'vertical', 

342 displacement_tolerance_factor, displacement_tolerance_percent 

343 ) 

344 

345 if displacement_valid: 

346 passed_threshold += 1 

347 connection_from[conn_idx] = tile_idx 

348 connection_to[conn_idx] = bottom_idx 

349 connection_dx[conn_idx] = tile_dx 

350 connection_dy[conn_idx] = tile_dy 

351 connection_quality[conn_idx] = quality 

352 

353 # Debug: Print first few connections 

354 if conn_idx < debug_vertical_limit: # Show a few more since we want to see vertical connections too 

355 print(f"🔥 VERTICAL CONNECTION {conn_idx}: {tile_idx}->{bottom_idx}") 

356 print(f" overlap coords: dx={float(dx):.3f}, dy={float(dy):.3f}") 

357 print(f" tile coords: dx={float(tile_dx):.3f}, dy={float(tile_dy):.3f}") 

358 print(f" quality={float(quality):.6f}, displacement_valid={displacement_valid}") 

359 

360 conn_idx += 1 

361 else: 

362 # Debug: Log rejected connections 

363 if conn_idx < debug_vertical_limit: 

364 print(f"🔥 REJECTED VERTICAL {tile_idx}->{bottom_idx}: displacement invalid") 

365 print(f" tile coords: dx={float(tile_dx):.3f}, dy={float(tile_dy):.3f}") 

366 print(f" expected: dx={float(expected_dx):.3f}, dy={float(expected_dy):.3f}") 

367 # Show validation details 

368 dy_error = abs(tile_dy - expected_dy) 

369 max_allowed_error = displacement_tolerance_factor * expected_dy * displacement_tolerance_percent 

370 max_allowed_dx = displacement_tolerance_factor * expected_dy * displacement_tolerance_percent 

371 print(f" dy_error={dy_error:.3f} vs max_allowed={max_allowed_error:.3f}") 

372 print(f" abs(dx)={abs(tile_dx):.3f} vs max_allowed_dx={max_allowed_dx:.3f}") 

373 

374 # Compute adaptive quality threshold if we have quality data 

375 if len(all_qualities) > 0: 

376 adaptive_threshold = compute_adaptive_quality_threshold( 

377 all_qualities, adaptive_base_threshold, adaptive_percentile_threshold 

378 ) 

379 print(f"🔥 ADAPTIVE THRESHOLD: original={quality_threshold:.6f}, adaptive={adaptive_threshold:.6f}") 

380 

381 # Re-filter connections with adaptive threshold if it's different 

382 if adaptive_threshold != quality_threshold and adaptive_threshold < quality_threshold: 

383 print("🔥 RE-FILTERING with adaptive threshold...") 

384 # Note: In a full implementation, we'd re-process with the adaptive threshold 

385 # For now, we'll use the original threshold but log the adaptive one 

386 

387 # Debug: Print quality filtering summary 

388 print(f"🔥 QUALITY FILTERING: {passed_threshold}/{total_correlations} connections passed threshold {quality_threshold}") 

389 if len(all_qualities) > 0: 

390 min_q = float(cp.min(cp.array(all_qualities))) 

391 max_q = float(cp.max(cp.array(all_qualities))) 

392 mean_q = float(cp.mean(cp.array(all_qualities))) 

393 print(f"🔥 QUALITY RANGE: min={min_q:.6f}, max={max_q:.6f}, mean={mean_q:.6f}") 

394 

395 # Validate translation consistency (Plan 03) 

396 if conn_idx > 0: 

397 # Collect translations for validation 

398 translations = [] 

399 for i in range(conn_idx): 

400 dy_val = float(connection_dy[i]) 

401 dx_val = float(connection_dx[i]) 

402 quality_val = float(connection_quality[i]) 

403 translations.append((dy_val, dx_val, quality_val)) 

404 

405 # Validate against expected spacing 

406 expected_spacing = (float(expected_dx), float(expected_dy)) 

407 valid_flags = validate_translation_consistency( 

408 translations, expected_spacing, translation_tolerance_factor, translation_min_quality 

409 ) 

410 

411 num_valid = sum(valid_flags) 

412 print(f"🔥 TRANSLATION VALIDATION: {num_valid}/{len(translations)} connections are consistent") 

413 

414 if num_valid < len(translations) * consistency_threshold_percent: # Less than threshold% valid 

415 print(f"🔥 WARNING: Low translation consistency ({num_valid}/{len(translations)})") 

416 print(f"🔥 Expected spacing: dx={expected_spacing[0]:.1f}, dy={expected_spacing[1]:.1f}") 

417 print("🔥 Consider adjusting overlap_ratio or quality thresholds") 

418 

419 # Trim arrays to actual size (GPU) 

420 if conn_idx > 0: 

421 valid_connections = cp.arange(conn_idx) 

422 connection_from = connection_from[:conn_idx] 

423 connection_to = connection_to[:conn_idx] 

424 connection_dx = connection_dx[:conn_idx] 

425 connection_dy = connection_dy[:conn_idx] 

426 connection_quality = connection_quality[:conn_idx] 

427 

428 # Build MST using refactored GPU Borůvka's algorithm 

429 mst_edges = build_mst_gpu( 

430 connection_from, connection_to, connection_dx, 

431 connection_dy, connection_quality, num_tiles 

432 ) 

433 

434 # Rebuild positions using MST (GPU) 

435 new_positions = rebuild_positions_from_mst_gpu( 

436 positions, mst_edges, num_tiles, anchor_tile_index 

437 ) 

438 

439 return new_positions 

440 

441 return positions 

442 

443 

444@special_inputs("grid_dimensions") 

445@special_outputs("positions") 

446@cupy_func 

447def mist_compute_tile_positions( 

448 image_stack: "cp.ndarray", # type: ignore 

449 grid_dimensions: Tuple[int, int], 

450 *, 

451 # === Input Validation Parameters === 

452 method: str = "phase_correlation", 

453 fft_backend: str = "cupy", 

454 

455 # === Core Algorithm Parameters === 

456 normalize: bool = True, 

457 verbose: bool = False, 

458 overlap_ratio: float = 0.1, 

459 subpixel: bool = True, 

460 refinement_iterations: int = 10, 

461 global_optimization: bool = True, 

462 anchor_tile_index: int = 0, 

463 

464 # === Refinement Tuning Parameters === 

465 refinement_damping: float = 0.5, 

466 correlation_weight_horizontal: float = 1.0, 

467 correlation_weight_vertical: float = 1.0, 

468 

469 # === Phase Correlation Parameters === 

470 subpixel_radius: int = 3, 

471 regularization_eps_multiplier: float = 1000.0, 

472 

473 # === MST Global Optimization Parameters === 

474 mst_quality_threshold: float = 0.5, # NIST Algorithm 15: ncc >= 0.5 for MST edge inclusion 

475 # NIST robustness parameters (Algorithms 2-5) 

476 use_nist_robustness: bool = True, # Enable full NIST PCIAM implementation 

477 n_peaks: int = 2, # NIST Algorithm 2: Test 2 peaks (experimentally determined) 

478 use_nist_normalization: bool = True, # NIST Algorithm 3: fc/abs(fc) normalization 

479 # Debugging and validation parameters 

480 debug_connection_limit: int = 3, 

481 debug_vertical_limit: int = 6, 

482 displacement_tolerance_factor: float = 2.0, 

483 displacement_tolerance_percent: float = 0.3, 

484 consistency_threshold_percent: float = 0.5, 

485 max_connections_multiplier: int = 2, 

486 # Quality metric tuning parameters 

487 adaptive_base_threshold: float = 0.3, 

488 adaptive_percentile_threshold: float = 0.25, 

489 translation_tolerance_factor: float = 0.2, 

490 translation_min_quality: float = 0.3, 

491 # Phase correlation tuning parameters 

492 magnitude_threshold_multiplier: float = 1e-6, 

493 peak_candidates_multiplier: int = 4, 

494 min_peak_distance: int = 5, 

495 **kwargs 

496) -> Tuple["cp.ndarray", "cp.ndarray"]: # type: ignore 

497 """ 

498 Full GPU MIST implementation with zero CPU operations. 

499 

500 Performs microscopy image stitching using phase correlation and iterative refinement. 

501 The algorithm has three phases: 

502 1. Initial positioning using sequential phase correlation 

503 2. Iterative refinement with constraint optimization 

504 3. Global optimization using minimum spanning tree (MST) 

505 

506 Args: 

507 image_stack: 3D tensor (Z, Y, X) of tiles to stitch 

508 grid_dimensions: (num_cols, num_rows) grid layout of tiles 

509 

510 === Input Validation Parameters === 

511 method: Correlation method - must be "phase_correlation" 

512 fft_backend: FFT backend - must be "cupy" for GPU acceleration 

513 

514 === Core Algorithm Parameters (NIST Algorithms 1-3) === 

515 normalize: Normalize each tile to [0,1] range using (tile-min)/(max-min). 

516 True = better correlation accuracy, handles varying illumination. 

517 False = faster but poor results with uneven lighting. 

518 Used in NIST Algorithm 3 (PCM) preprocessing. 

519 verbose: Enable detailed logging of algorithm progress and timing 

520 overlap_ratio: Expected overlap between adjacent tiles as fraction (0.0-1.0). 

521 Defines correlation region size: overlap_w = int(W * overlap_ratio). 

522 CRITICAL: Must match actual overlap in data or correlation fails. 

523 Higher (0.2-0.4) = more robust but slower. 

524 Lower (0.05-0.08) = faster but less accurate. 

525 Used in NIST Algorithm 10 (Compute Image Overlap). 

526 subpixel: Enable subpixel-accurate phase correlation for higher precision. 

527 True = center-of-mass interpolation around correlation peak. 

528 False = pixel-only accuracy (faster, less precise). 

529 Enhances NIST Algorithm 3 (PCM) with subpixel refinement. 

530 refinement_iterations: Number of iterative position refinement passes (0-50). 

531 Each iteration applies weighted position corrections. 

532 Higher = better convergence but much slower. 

533 0 = skip refinement (fastest, least accurate). 

534 Implements NIST Algorithm 21 (Bounded NCC Hill Climb). 

535 global_optimization: Enable MST-based global optimization phase. 

536 Uses minimum spanning tree to optimize tile positions globally. 

537 Significantly improves accuracy for large grids. 

538 Implements NIST Phase 3 (Image Composition). 

539 anchor_tile_index: Index of reference tile that remains fixed at origin (usually 0). 

540 All other positions calculated relative to this tile. 

541 Used in NIST MST position reconstruction. 

542 

543 === Refinement Tuning Parameters === 

544 refinement_damping: Controls how aggressively positions are updated (0.0-1.0). 

545 Formula: new_pos = (1-damping)*old_pos + damping*correction. 

546 Higher (0.7-0.9) = faster convergence but may overshoot. 

547 Lower (0.1-0.3) = more stable but slower convergence. 

548 1.0 = full correction (may be unstable), 0.0 = no updates. 

549 correlation_weight_horizontal: Weight for horizontal tile constraints (>0). 

550 Higher values prioritize horizontal alignment accuracy. 

551 Typical range: 0.5-2.0. 

552 correlation_weight_vertical: Weight for vertical tile constraints (>0). 

553 Higher values prioritize vertical alignment accuracy. 

554 Typical range: 0.5-2.0. 

555 

556 === Phase Correlation Parameters (NIST Algorithm 3) === 

557 subpixel_radius: Radius around correlation peak for center-of-mass calculation. 

558 Extracts (2*radius+1)² region around peak for interpolation. 

559 Higher (5-10) = more accurate subpixel positioning but slower. 

560 Lower (1-2) = faster but less precise, may cause drift. 

561 0 = pixel-only accuracy (fastest, least precise). 

562 Enhances NIST Algorithm 3 (PCM) with subpixel precision. 

563 regularization_eps_multiplier: Prevents division by zero in phase correlation. 

564 Formula: eps = machine_epsilon * multiplier. 

565 Higher (10000+) = more stable with noisy images. 

566 Lower (100-500) = higher precision but may fail. 

567 Too low (<10) = risk of numerical instability. 

568 Used in NIST Algorithm 3 cross-power normalization. 

569 

570 === MST Global Optimization Parameters (NIST Algorithms 8-21) === 

571 mst_quality_threshold: Minimum correlation quality for MST edge inclusion (0.0-1.0). 

572 NIST Algorithm 15: ncc >= 0.5 for valid translations. 

573 Formula: if correlation_peak < threshold: reject_connection. 

574 NIST default: 0.5 (stricter quality control). 

575 Higher = fewer connections, lower = includes weak correlations. 

576 Too high = MST may fail, too low = includes noise. 

577 use_nist_robustness: Enable NIST robust phase correlation (Algorithm 2). 

578 True = multi-peak PCIAM with interpretation testing. 

579 False = simplified single-peak method (faster). 

580 n_peaks: Number of correlation peaks to analyze (NIST Algorithm 4). 

581 NIST default: n=2 (manually selected based on experimental testing). 

582 Higher = more robust peak selection but slower processing. 

583 use_nist_normalization: Apply NIST normalization method (Algorithm 3). 

584 True = fc/abs(fc) normalization (NIST standard). 

585 False = OpenHCS regularization method. 

586 

587 displacement_tolerance_factor: Multiplier for expected displacement tolerance. 

588 NIST Algorithm 14: Stage model displacement validation. 

589 Formula: max_error = factor * expected_displacement * percent. 

590 Higher (3.0-5.0) = more permissive validation. 

591 Lower (1.0-1.5) = stricter validation. 

592 displacement_tolerance_percent: Percentage tolerance for displacement (0.0-1.0). 

593 NIST Algorithm 14: Displacement validation threshold. 

594 Formula: valid if |actual - expected| <= expected * percent. 

595 0.3 = ±30% deviation allowed from expected displacement. 

596 Higher = accepts larger deviations, lower = stricter. 

597 

598 debug_connection_limit: Max horizontal connections to log for debugging (0-10) 

599 debug_vertical_limit: Max vertical connections to log for debugging (0-10) 

600 consistency_threshold_percent: Translation consistency validation threshold (0.0-1.0). 

601 NIST Algorithm 17: Filter by repeatability. 

602 Formula: valid if |translation - median| <= median * threshold. 

603 0.5 = ±50% deviation from median allowed. 

604 Higher = more permissive, lower = stricter consistency. 

605 max_connections_multiplier: Maximum connections per tile in MST construction. 

606 Formula: max_connections = base_connections * multiplier. 

607 Prevents over-connected graphs that slow MST algorithms. 

608 2 = allow 2x normal connections, 1 = strict minimum. 

609 adaptive_base_threshold: Minimum quality threshold for adaptive quality metrics. 

610 NIST-inspired adaptive thresholding for challenging datasets. 

611 Formula: final_threshold = max(base_threshold, percentile_threshold). 

612 0.3 = minimum 30% correlation required regardless of distribution. 

613 Prevents threshold from becoming too permissive. 

614 adaptive_percentile_threshold: Percentile-based quality threshold (0.0-1.0). 

615 NIST Algorithm 9: Stage model validation approach. 

616 Formula: threshold = percentile(all_qualities, percentile * 100). 

617 0.25 = use 25th percentile of quality distribution. 

618 Lower = more permissive, higher = stricter selection. 

619 translation_tolerance_factor: Tolerance multiplier for translation validation. 

620 NIST Algorithm 14: Stage model displacement validation. 

621 Formula: max_error = expected_displacement * factor * percent. 

622 0.2 = allow 20% deviation from expected displacement. 

623 Higher = more permissive validation. 

624 translation_min_quality: Minimum correlation quality for translation acceptance. 

625 NIST Algorithm 15: Quality-based filtering threshold. 

626 Formula: accept if ncc >= min_quality. 

627 0.3 = require 30% normalized cross-correlation minimum. 

628 Higher = stricter quality, lower = more permissive. 

629 magnitude_threshold_multiplier: FFT magnitude threshold for numerical stability. 

630 NIST Algorithm 3: Cross-power spectrum normalization. 

631 Formula: threshold = mean(magnitude) * multiplier. 

632 1e-6 = very small threshold for numerical stability. 

633 Higher = more aggressive filtering of low-magnitude frequencies. 

634 peak_candidates_multiplier: Candidate peak search multiplier for robustness. 

635 NIST Algorithm 4: Multi-peak max search optimization. 

636 Formula: n_candidates = n_peaks * multiplier. 

637 4 = search 4x more candidates than needed for robust selection. 

638 Higher = more thorough search but slower processing. 

639 min_peak_distance: Minimum pixel distance between correlation peaks. 

640 NIST Algorithm 4: Prevents duplicate peak detection. 

641 Formula: reject if distance(peak1, peak2) < min_distance. 

642 5 = peaks must be ≥5 pixels apart to be considered distinct. 

643 Higher = fewer but more distinct peaks, lower = more peaks. 

644 

645 === NIST Mathematical Formulas === 

646 

647 Algorithm 3 (PCM): Peak Correlation Matrix 

648 F1 ← fft2D(I1), F2 ← fft2D(I2) 

649 FC ← F1 .* conj(F2) 

650 PCM ← ifft2D(FC ./ abs(FC)) 

651 

652 Algorithm 6 (NCC): Normalized Cross-Correlation 

653 I1 ← I1 - mean(I1), I2 ← I2 - mean(I2) 

654 ncc = (I1 · I2) / (|I1| * |I2|) 

655 

656 Algorithm 10 (Overlap): Image Overlap Computation 

657 overlap_percent = 100 - mu (where mu is mean translation) 

658 valid_range = [overlap ± overlap_uncertainty_percent] 

659 

660 Algorithm 16 (Outliers): Statistical Outlier Detection 

661 q1 = 25th percentile, q3 = 75th percentile 

662 IQR = q3 - q1 

663 outlier if: value < (q1 - 1.5*IQR) OR value > (q3 + 1.5*IQR) 

664 

665 Algorithm 21 (Hill Climb): Bounded Translation Refinement 

666 search_bounds = [current ± repeatability] 

667 ncc_surface[i,j] = ncc(extract_overlap(I1, j, i), extract_overlap(I2, -j, -i)) 

668 climb to local maximum within bounds 

669 

670 === NIST Performance Guidance === 

671 

672 Quality Threshold Tuning: 

673 - Start with NIST default: 0.5 (strict quality control) 

674 - Lower to 0.3-0.4 for noisy biological samples 

675 - Lower to 0.1-0.2 for very challenging datasets 

676 - Monitor MST edge count: need ≥(num_tiles-1) edges minimum 

677 

678 Peak Count Optimization: 

679 - NIST default: n=2 peaks (experimentally optimal) 

680 - Increase to 3-5 for highly repetitive patterns 

681 - Keep at 2 for most microscopy applications 

682 

683 Overlap Ratio Guidelines: 

684 - Must match actual image overlap precisely 

685 - Typical microscopy: 0.1-0.2 (10-20% overlap) 

686 - Higher overlap = more robust but slower processing 

687 - Lower overlap = faster but less reliable alignment 

688 

689 Subpixel Refinement: 

690 - Enable for publication-quality results 

691 - Radius 3-5 optimal for most applications 

692 - Disable for speed-critical applications 

693 

694 Expected Performance: 

695 - With NIST defaults: High accuracy, moderate speed 

696 - Quality threshold 0.5: Strict filtering, fewer edges 

697 - Multi-peak robustness: 2-3x slower but more reliable 

698 - Global optimization: Essential for large grids (>3x3) 

699 

700 Returns: 

701 Tuple of (image_stack, positions) where: 

702 - image_stack: Original input tiles (potentially normalized) 

703 - positions: (Z, 2) array of tile positions in (x, y) format 

704 Positions are centered around origin 

705 

706 Raises: 

707 ValueError: If input validation fails (wrong method, backend, or dimensions) 

708 TypeError: If image_stack is not a CuPy array 

709 """ 

710 _validate_cupy_array(image_stack, "image_stack") 

711 

712 if image_stack.ndim != 3: 

713 raise ValueError(f"Input must be a 3D tensor, got {image_stack.ndim}D") 

714 

715 if fft_backend != "cupy": 

716 raise ValueError(f"FFT backend must be 'cupy', got '{fft_backend}'") 

717 

718 if method != "phase_correlation": 

719 raise ValueError(f"Only 'phase_correlation' method is supported, got '{method}'") 

720 

721 num_cols, num_rows = grid_dimensions 

722 Z, H, W = image_stack.shape 

723 

724 # VERY FIRST THING - Debug output to confirm function is called 

725 print("🔥🔥🔥 MIST FUNCTION ENTRY POINT - FUNCTION IS DEFINITELY BEING CALLED! 🔥🔥🔥") 

726 print(f"🔥 Image stack shape: {image_stack.shape}") 

727 print(f"🔥 Grid dimensions: {grid_dimensions}") 

728 

729 # Debug: Log the actual overlap_ratio parameter being used 

730 print(f"🔥 MIST FUNCTION CALLED WITH overlap_ratio={overlap_ratio}") 

731 print(f"🔥 Expected: 0.1 (10% overlap), Actual: {overlap_ratio}") 

732 

733 if Z != num_rows * num_cols: 

734 raise ValueError( 

735 f"Number of tiles ({Z}) does not match grid size ({num_rows}x{num_cols}={num_rows*num_cols})" 

736 ) 

737 

738 # Normalize on GPU 

739 tiles = image_stack.astype(cp.float32) 

740 if normalize: 

741 for z in range(Z): 

742 tile = tiles[z] 

743 tile_min = cp.min(tile) 

744 tile_max = cp.max(tile) 

745 tile_range = tile_max - tile_min 

746 # Use GPU conditional to avoid division by zero 

747 tiles[z] = cp.where(tile_range > 0, (tile - tile_min) / tile_range, tile) 

748 

749 # Reshape to grid (GPU operation) 

750 tile_grid = tiles.reshape(num_rows, num_cols, H, W) 

751 

752 # Calculate expected spacing (GPU) 

753 expected_dy = cp.float32(H * (1.0 - overlap_ratio)) 

754 expected_dx = cp.float32(W * (1.0 - overlap_ratio)) 

755 

756 # Initialize positions on GPU 

757 positions = cp.zeros((Z, 2), dtype=cp.float32) 

758 

759 if verbose: 

760 logger.info(f"GPU MIST: {num_rows}x{num_cols} grid, spacing: dx={float(expected_dx):.1f}, dy={float(expected_dy):.1f}") 

761 

762 # Phase 1: Initial positioning (all GPU) 

763 for r in range(num_rows): 

764 for c in range(num_cols): 

765 tile_idx = r * num_cols + c 

766 

767 if tile_idx == anchor_tile_index: 

768 positions[tile_idx] = cp.array([0.0, 0.0]) 

769 continue 

770 

771 current_tile = tile_grid[r, c] 

772 

773 # Position from left neighbor (GPU operations) 

774 if c > 0: 

775 left_idx = r * num_cols + (c - 1) 

776 left_tile = tile_grid[r, c - 1] 

777 

778 # Extract overlap regions (GPU) 

779 overlap_w = cp.int32(W * overlap_ratio) 

780 left_region = left_tile[:, -overlap_w:] 

781 current_region = current_tile[:, :overlap_w] 

782 

783 # GPU phase correlation 

784 dy, dx = phase_correlation_gpu_only( 

785 left_region, current_region, 

786 subpixel=subpixel, 

787 subpixel_radius=subpixel_radius, 

788 regularization_eps_multiplier=regularization_eps_multiplier 

789 ) 

790 

791 # Convert overlap-region coordinates to tile-center coordinates 

792 tile_dy, tile_dx = _convert_overlap_to_tile_coordinates( 

793 dy, dx, int(overlap_w), int(overlap_w), H, W, 'horizontal' 

794 ) 

795 

796 # Update position (GPU) 

797 new_x = positions[left_idx, 0] + tile_dx 

798 new_y = positions[left_idx, 1] + tile_dy 

799 positions[tile_idx] = cp.array([new_x, new_y]) 

800 

801 elif r > 0: # Position from top neighbor 

802 top_idx = (r - 1) * num_cols + c 

803 top_tile = tile_grid[r - 1, c] 

804 

805 # Extract overlap regions (GPU) 

806 overlap_h = cp.int32(H * overlap_ratio) 

807 top_region = top_tile[-overlap_h:, :] 

808 current_region = current_tile[:overlap_h, :] 

809 

810 # GPU phase correlation 

811 dy, dx = phase_correlation_gpu_only( 

812 top_region, current_region, 

813 subpixel=subpixel, 

814 subpixel_radius=subpixel_radius, 

815 regularization_eps_multiplier=regularization_eps_multiplier 

816 ) 

817 

818 # Convert overlap-region coordinates to tile-center coordinates 

819 tile_dy, tile_dx = _convert_overlap_to_tile_coordinates( 

820 dy, dx, int(overlap_h), int(overlap_h), H, W, 'vertical' 

821 ) 

822 

823 # Update position (GPU) 

824 new_x = positions[top_idx, 0] + tile_dx 

825 new_y = positions[top_idx, 1] + tile_dy 

826 positions[tile_idx] = cp.array([new_x, new_y]) 

827 

828 # Phase 2: Refinement iterations (all GPU) 

829 for iteration in range(refinement_iterations): 

830 if verbose: 

831 logger.info(f"GPU refinement iteration {iteration + 1}/{refinement_iterations}") 

832 

833 position_corrections = cp.zeros_like(positions) 

834 correction_weights = cp.zeros(Z, dtype=cp.float32) 

835 

836 # Horizontal constraints (GPU) 

837 for r in range(num_rows): 

838 for c in range(num_cols - 1): 

839 left_idx = r * num_cols + c 

840 right_idx = r * num_cols + (c + 1) 

841 

842 left_tile = tile_grid[r, c] 

843 right_tile = tile_grid[r, c + 1] 

844 

845 overlap_w = cp.int32(W * overlap_ratio) 

846 left_region = left_tile[:, -overlap_w:] # Right edge of left tile 

847 right_region = right_tile[:, :overlap_w] # Left edge of right tile 

848 

849 dy, dx = phase_correlation_gpu_only( 

850 left_region, right_region, # Standardized: left_region first 

851 subpixel=subpixel, 

852 subpixel_radius=subpixel_radius, 

853 regularization_eps_multiplier=regularization_eps_multiplier 

854 ) 

855 

856 # Convert overlap-region coordinates to tile-center coordinates 

857 tile_dy, tile_dx = _convert_overlap_to_tile_coordinates( 

858 dy, dx, int(overlap_w), int(overlap_w), H, W, 'horizontal' 

859 ) 

860 

861 # Expected position (GPU) 

862 expected_right = positions[left_idx] + cp.array([tile_dx, tile_dy]) 

863 

864 # Accumulate updates (GPU) 

865 position_corrections[right_idx] += expected_right * correlation_weight_horizontal 

866 correction_weights[right_idx] += correlation_weight_horizontal 

867 

868 # Vertical constraints (GPU) 

869 for r in range(num_rows - 1): 

870 for c in range(num_cols): 

871 top_idx = r * num_cols + c 

872 bottom_idx = (r + 1) * num_cols + c 

873 

874 top_tile = tile_grid[r, c] 

875 bottom_tile = tile_grid[r + 1, c] 

876 

877 overlap_h = cp.int32(H * overlap_ratio) 

878 top_region = top_tile[-overlap_h:, :] # Bottom edge of top tile 

879 bottom_region = bottom_tile[:overlap_h, :] # Top edge of bottom tile 

880 

881 dy, dx = phase_correlation_gpu_only( 

882 top_region, bottom_region, # Standardized: top_region first 

883 subpixel=subpixel, 

884 subpixel_radius=subpixel_radius, 

885 regularization_eps_multiplier=regularization_eps_multiplier 

886 ) 

887 

888 # Convert overlap-region coordinates to tile-center coordinates 

889 tile_dy, tile_dx = _convert_overlap_to_tile_coordinates( 

890 dy, dx, int(overlap_h), int(overlap_h), H, W, 'vertical' 

891 ) 

892 

893 # Expected position (GPU) 

894 expected_bottom = positions[top_idx] + cp.array([tile_dx, tile_dy]) 

895 

896 # Accumulate updates (GPU) 

897 position_corrections[bottom_idx] += expected_bottom * correlation_weight_vertical 

898 correction_weights[bottom_idx] += correlation_weight_vertical 

899 

900 # Apply corrections with damping (all GPU) 

901 for tile_idx in range(Z): 

902 if correction_weights[tile_idx] > 0 and tile_idx != anchor_tile_index: 

903 averaged_correction = position_corrections[tile_idx] / correction_weights[tile_idx] 

904 positions[tile_idx] = ((1 - refinement_damping) * positions[tile_idx] + 

905 refinement_damping * averaged_correction) 

906 

907 # Phase 3: Global optimization MST (GPU operations) 

908 print(f"🔥 PHASE 3: global_optimization={global_optimization}") 

909 if global_optimization: 

910 print("🔥 STARTING MST GLOBAL OPTIMIZATION") 

911 positions = _global_optimization_gpu_only( 

912 positions, tile_grid, num_rows, num_cols, 

913 expected_dx, expected_dy, overlap_ratio, subpixel, 

914 

915 quality_threshold=mst_quality_threshold, 

916 subpixel_radius=subpixel_radius, 

917 regularization_eps_multiplier=regularization_eps_multiplier, 

918 anchor_tile_index=anchor_tile_index, 

919 debug_connection_limit=debug_connection_limit, 

920 debug_vertical_limit=debug_vertical_limit, 

921 displacement_tolerance_factor=displacement_tolerance_factor, 

922 displacement_tolerance_percent=displacement_tolerance_percent, 

923 consistency_threshold_percent=consistency_threshold_percent, 

924 max_connections_multiplier=max_connections_multiplier, 

925 adaptive_base_threshold=adaptive_base_threshold, 

926 adaptive_percentile_threshold=adaptive_percentile_threshold, 

927 translation_tolerance_factor=translation_tolerance_factor, 

928 translation_min_quality=translation_min_quality, 

929 magnitude_threshold_multiplier=magnitude_threshold_multiplier, 

930 peak_candidates_multiplier=peak_candidates_multiplier, 

931 min_peak_distance=min_peak_distance, 

932 use_nist_robustness=use_nist_robustness, 

933 n_peaks=n_peaks, 

934 use_nist_normalization=use_nist_normalization 

935 ) 

936 

937 # Center positions (GPU) 

938 mean_pos = cp.mean(positions, axis=0) 

939 positions = positions - mean_pos 

940 

941 print(f"🔥 MIST COMPLETE: Returning {positions.shape} positions") 

942 return tiles, positions