Coverage for openhcs/processing/backends/pos_gen/mist/mist_main.py: 5.3%

305 statements  

« prev     ^ index     » next       coverage.py v7.10.3, created at 2025-08-14 05:57 +0000

1""" 

2Main MIST Implementation 

3 

4Full GPU-accelerated MIST implementation with zero CPU operations. 

5Orchestrates all MIST components for tile position computation. 

6""" 

7from __future__ import annotations 

8 

9import logging 

10from typing import TYPE_CHECKING, Tuple 

11 

12from openhcs.constants.constants import DEFAULT_PATCH_SIZE, DEFAULT_SEARCH_RADIUS 

13from openhcs.core.memory.decorators import cupy as cupy_func 

14from openhcs.core.pipeline.function_contracts import special_inputs, special_outputs 

15from openhcs.core.utils import optional_import 

16 

17from .phase_correlation import phase_correlation_gpu_only, phase_correlation_nist_gpu 

18from .quality_metrics import ( 

19 compute_correlation_quality_gpu_aligned, 

20 compute_adaptive_quality_threshold, 

21 validate_translation_consistency, 

22 log_coordinate_transformation, 

23 debug_phase_correlation_matrix 

24) 

25from .position_reconstruction import build_mst_gpu, rebuild_positions_from_mst_gpu 

26 

27# For type checking only 

28if TYPE_CHECKING: 28 ↛ 29line 28 didn't jump to line 29 because the condition on line 28 was never true

29 import cupy as cp 

30 

31# Import CuPy as an optional dependency 

32cp = optional_import("cupy") 

33 

34logger = logging.getLogger(__name__) 

35 

36 

37def _convert_overlap_to_tile_coordinates( 

38 dy: float, dx: float, 

39 overlap_h: int, overlap_w: int, 

40 tile_h: int, tile_w: int, 

41 direction: str 

42) -> Tuple[float, float]: 

43 """ 

44 Convert overlap-region-relative displacements to tile-center coordinates. 

45 

46 Args: 

47 dy, dx: Phase correlation displacements in overlap region coordinates 

48 overlap_h, overlap_w: Overlap region dimensions 

49 tile_h, tile_w: Full tile dimensions 

50 direction: 'horizontal' or 'vertical' 

51 

52 Returns: 

53 (tile_dy, tile_dx): Displacements in tile-center coordinates 

54 """ 

55 if direction == 'horizontal': 

56 # For horizontal connections (left-right) 

57 # Expected displacement is approximately tile_w - overlap_w 

58 expected_dx = tile_w - overlap_w 

59 tile_dx = expected_dx + dx # Add phase correlation correction 

60 tile_dy = dy # Vertical should be minimal 

61 

62 elif direction == 'vertical': 

63 # For vertical connections (top-bottom) 

64 # Expected displacement is approximately tile_h - overlap_h 

65 expected_dy = tile_h - overlap_h 

66 tile_dy = expected_dy + dy # Add phase correlation correction 

67 tile_dx = dx # Horizontal should be minimal 

68 

69 else: 

70 raise ValueError(f"Invalid direction: {direction}. Must be 'horizontal' or 'vertical'") 

71 

72 return tile_dy, tile_dx 

73 

74 

75 

76 

77 

78def _validate_displacement_magnitude( 

79 tile_dx: float, tile_dy: float, 

80 expected_dx: float, expected_dy: float, 

81 direction: str, 

82 tolerance_factor: float = 2.0, 

83 tolerance_percent: float = 0.1 

84) -> bool: 

85 """ 

86 Validate that displacement magnitudes are reasonable. 

87 

88 Args: 

89 tile_dx, tile_dy: Computed tile-center displacements 

90 expected_dx, expected_dy: Expected displacements 

91 direction: 'horizontal' or 'vertical' 

92 tolerance_factor: How much deviation to allow 

93 

94 Returns: 

95 True if displacement is reasonable, False otherwise 

96 """ 

97 if direction == 'horizontal': 

98 # For horizontal connections, dx should be close to expected_dx 

99 dx_error = abs(tile_dx - expected_dx) 

100 max_allowed_error = tolerance_factor * expected_dx * tolerance_percent 

101 dx_valid = dx_error <= max_allowed_error 

102 

103 # dy should be small (minimal vertical drift relative to expected_dx, not expected_dy) 

104 max_allowed_dy = tolerance_factor * expected_dx * tolerance_percent 

105 dy_valid = abs(tile_dy) <= max_allowed_dy 

106 

107 return dx_valid and dy_valid 

108 

109 elif direction == 'vertical': 

110 # For vertical connections, dy should be close to expected_dy 

111 dy_error = abs(tile_dy - expected_dy) 

112 max_allowed_error = tolerance_factor * expected_dy * tolerance_percent 

113 dy_valid = dy_error <= max_allowed_error 

114 

115 # dx should be small (minimal horizontal drift relative to expected_dy, not expected_dx) 

116 max_allowed_dx = tolerance_factor * expected_dy * tolerance_percent 

117 dx_valid = abs(tile_dx) <= max_allowed_dx 

118 

119 return dy_valid and dx_valid 

120 

121 return False 

122 

123 

124def _validate_cupy_array(array, name: str = "input") -> None: # type: ignore 

125 """Validate that the input is a CuPy array.""" 

126 if not isinstance(array, cp.ndarray): 

127 raise TypeError(f"{name} must be a CuPy array, got {type(array)}") 

128 

129 

130def _global_optimization_gpu_only( 

131 positions: "cp.ndarray", # type: ignore 

132 tile_grid: "cp.ndarray", # type: ignore 

133 num_rows: int, 

134 num_cols: int, 

135 expected_dx: float, 

136 expected_dy: float, 

137 overlap_ratio: float, 

138 subpixel: bool, 

139 *, 

140 

141 quality_threshold: float = 0.5, # NIST Algorithm 15: ncc >= 0.5 for valid translations 

142 subpixel_radius: int = 3, 

143 regularization_eps_multiplier: float = 1000.0, 

144 anchor_tile_index: int = 0, 

145 debug_connection_limit: int = 3, 

146 debug_vertical_limit: int = 6, 

147 displacement_tolerance_factor: float = 2.0, 

148 displacement_tolerance_percent: float = 0.3, 

149 consistency_threshold_percent: float = 0.5, 

150 max_connections_multiplier: int = 2, 

151 adaptive_base_threshold: float = 0.3, 

152 adaptive_percentile_threshold: float = 0.25, 

153 translation_tolerance_factor: float = 0.2, 

154 translation_min_quality: float = 0.3, 

155 magnitude_threshold_multiplier: float = 1e-6, 

156 peak_candidates_multiplier: int = 4, 

157 min_peak_distance: int = 5, 

158 use_nist_robustness: bool = True, # NIST Algorithm 2: Enable multi-peak PCIAM with interpretation testing 

159 n_peaks: int = 2, # NIST Algorithm 2: n=2 peaks tested (manually selected based on experimental testing) 

160 use_nist_normalization: bool = True, # NIST Algorithm 3: Use fc/abs(fc) normalization instead of regularized approach 

161 

162 # NIST Algorithm 9: Stage model parameters 

163 overlap_uncertainty_percent: float = 3.0, # NIST default: 3% overlap uncertainty (pou) 

164 outlier_threshold_multiplier: float = 1.5, # NIST Algorithm 16: 1.5 × IQR for outlier detection 

165) -> "cp.ndarray": # type: ignore 

166 """ 

167 GPU-only global optimization using simplified MST approach. 

168 """ 

169 H, W = tile_grid.shape[2], tile_grid.shape[3] 

170 num_tiles = num_rows * num_cols 

171 

172 # Pre-allocate GPU arrays for connections 

173 max_connections = max_connections_multiplier * num_tiles # Each tile has at most 2 neighbors (right, bottom) 

174 connection_from = cp.full(max_connections, -1, dtype=cp.int32) 

175 connection_to = cp.full(max_connections, -1, dtype=cp.int32) 

176 connection_dx = cp.zeros(max_connections, dtype=cp.float32) 

177 connection_dy = cp.zeros(max_connections, dtype=cp.float32) 

178 connection_quality = cp.zeros(max_connections, dtype=cp.float32) 

179 

180 conn_idx = 0 

181 

182 # Debug: Track quality filtering 

183 total_correlations = 0 

184 passed_threshold = 0 

185 all_qualities = [] 

186 

187 # Debug: Print expected displacements and coordinate validation 

188 print(f"🔥 EXPECTED DISPLACEMENTS: dx={float(expected_dx):.1f}, dy={float(expected_dy):.1f}") 

189 print(f"🔥 OVERLAP RATIO: {overlap_ratio}, H={H}, W={W}") 

190 print(f"🔥 COORDINATE VALIDATION:") 

191 print(f" Expected tile spacing: dx={float(expected_dx):.1f}, dy={float(expected_dy):.1f}") 

192 print(f" Overlap regions: H*ratio={H*overlap_ratio:.1f}, W*ratio={W*overlap_ratio:.1f}") 

193 print(f" Actual overlap: H={H*overlap_ratio:.1f}, W={W*overlap_ratio:.1f} pixels") 

194 

195 # Debug: Check if images are black 

196 tile_stats = [] 

197 for r in range(num_rows): 

198 for c in range(num_cols): 

199 tile = tile_grid[r, c] 

200 tile_min = float(cp.min(tile)) 

201 tile_max = float(cp.max(tile)) 

202 tile_mean = float(cp.mean(tile)) 

203 tile_stats.append((tile_min, tile_max, tile_mean)) 

204 

205 print(f"🔥 TILE STATS: First {debug_connection_limit} tiles - min/max/mean:") 

206 for i, (tmin, tmax, tmean) in enumerate(tile_stats[:debug_connection_limit]): 

207 print(f" Tile {i}: [{tmin:.1f}, {tmax:.1f}], mean={tmean:.1f}") 

208 

209 # Build connections (GPU operations) 

210 for r in range(num_rows): 

211 for c in range(num_cols): 

212 tile_idx = r * num_cols + c 

213 current_tile = tile_grid[r, c] 

214 

215 # Horizontal connection 

216 if c < num_cols - 1: 

217 right_idx = r * num_cols + (c + 1) 

218 right_tile = tile_grid[r, c + 1] 

219 

220 overlap_w = cp.int32(W * overlap_ratio) 

221 left_region = current_tile[:, -overlap_w:] # Right edge of left tile 

222 right_region = right_tile[:, :overlap_w] # Left edge of right tile 

223 

224 # Debug: Check overlap region extraction (avoid GPU sync on .shape) 

225 if conn_idx < debug_connection_limit: 

226 print(f"🔥 HORIZONTAL OVERLAP {conn_idx}: tiles {tile_idx}->{right_idx}") 

227 print(f" overlap_w={int(overlap_w)}, W={W}") 

228 # Avoid .shape access which can cause GPU sync issues 

229 print(f" Processing overlap regions (shapes not shown to avoid GPU sync)") 

230 

231 if use_nist_robustness: 

232 dy, dx, quality = phase_correlation_nist_gpu( 

233 left_region, right_region, 

234 direction='horizontal', 

235 n_peaks=n_peaks, 

236 use_nist_normalization=use_nist_normalization 

237 ) 

238 else: 

239 dy, dx = phase_correlation_gpu_only( 

240 left_region, right_region, # Standardized: left_region first 

241 subpixel=subpixel, 

242 subpixel_radius=subpixel_radius, 

243 regularization_eps_multiplier=regularization_eps_multiplier 

244 ) 

245 # Compute quality after applying the shift 

246 quality = compute_correlation_quality_gpu_aligned(left_region, right_region, dx, dy) 

247 

248 # Debug: Track all quality values 

249 total_correlations += 1 

250 all_qualities.append(quality) 

251 

252 if quality >= quality_threshold: 

253 # Convert overlap-region coordinates to tile-center coordinates 

254 tile_dy, tile_dx = _convert_overlap_to_tile_coordinates( 

255 dy, dx, int(overlap_w), int(overlap_w), H, W, 'horizontal' 

256 ) 

257 

258 # Log coordinate transformation for debugging 

259 if conn_idx < debug_connection_limit: # Only log first few for brevity 

260 log_coordinate_transformation( 

261 dy, dx, tile_dy, tile_dx, 'horizontal', (tile_idx, right_idx) 

262 ) 

263 

264 # Validate displacement magnitude 

265 displacement_valid = _validate_displacement_magnitude( 

266 tile_dx, tile_dy, float(expected_dx), 0.0, 'horizontal', 

267 displacement_tolerance_factor, displacement_tolerance_percent 

268 ) 

269 

270 if displacement_valid: 

271 passed_threshold += 1 

272 connection_from[conn_idx] = tile_idx 

273 connection_to[conn_idx] = right_idx 

274 connection_dx[conn_idx] = tile_dx 

275 connection_dy[conn_idx] = tile_dy 

276 connection_quality[conn_idx] = quality 

277 

278 # Debug: Print first few connections 

279 if conn_idx < debug_connection_limit: 

280 print(f"🔥 HORIZONTAL CONNECTION {conn_idx}: {tile_idx}->{right_idx}") 

281 print(f" overlap coords: dx={float(dx):.3f}, dy={float(dy):.3f}") 

282 print(f" tile coords: dx={float(tile_dx):.3f}, dy={float(tile_dy):.3f}") 

283 print(f" quality={float(quality):.6f}, displacement_valid={displacement_valid}") 

284 

285 conn_idx += 1 

286 else: 

287 # Debug: Log rejected connections 

288 if conn_idx < debug_connection_limit: 

289 print(f"🔥 REJECTED HORIZONTAL {tile_idx}->{right_idx}: displacement invalid") 

290 print(f" tile coords: dx={float(tile_dx):.3f}, dy={float(tile_dy):.3f}") 

291 print(f" expected: dx={float(expected_dx):.3f}, dy={float(expected_dy):.3f}") 

292 # Show validation details 

293 dx_error = abs(tile_dx - expected_dx) 

294 max_allowed_error = displacement_tolerance_factor * expected_dx * displacement_tolerance_percent 

295 max_allowed_dy = displacement_tolerance_factor * expected_dx * displacement_tolerance_percent 

296 print(f" dx_error={dx_error:.3f} vs max_allowed={max_allowed_error:.3f}") 

297 print(f" abs(dy)={abs(tile_dy):.3f} vs max_allowed_dy={max_allowed_dy:.3f}") 

298 

299 # Vertical connection 

300 if r < num_rows - 1: 

301 bottom_idx = (r + 1) * num_cols + c 

302 bottom_tile = tile_grid[r + 1, c] 

303 

304 overlap_h = cp.int32(H * overlap_ratio) 

305 top_region = current_tile[-overlap_h:, :] # Bottom edge of top tile 

306 bottom_region = bottom_tile[:overlap_h, :] # Top edge of bottom tile 

307 

308 if use_nist_robustness: 

309 dy, dx, quality = phase_correlation_nist_gpu( 

310 top_region, bottom_region, 

311 direction='vertical', 

312 n_peaks=n_peaks, 

313 use_nist_normalization=use_nist_normalization 

314 ) 

315 else: 

316 dy, dx = phase_correlation_gpu_only( 

317 top_region, bottom_region, # Standardized: top_region first 

318 subpixel=subpixel, 

319 subpixel_radius=subpixel_radius, 

320 regularization_eps_multiplier=regularization_eps_multiplier 

321 ) 

322 # Compute quality after applying the shift 

323 quality = compute_correlation_quality_gpu_aligned(top_region, bottom_region, dx, dy) 

324 

325 # Debug: Track all quality values 

326 total_correlations += 1 

327 all_qualities.append(quality) 

328 

329 if quality >= quality_threshold: 

330 # Convert overlap-region coordinates to tile-center coordinates 

331 tile_dy, tile_dx = _convert_overlap_to_tile_coordinates( 

332 dy, dx, int(overlap_h), int(overlap_h), H, W, 'vertical' 

333 ) 

334 

335 # Log coordinate transformation for debugging 

336 if conn_idx < debug_vertical_limit: # Only log first few for brevity 

337 log_coordinate_transformation( 

338 dy, dx, tile_dy, tile_dx, 'vertical', (tile_idx, bottom_idx) 

339 ) 

340 

341 # Validate displacement magnitude 

342 displacement_valid = _validate_displacement_magnitude( 

343 tile_dx, tile_dy, 0.0, float(expected_dy), 'vertical', 

344 displacement_tolerance_factor, displacement_tolerance_percent 

345 ) 

346 

347 if displacement_valid: 

348 passed_threshold += 1 

349 connection_from[conn_idx] = tile_idx 

350 connection_to[conn_idx] = bottom_idx 

351 connection_dx[conn_idx] = tile_dx 

352 connection_dy[conn_idx] = tile_dy 

353 connection_quality[conn_idx] = quality 

354 

355 # Debug: Print first few connections 

356 if conn_idx < debug_vertical_limit: # Show a few more since we want to see vertical connections too 

357 print(f"🔥 VERTICAL CONNECTION {conn_idx}: {tile_idx}->{bottom_idx}") 

358 print(f" overlap coords: dx={float(dx):.3f}, dy={float(dy):.3f}") 

359 print(f" tile coords: dx={float(tile_dx):.3f}, dy={float(tile_dy):.3f}") 

360 print(f" quality={float(quality):.6f}, displacement_valid={displacement_valid}") 

361 

362 conn_idx += 1 

363 else: 

364 # Debug: Log rejected connections 

365 if conn_idx < debug_vertical_limit: 

366 print(f"🔥 REJECTED VERTICAL {tile_idx}->{bottom_idx}: displacement invalid") 

367 print(f" tile coords: dx={float(tile_dx):.3f}, dy={float(tile_dy):.3f}") 

368 print(f" expected: dx={float(expected_dx):.3f}, dy={float(expected_dy):.3f}") 

369 # Show validation details 

370 dy_error = abs(tile_dy - expected_dy) 

371 max_allowed_error = displacement_tolerance_factor * expected_dy * displacement_tolerance_percent 

372 max_allowed_dx = displacement_tolerance_factor * expected_dy * displacement_tolerance_percent 

373 print(f" dy_error={dy_error:.3f} vs max_allowed={max_allowed_error:.3f}") 

374 print(f" abs(dx)={abs(tile_dx):.3f} vs max_allowed_dx={max_allowed_dx:.3f}") 

375 

376 # Compute adaptive quality threshold if we have quality data 

377 if len(all_qualities) > 0: 

378 adaptive_threshold = compute_adaptive_quality_threshold( 

379 all_qualities, adaptive_base_threshold, adaptive_percentile_threshold 

380 ) 

381 print(f"🔥 ADAPTIVE THRESHOLD: original={quality_threshold:.6f}, adaptive={adaptive_threshold:.6f}") 

382 

383 # Re-filter connections with adaptive threshold if it's different 

384 if adaptive_threshold != quality_threshold and adaptive_threshold < quality_threshold: 

385 print(f"🔥 RE-FILTERING with adaptive threshold...") 

386 # Note: In a full implementation, we'd re-process with the adaptive threshold 

387 # For now, we'll use the original threshold but log the adaptive one 

388 

389 # Debug: Print quality filtering summary 

390 print(f"🔥 QUALITY FILTERING: {passed_threshold}/{total_correlations} connections passed threshold {quality_threshold}") 

391 if len(all_qualities) > 0: 

392 min_q = float(cp.min(cp.array(all_qualities))) 

393 max_q = float(cp.max(cp.array(all_qualities))) 

394 mean_q = float(cp.mean(cp.array(all_qualities))) 

395 print(f"🔥 QUALITY RANGE: min={min_q:.6f}, max={max_q:.6f}, mean={mean_q:.6f}") 

396 

397 # Validate translation consistency (Plan 03) 

398 if conn_idx > 0: 

399 # Collect translations for validation 

400 translations = [] 

401 for i in range(conn_idx): 

402 dy_val = float(connection_dy[i]) 

403 dx_val = float(connection_dx[i]) 

404 quality_val = float(connection_quality[i]) 

405 translations.append((dy_val, dx_val, quality_val)) 

406 

407 # Validate against expected spacing 

408 expected_spacing = (float(expected_dx), float(expected_dy)) 

409 valid_flags = validate_translation_consistency( 

410 translations, expected_spacing, translation_tolerance_factor, translation_min_quality 

411 ) 

412 

413 num_valid = sum(valid_flags) 

414 print(f"🔥 TRANSLATION VALIDATION: {num_valid}/{len(translations)} connections are consistent") 

415 

416 if num_valid < len(translations) * consistency_threshold_percent: # Less than threshold% valid 

417 print(f"🔥 WARNING: Low translation consistency ({num_valid}/{len(translations)})") 

418 print(f"🔥 Expected spacing: dx={expected_spacing[0]:.1f}, dy={expected_spacing[1]:.1f}") 

419 print(f"🔥 Consider adjusting overlap_ratio or quality thresholds") 

420 

421 # Trim arrays to actual size (GPU) 

422 if conn_idx > 0: 

423 valid_connections = cp.arange(conn_idx) 

424 connection_from = connection_from[:conn_idx] 

425 connection_to = connection_to[:conn_idx] 

426 connection_dx = connection_dx[:conn_idx] 

427 connection_dy = connection_dy[:conn_idx] 

428 connection_quality = connection_quality[:conn_idx] 

429 

430 # Build MST using refactored GPU Borůvka's algorithm 

431 mst_edges = build_mst_gpu( 

432 connection_from, connection_to, connection_dx, 

433 connection_dy, connection_quality, num_tiles 

434 ) 

435 

436 # Rebuild positions using MST (GPU) 

437 new_positions = rebuild_positions_from_mst_gpu( 

438 positions, mst_edges, num_tiles, anchor_tile_index 

439 ) 

440 

441 return new_positions 

442 

443 return positions 

444 

445 

446@special_inputs("grid_dimensions") 

447@special_outputs("positions") 

448@cupy_func 

449def mist_compute_tile_positions( 

450 image_stack: "cp.ndarray", # type: ignore 

451 grid_dimensions: Tuple[int, int], 

452 *, 

453 # === Input Validation Parameters === 

454 method: str = "phase_correlation", 

455 fft_backend: str = "cupy", 

456 

457 # === Core Algorithm Parameters === 

458 normalize: bool = True, 

459 verbose: bool = False, 

460 overlap_ratio: float = 0.1, 

461 subpixel: bool = True, 

462 refinement_iterations: int = 10, 

463 global_optimization: bool = True, 

464 anchor_tile_index: int = 0, 

465 

466 # === Refinement Tuning Parameters === 

467 refinement_damping: float = 0.5, 

468 correlation_weight_horizontal: float = 1.0, 

469 correlation_weight_vertical: float = 1.0, 

470 

471 # === Phase Correlation Parameters === 

472 subpixel_radius: int = 3, 

473 regularization_eps_multiplier: float = 1000.0, 

474 

475 # === MST Global Optimization Parameters === 

476 mst_quality_threshold: float = 0.5, # NIST Algorithm 15: ncc >= 0.5 for MST edge inclusion 

477 # NIST robustness parameters (Algorithms 2-5) 

478 use_nist_robustness: bool = True, # Enable full NIST PCIAM implementation 

479 n_peaks: int = 2, # NIST Algorithm 2: Test 2 peaks (experimentally determined) 

480 use_nist_normalization: bool = True, # NIST Algorithm 3: fc/abs(fc) normalization 

481 # Debugging and validation parameters 

482 debug_connection_limit: int = 3, 

483 debug_vertical_limit: int = 6, 

484 displacement_tolerance_factor: float = 2.0, 

485 displacement_tolerance_percent: float = 0.3, 

486 consistency_threshold_percent: float = 0.5, 

487 max_connections_multiplier: int = 2, 

488 # Quality metric tuning parameters 

489 adaptive_base_threshold: float = 0.3, 

490 adaptive_percentile_threshold: float = 0.25, 

491 translation_tolerance_factor: float = 0.2, 

492 translation_min_quality: float = 0.3, 

493 # Phase correlation tuning parameters 

494 magnitude_threshold_multiplier: float = 1e-6, 

495 peak_candidates_multiplier: int = 4, 

496 min_peak_distance: int = 5, 

497 **kwargs 

498) -> Tuple["cp.ndarray", "cp.ndarray"]: # type: ignore 

499 """ 

500 Full GPU MIST implementation with zero CPU operations. 

501 

502 Performs microscopy image stitching using phase correlation and iterative refinement. 

503 The algorithm has three phases: 

504 1. Initial positioning using sequential phase correlation 

505 2. Iterative refinement with constraint optimization 

506 3. Global optimization using minimum spanning tree (MST) 

507 

508 Args: 

509 image_stack: 3D tensor (Z, Y, X) of tiles to stitch 

510 grid_dimensions: (num_cols, num_rows) grid layout of tiles 

511 

512 === Input Validation Parameters === 

513 method: Correlation method - must be "phase_correlation" 

514 fft_backend: FFT backend - must be "cupy" for GPU acceleration 

515 

516 === Core Algorithm Parameters (NIST Algorithms 1-3) === 

517 normalize: Normalize each tile to [0,1] range using (tile-min)/(max-min). 

518 True = better correlation accuracy, handles varying illumination. 

519 False = faster but poor results with uneven lighting. 

520 Used in NIST Algorithm 3 (PCM) preprocessing. 

521 verbose: Enable detailed logging of algorithm progress and timing 

522 overlap_ratio: Expected overlap between adjacent tiles as fraction (0.0-1.0). 

523 Defines correlation region size: overlap_w = int(W * overlap_ratio). 

524 CRITICAL: Must match actual overlap in data or correlation fails. 

525 Higher (0.2-0.4) = more robust but slower. 

526 Lower (0.05-0.08) = faster but less accurate. 

527 Used in NIST Algorithm 10 (Compute Image Overlap). 

528 subpixel: Enable subpixel-accurate phase correlation for higher precision. 

529 True = center-of-mass interpolation around correlation peak. 

530 False = pixel-only accuracy (faster, less precise). 

531 Enhances NIST Algorithm 3 (PCM) with subpixel refinement. 

532 refinement_iterations: Number of iterative position refinement passes (0-50). 

533 Each iteration applies weighted position corrections. 

534 Higher = better convergence but much slower. 

535 0 = skip refinement (fastest, least accurate). 

536 Implements NIST Algorithm 21 (Bounded NCC Hill Climb). 

537 global_optimization: Enable MST-based global optimization phase. 

538 Uses minimum spanning tree to optimize tile positions globally. 

539 Significantly improves accuracy for large grids. 

540 Implements NIST Phase 3 (Image Composition). 

541 anchor_tile_index: Index of reference tile that remains fixed at origin (usually 0). 

542 All other positions calculated relative to this tile. 

543 Used in NIST MST position reconstruction. 

544 

545 === Refinement Tuning Parameters === 

546 refinement_damping: Controls how aggressively positions are updated (0.0-1.0). 

547 Formula: new_pos = (1-damping)*old_pos + damping*correction. 

548 Higher (0.7-0.9) = faster convergence but may overshoot. 

549 Lower (0.1-0.3) = more stable but slower convergence. 

550 1.0 = full correction (may be unstable), 0.0 = no updates. 

551 correlation_weight_horizontal: Weight for horizontal tile constraints (>0). 

552 Higher values prioritize horizontal alignment accuracy. 

553 Typical range: 0.5-2.0. 

554 correlation_weight_vertical: Weight for vertical tile constraints (>0). 

555 Higher values prioritize vertical alignment accuracy. 

556 Typical range: 0.5-2.0. 

557 

558 === Phase Correlation Parameters (NIST Algorithm 3) === 

559 subpixel_radius: Radius around correlation peak for center-of-mass calculation. 

560 Extracts (2*radius+1)² region around peak for interpolation. 

561 Higher (5-10) = more accurate subpixel positioning but slower. 

562 Lower (1-2) = faster but less precise, may cause drift. 

563 0 = pixel-only accuracy (fastest, least precise). 

564 Enhances NIST Algorithm 3 (PCM) with subpixel precision. 

565 regularization_eps_multiplier: Prevents division by zero in phase correlation. 

566 Formula: eps = machine_epsilon * multiplier. 

567 Higher (10000+) = more stable with noisy images. 

568 Lower (100-500) = higher precision but may fail. 

569 Too low (<10) = risk of numerical instability. 

570 Used in NIST Algorithm 3 cross-power normalization. 

571 

572 === MST Global Optimization Parameters (NIST Algorithms 8-21) === 

573 mst_quality_threshold: Minimum correlation quality for MST edge inclusion (0.0-1.0). 

574 NIST Algorithm 15: ncc >= 0.5 for valid translations. 

575 Formula: if correlation_peak < threshold: reject_connection. 

576 NIST default: 0.5 (stricter quality control). 

577 Higher = fewer connections, lower = includes weak correlations. 

578 Too high = MST may fail, too low = includes noise. 

579 use_nist_robustness: Enable NIST robust phase correlation (Algorithm 2). 

580 True = multi-peak PCIAM with interpretation testing. 

581 False = simplified single-peak method (faster). 

582 n_peaks: Number of correlation peaks to analyze (NIST Algorithm 4). 

583 NIST default: n=2 (manually selected based on experimental testing). 

584 Higher = more robust peak selection but slower processing. 

585 use_nist_normalization: Apply NIST normalization method (Algorithm 3). 

586 True = fc/abs(fc) normalization (NIST standard). 

587 False = OpenHCS regularization method. 

588 

589 displacement_tolerance_factor: Multiplier for expected displacement tolerance. 

590 NIST Algorithm 14: Stage model displacement validation. 

591 Formula: max_error = factor * expected_displacement * percent. 

592 Higher (3.0-5.0) = more permissive validation. 

593 Lower (1.0-1.5) = stricter validation. 

594 displacement_tolerance_percent: Percentage tolerance for displacement (0.0-1.0). 

595 NIST Algorithm 14: Displacement validation threshold. 

596 Formula: valid if |actual - expected| <= expected * percent. 

597 0.3 = ±30% deviation allowed from expected displacement. 

598 Higher = accepts larger deviations, lower = stricter. 

599 

600 debug_connection_limit: Max horizontal connections to log for debugging (0-10) 

601 debug_vertical_limit: Max vertical connections to log for debugging (0-10) 

602 consistency_threshold_percent: Translation consistency validation threshold (0.0-1.0). 

603 NIST Algorithm 17: Filter by repeatability. 

604 Formula: valid if |translation - median| <= median * threshold. 

605 0.5 = ±50% deviation from median allowed. 

606 Higher = more permissive, lower = stricter consistency. 

607 max_connections_multiplier: Maximum connections per tile in MST construction. 

608 Formula: max_connections = base_connections * multiplier. 

609 Prevents over-connected graphs that slow MST algorithms. 

610 2 = allow 2x normal connections, 1 = strict minimum. 

611 adaptive_base_threshold: Minimum quality threshold for adaptive quality metrics. 

612 NIST-inspired adaptive thresholding for challenging datasets. 

613 Formula: final_threshold = max(base_threshold, percentile_threshold). 

614 0.3 = minimum 30% correlation required regardless of distribution. 

615 Prevents threshold from becoming too permissive. 

616 adaptive_percentile_threshold: Percentile-based quality threshold (0.0-1.0). 

617 NIST Algorithm 9: Stage model validation approach. 

618 Formula: threshold = percentile(all_qualities, percentile * 100). 

619 0.25 = use 25th percentile of quality distribution. 

620 Lower = more permissive, higher = stricter selection. 

621 translation_tolerance_factor: Tolerance multiplier for translation validation. 

622 NIST Algorithm 14: Stage model displacement validation. 

623 Formula: max_error = expected_displacement * factor * percent. 

624 0.2 = allow 20% deviation from expected displacement. 

625 Higher = more permissive validation. 

626 translation_min_quality: Minimum correlation quality for translation acceptance. 

627 NIST Algorithm 15: Quality-based filtering threshold. 

628 Formula: accept if ncc >= min_quality. 

629 0.3 = require 30% normalized cross-correlation minimum. 

630 Higher = stricter quality, lower = more permissive. 

631 magnitude_threshold_multiplier: FFT magnitude threshold for numerical stability. 

632 NIST Algorithm 3: Cross-power spectrum normalization. 

633 Formula: threshold = mean(magnitude) * multiplier. 

634 1e-6 = very small threshold for numerical stability. 

635 Higher = more aggressive filtering of low-magnitude frequencies. 

636 peak_candidates_multiplier: Candidate peak search multiplier for robustness. 

637 NIST Algorithm 4: Multi-peak max search optimization. 

638 Formula: n_candidates = n_peaks * multiplier. 

639 4 = search 4x more candidates than needed for robust selection. 

640 Higher = more thorough search but slower processing. 

641 min_peak_distance: Minimum pixel distance between correlation peaks. 

642 NIST Algorithm 4: Prevents duplicate peak detection. 

643 Formula: reject if distance(peak1, peak2) < min_distance. 

644 5 = peaks must be ≥5 pixels apart to be considered distinct. 

645 Higher = fewer but more distinct peaks, lower = more peaks. 

646 

647 === NIST Mathematical Formulas === 

648 

649 Algorithm 3 (PCM): Peak Correlation Matrix 

650 F1 ← fft2D(I1), F2 ← fft2D(I2) 

651 FC ← F1 .* conj(F2) 

652 PCM ← ifft2D(FC ./ abs(FC)) 

653 

654 Algorithm 6 (NCC): Normalized Cross-Correlation 

655 I1 ← I1 - mean(I1), I2 ← I2 - mean(I2) 

656 ncc = (I1 · I2) / (|I1| * |I2|) 

657 

658 Algorithm 10 (Overlap): Image Overlap Computation 

659 overlap_percent = 100 - mu (where mu is mean translation) 

660 valid_range = [overlap ± overlap_uncertainty_percent] 

661 

662 Algorithm 16 (Outliers): Statistical Outlier Detection 

663 q1 = 25th percentile, q3 = 75th percentile 

664 IQR = q3 - q1 

665 outlier if: value < (q1 - 1.5*IQR) OR value > (q3 + 1.5*IQR) 

666 

667 Algorithm 21 (Hill Climb): Bounded Translation Refinement 

668 search_bounds = [current ± repeatability] 

669 ncc_surface[i,j] = ncc(extract_overlap(I1, j, i), extract_overlap(I2, -j, -i)) 

670 climb to local maximum within bounds 

671 

672 === NIST Performance Guidance === 

673 

674 Quality Threshold Tuning: 

675 - Start with NIST default: 0.5 (strict quality control) 

676 - Lower to 0.3-0.4 for noisy biological samples 

677 - Lower to 0.1-0.2 for very challenging datasets 

678 - Monitor MST edge count: need ≥(num_tiles-1) edges minimum 

679 

680 Peak Count Optimization: 

681 - NIST default: n=2 peaks (experimentally optimal) 

682 - Increase to 3-5 for highly repetitive patterns 

683 - Keep at 2 for most microscopy applications 

684 

685 Overlap Ratio Guidelines: 

686 - Must match actual image overlap precisely 

687 - Typical microscopy: 0.1-0.2 (10-20% overlap) 

688 - Higher overlap = more robust but slower processing 

689 - Lower overlap = faster but less reliable alignment 

690 

691 Subpixel Refinement: 

692 - Enable for publication-quality results 

693 - Radius 3-5 optimal for most applications 

694 - Disable for speed-critical applications 

695 

696 Expected Performance: 

697 - With NIST defaults: High accuracy, moderate speed 

698 - Quality threshold 0.5: Strict filtering, fewer edges 

699 - Multi-peak robustness: 2-3x slower but more reliable 

700 - Global optimization: Essential for large grids (>3x3) 

701 

702 Returns: 

703 Tuple of (image_stack, positions) where: 

704 - image_stack: Original input tiles (potentially normalized) 

705 - positions: (Z, 2) array of tile positions in (x, y) format 

706 Positions are centered around origin 

707 

708 Raises: 

709 ValueError: If input validation fails (wrong method, backend, or dimensions) 

710 TypeError: If image_stack is not a CuPy array 

711 """ 

712 _validate_cupy_array(image_stack, "image_stack") 

713 

714 if image_stack.ndim != 3: 

715 raise ValueError(f"Input must be a 3D tensor, got {image_stack.ndim}D") 

716 

717 if fft_backend != "cupy": 

718 raise ValueError(f"FFT backend must be 'cupy', got '{fft_backend}'") 

719 

720 if method != "phase_correlation": 

721 raise ValueError(f"Only 'phase_correlation' method is supported, got '{method}'") 

722 

723 num_cols, num_rows = grid_dimensions 

724 Z, H, W = image_stack.shape 

725 

726 # VERY FIRST THING - Debug output to confirm function is called 

727 print("🔥🔥🔥 MIST FUNCTION ENTRY POINT - FUNCTION IS DEFINITELY BEING CALLED! 🔥🔥🔥") 

728 print(f"🔥 Image stack shape: {image_stack.shape}") 

729 print(f"🔥 Grid dimensions: {grid_dimensions}") 

730 

731 # Debug: Log the actual overlap_ratio parameter being used 

732 print(f"🔥 MIST FUNCTION CALLED WITH overlap_ratio={overlap_ratio}") 

733 print(f"🔥 Expected: 0.1 (10% overlap), Actual: {overlap_ratio}") 

734 

735 if Z != num_rows * num_cols: 

736 raise ValueError( 

737 f"Number of tiles ({Z}) does not match grid size ({num_rows}x{num_cols}={num_rows*num_cols})" 

738 ) 

739 

740 # Normalize on GPU 

741 tiles = image_stack.astype(cp.float32) 

742 if normalize: 

743 for z in range(Z): 

744 tile = tiles[z] 

745 tile_min = cp.min(tile) 

746 tile_max = cp.max(tile) 

747 tile_range = tile_max - tile_min 

748 # Use GPU conditional to avoid division by zero 

749 tiles[z] = cp.where(tile_range > 0, (tile - tile_min) / tile_range, tile) 

750 

751 # Reshape to grid (GPU operation) 

752 tile_grid = tiles.reshape(num_rows, num_cols, H, W) 

753 

754 # Calculate expected spacing (GPU) 

755 expected_dy = cp.float32(H * (1.0 - overlap_ratio)) 

756 expected_dx = cp.float32(W * (1.0 - overlap_ratio)) 

757 

758 # Initialize positions on GPU 

759 positions = cp.zeros((Z, 2), dtype=cp.float32) 

760 

761 if verbose: 

762 logger.info(f"GPU MIST: {num_rows}x{num_cols} grid, spacing: dx={float(expected_dx):.1f}, dy={float(expected_dy):.1f}") 

763 

764 # Phase 1: Initial positioning (all GPU) 

765 for r in range(num_rows): 

766 for c in range(num_cols): 

767 tile_idx = r * num_cols + c 

768 

769 if tile_idx == anchor_tile_index: 

770 positions[tile_idx] = cp.array([0.0, 0.0]) 

771 continue 

772 

773 current_tile = tile_grid[r, c] 

774 

775 # Position from left neighbor (GPU operations) 

776 if c > 0: 

777 left_idx = r * num_cols + (c - 1) 

778 left_tile = tile_grid[r, c - 1] 

779 

780 # Extract overlap regions (GPU) 

781 overlap_w = cp.int32(W * overlap_ratio) 

782 left_region = left_tile[:, -overlap_w:] 

783 current_region = current_tile[:, :overlap_w] 

784 

785 # GPU phase correlation 

786 dy, dx = phase_correlation_gpu_only( 

787 left_region, current_region, 

788 subpixel=subpixel, 

789 subpixel_radius=subpixel_radius, 

790 regularization_eps_multiplier=regularization_eps_multiplier 

791 ) 

792 

793 # Convert overlap-region coordinates to tile-center coordinates 

794 tile_dy, tile_dx = _convert_overlap_to_tile_coordinates( 

795 dy, dx, int(overlap_w), int(overlap_w), H, W, 'horizontal' 

796 ) 

797 

798 # Update position (GPU) 

799 new_x = positions[left_idx, 0] + tile_dx 

800 new_y = positions[left_idx, 1] + tile_dy 

801 positions[tile_idx] = cp.array([new_x, new_y]) 

802 

803 elif r > 0: # Position from top neighbor 

804 top_idx = (r - 1) * num_cols + c 

805 top_tile = tile_grid[r - 1, c] 

806 

807 # Extract overlap regions (GPU) 

808 overlap_h = cp.int32(H * overlap_ratio) 

809 top_region = top_tile[-overlap_h:, :] 

810 current_region = current_tile[:overlap_h, :] 

811 

812 # GPU phase correlation 

813 dy, dx = phase_correlation_gpu_only( 

814 top_region, current_region, 

815 subpixel=subpixel, 

816 subpixel_radius=subpixel_radius, 

817 regularization_eps_multiplier=regularization_eps_multiplier 

818 ) 

819 

820 # Convert overlap-region coordinates to tile-center coordinates 

821 tile_dy, tile_dx = _convert_overlap_to_tile_coordinates( 

822 dy, dx, int(overlap_h), int(overlap_h), H, W, 'vertical' 

823 ) 

824 

825 # Update position (GPU) 

826 new_x = positions[top_idx, 0] + tile_dx 

827 new_y = positions[top_idx, 1] + tile_dy 

828 positions[tile_idx] = cp.array([new_x, new_y]) 

829 

830 # Phase 2: Refinement iterations (all GPU) 

831 for iteration in range(refinement_iterations): 

832 if verbose: 

833 logger.info(f"GPU refinement iteration {iteration + 1}/{refinement_iterations}") 

834 

835 position_corrections = cp.zeros_like(positions) 

836 correction_weights = cp.zeros(Z, dtype=cp.float32) 

837 

838 # Horizontal constraints (GPU) 

839 for r in range(num_rows): 

840 for c in range(num_cols - 1): 

841 left_idx = r * num_cols + c 

842 right_idx = r * num_cols + (c + 1) 

843 

844 left_tile = tile_grid[r, c] 

845 right_tile = tile_grid[r, c + 1] 

846 

847 overlap_w = cp.int32(W * overlap_ratio) 

848 left_region = left_tile[:, -overlap_w:] # Right edge of left tile 

849 right_region = right_tile[:, :overlap_w] # Left edge of right tile 

850 

851 dy, dx = phase_correlation_gpu_only( 

852 left_region, right_region, # Standardized: left_region first 

853 subpixel=subpixel, 

854 subpixel_radius=subpixel_radius, 

855 regularization_eps_multiplier=regularization_eps_multiplier 

856 ) 

857 

858 # Convert overlap-region coordinates to tile-center coordinates 

859 tile_dy, tile_dx = _convert_overlap_to_tile_coordinates( 

860 dy, dx, int(overlap_w), int(overlap_w), H, W, 'horizontal' 

861 ) 

862 

863 # Expected position (GPU) 

864 expected_right = positions[left_idx] + cp.array([tile_dx, tile_dy]) 

865 

866 # Accumulate updates (GPU) 

867 position_corrections[right_idx] += expected_right * correlation_weight_horizontal 

868 correction_weights[right_idx] += correlation_weight_horizontal 

869 

870 # Vertical constraints (GPU) 

871 for r in range(num_rows - 1): 

872 for c in range(num_cols): 

873 top_idx = r * num_cols + c 

874 bottom_idx = (r + 1) * num_cols + c 

875 

876 top_tile = tile_grid[r, c] 

877 bottom_tile = tile_grid[r + 1, c] 

878 

879 overlap_h = cp.int32(H * overlap_ratio) 

880 top_region = top_tile[-overlap_h:, :] # Bottom edge of top tile 

881 bottom_region = bottom_tile[:overlap_h, :] # Top edge of bottom tile 

882 

883 dy, dx = phase_correlation_gpu_only( 

884 top_region, bottom_region, # Standardized: top_region first 

885 subpixel=subpixel, 

886 subpixel_radius=subpixel_radius, 

887 regularization_eps_multiplier=regularization_eps_multiplier 

888 ) 

889 

890 # Convert overlap-region coordinates to tile-center coordinates 

891 tile_dy, tile_dx = _convert_overlap_to_tile_coordinates( 

892 dy, dx, int(overlap_h), int(overlap_h), H, W, 'vertical' 

893 ) 

894 

895 # Expected position (GPU) 

896 expected_bottom = positions[top_idx] + cp.array([tile_dx, tile_dy]) 

897 

898 # Accumulate updates (GPU) 

899 position_corrections[bottom_idx] += expected_bottom * correlation_weight_vertical 

900 correction_weights[bottom_idx] += correlation_weight_vertical 

901 

902 # Apply corrections with damping (all GPU) 

903 for tile_idx in range(Z): 

904 if correction_weights[tile_idx] > 0 and tile_idx != anchor_tile_index: 

905 averaged_correction = position_corrections[tile_idx] / correction_weights[tile_idx] 

906 positions[tile_idx] = ((1 - refinement_damping) * positions[tile_idx] + 

907 refinement_damping * averaged_correction) 

908 

909 # Phase 3: Global optimization MST (GPU operations) 

910 print(f"🔥 PHASE 3: global_optimization={global_optimization}") 

911 if global_optimization: 

912 print(f"🔥 STARTING MST GLOBAL OPTIMIZATION") 

913 positions = _global_optimization_gpu_only( 

914 positions, tile_grid, num_rows, num_cols, 

915 expected_dx, expected_dy, overlap_ratio, subpixel, 

916 

917 quality_threshold=mst_quality_threshold, 

918 subpixel_radius=subpixel_radius, 

919 regularization_eps_multiplier=regularization_eps_multiplier, 

920 anchor_tile_index=anchor_tile_index, 

921 debug_connection_limit=debug_connection_limit, 

922 debug_vertical_limit=debug_vertical_limit, 

923 displacement_tolerance_factor=displacement_tolerance_factor, 

924 displacement_tolerance_percent=displacement_tolerance_percent, 

925 consistency_threshold_percent=consistency_threshold_percent, 

926 max_connections_multiplier=max_connections_multiplier, 

927 adaptive_base_threshold=adaptive_base_threshold, 

928 adaptive_percentile_threshold=adaptive_percentile_threshold, 

929 translation_tolerance_factor=translation_tolerance_factor, 

930 translation_min_quality=translation_min_quality, 

931 magnitude_threshold_multiplier=magnitude_threshold_multiplier, 

932 peak_candidates_multiplier=peak_candidates_multiplier, 

933 min_peak_distance=min_peak_distance, 

934 use_nist_robustness=use_nist_robustness, 

935 n_peaks=n_peaks, 

936 use_nist_normalization=use_nist_normalization 

937 ) 

938 

939 # Center positions (GPU) 

940 mean_pos = cp.mean(positions, axis=0) 

941 positions = positions - mean_pos 

942 

943 print(f"🔥 MIST COMPLETE: Returning {positions.shape} positions") 

944 return tiles, positions