Coverage for openhcs/processing/backends/pos_gen/mist/mist_main.py: 5.1%
304 statements
« prev ^ index » next coverage.py v7.11.0, created at 2025-11-04 02:09 +0000
« prev ^ index » next coverage.py v7.11.0, created at 2025-11-04 02:09 +0000
1"""
2Main MIST Implementation
4Full GPU-accelerated MIST implementation with zero CPU operations.
5Orchestrates all MIST components for tile position computation.
6"""
7from __future__ import annotations
9import logging
10from typing import TYPE_CHECKING, Tuple
12from openhcs.core.memory.decorators import cupy as cupy_func
13from openhcs.core.pipeline.function_contracts import special_inputs, special_outputs
14from openhcs.core.utils import optional_import
16from .phase_correlation import phase_correlation_gpu_only, phase_correlation_nist_gpu
17from .quality_metrics import (
18 compute_correlation_quality_gpu_aligned,
19 compute_adaptive_quality_threshold,
20 validate_translation_consistency,
21 log_coordinate_transformation
22)
23from .position_reconstruction import build_mst_gpu, rebuild_positions_from_mst_gpu
25# For type checking only
26if TYPE_CHECKING: 26 ↛ 27line 26 didn't jump to line 27 because the condition on line 26 was never true
27 import cupy as cp
29# Import CuPy as an optional dependency
30cp = optional_import("cupy")
32logger = logging.getLogger(__name__)
35def _convert_overlap_to_tile_coordinates(
36 dy: float, dx: float,
37 overlap_h: int, overlap_w: int,
38 tile_h: int, tile_w: int,
39 direction: str
40) -> Tuple[float, float]:
41 """
42 Convert overlap-region-relative displacements to tile-center coordinates.
44 Args:
45 dy, dx: Phase correlation displacements in overlap region coordinates
46 overlap_h, overlap_w: Overlap region dimensions
47 tile_h, tile_w: Full tile dimensions
48 direction: 'horizontal' or 'vertical'
50 Returns:
51 (tile_dy, tile_dx): Displacements in tile-center coordinates
52 """
53 if direction == 'horizontal':
54 # For horizontal connections (left-right)
55 # Expected displacement is approximately tile_w - overlap_w
56 expected_dx = tile_w - overlap_w
57 tile_dx = expected_dx + dx # Add phase correlation correction
58 tile_dy = dy # Vertical should be minimal
60 elif direction == 'vertical':
61 # For vertical connections (top-bottom)
62 # Expected displacement is approximately tile_h - overlap_h
63 expected_dy = tile_h - overlap_h
64 tile_dy = expected_dy + dy # Add phase correlation correction
65 tile_dx = dx # Horizontal should be minimal
67 else:
68 raise ValueError(f"Invalid direction: {direction}. Must be 'horizontal' or 'vertical'")
70 return tile_dy, tile_dx
76def _validate_displacement_magnitude(
77 tile_dx: float, tile_dy: float,
78 expected_dx: float, expected_dy: float,
79 direction: str,
80 tolerance_factor: float = 2.0,
81 tolerance_percent: float = 0.1
82) -> bool:
83 """
84 Validate that displacement magnitudes are reasonable.
86 Args:
87 tile_dx, tile_dy: Computed tile-center displacements
88 expected_dx, expected_dy: Expected displacements
89 direction: 'horizontal' or 'vertical'
90 tolerance_factor: How much deviation to allow
92 Returns:
93 True if displacement is reasonable, False otherwise
94 """
95 if direction == 'horizontal':
96 # For horizontal connections, dx should be close to expected_dx
97 dx_error = abs(tile_dx - expected_dx)
98 max_allowed_error = tolerance_factor * expected_dx * tolerance_percent
99 dx_valid = dx_error <= max_allowed_error
101 # dy should be small (minimal vertical drift relative to expected_dx, not expected_dy)
102 max_allowed_dy = tolerance_factor * expected_dx * tolerance_percent
103 dy_valid = abs(tile_dy) <= max_allowed_dy
105 return dx_valid and dy_valid
107 elif direction == 'vertical':
108 # For vertical connections, dy should be close to expected_dy
109 dy_error = abs(tile_dy - expected_dy)
110 max_allowed_error = tolerance_factor * expected_dy * tolerance_percent
111 dy_valid = dy_error <= max_allowed_error
113 # dx should be small (minimal horizontal drift relative to expected_dy, not expected_dx)
114 max_allowed_dx = tolerance_factor * expected_dy * tolerance_percent
115 dx_valid = abs(tile_dx) <= max_allowed_dx
117 return dy_valid and dx_valid
119 return False
122def _validate_cupy_array(array, name: str = "input") -> None: # type: ignore
123 """Validate that the input is a CuPy array."""
124 if not isinstance(array, cp.ndarray):
125 raise TypeError(f"{name} must be a CuPy array, got {type(array)}")
128def _global_optimization_gpu_only(
129 positions: "cp.ndarray", # type: ignore
130 tile_grid: "cp.ndarray", # type: ignore
131 num_rows: int,
132 num_cols: int,
133 expected_dx: float,
134 expected_dy: float,
135 overlap_ratio: float,
136 subpixel: bool,
137 *,
139 quality_threshold: float = 0.5, # NIST Algorithm 15: ncc >= 0.5 for valid translations
140 subpixel_radius: int = 3,
141 regularization_eps_multiplier: float = 1000.0,
142 anchor_tile_index: int = 0,
143 debug_connection_limit: int = 3,
144 debug_vertical_limit: int = 6,
145 displacement_tolerance_factor: float = 2.0,
146 displacement_tolerance_percent: float = 0.3,
147 consistency_threshold_percent: float = 0.5,
148 max_connections_multiplier: int = 2,
149 adaptive_base_threshold: float = 0.3,
150 adaptive_percentile_threshold: float = 0.25,
151 translation_tolerance_factor: float = 0.2,
152 translation_min_quality: float = 0.3,
153 magnitude_threshold_multiplier: float = 1e-6,
154 peak_candidates_multiplier: int = 4,
155 min_peak_distance: int = 5,
156 use_nist_robustness: bool = True, # NIST Algorithm 2: Enable multi-peak PCIAM with interpretation testing
157 n_peaks: int = 2, # NIST Algorithm 2: n=2 peaks tested (manually selected based on experimental testing)
158 use_nist_normalization: bool = True, # NIST Algorithm 3: Use fc/abs(fc) normalization instead of regularized approach
160 # NIST Algorithm 9: Stage model parameters
161 overlap_uncertainty_percent: float = 3.0, # NIST default: 3% overlap uncertainty (pou)
162 outlier_threshold_multiplier: float = 1.5, # NIST Algorithm 16: 1.5 × IQR for outlier detection
163) -> "cp.ndarray": # type: ignore
164 """
165 GPU-only global optimization using simplified MST approach.
166 """
167 H, W = tile_grid.shape[2], tile_grid.shape[3]
168 num_tiles = num_rows * num_cols
170 # Pre-allocate GPU arrays for connections
171 max_connections = max_connections_multiplier * num_tiles # Each tile has at most 2 neighbors (right, bottom)
172 connection_from = cp.full(max_connections, -1, dtype=cp.int32)
173 connection_to = cp.full(max_connections, -1, dtype=cp.int32)
174 connection_dx = cp.zeros(max_connections, dtype=cp.float32)
175 connection_dy = cp.zeros(max_connections, dtype=cp.float32)
176 connection_quality = cp.zeros(max_connections, dtype=cp.float32)
178 conn_idx = 0
180 # Debug: Track quality filtering
181 total_correlations = 0
182 passed_threshold = 0
183 all_qualities = []
185 # Debug: Print expected displacements and coordinate validation
186 print(f"🔥 EXPECTED DISPLACEMENTS: dx={float(expected_dx):.1f}, dy={float(expected_dy):.1f}")
187 print(f"🔥 OVERLAP RATIO: {overlap_ratio}, H={H}, W={W}")
188 print("🔥 COORDINATE VALIDATION:")
189 print(f" Expected tile spacing: dx={float(expected_dx):.1f}, dy={float(expected_dy):.1f}")
190 print(f" Overlap regions: H*ratio={H*overlap_ratio:.1f}, W*ratio={W*overlap_ratio:.1f}")
191 print(f" Actual overlap: H={H*overlap_ratio:.1f}, W={W*overlap_ratio:.1f} pixels")
193 # Debug: Check if images are black
194 tile_stats = []
195 for r in range(num_rows):
196 for c in range(num_cols):
197 tile = tile_grid[r, c]
198 tile_min = float(cp.min(tile))
199 tile_max = float(cp.max(tile))
200 tile_mean = float(cp.mean(tile))
201 tile_stats.append((tile_min, tile_max, tile_mean))
203 print(f"🔥 TILE STATS: First {debug_connection_limit} tiles - min/max/mean:")
204 for i, (tmin, tmax, tmean) in enumerate(tile_stats[:debug_connection_limit]):
205 print(f" Tile {i}: [{tmin:.1f}, {tmax:.1f}], mean={tmean:.1f}")
207 # Build connections (GPU operations)
208 for r in range(num_rows):
209 for c in range(num_cols):
210 tile_idx = r * num_cols + c
211 current_tile = tile_grid[r, c]
213 # Horizontal connection
214 if c < num_cols - 1:
215 right_idx = r * num_cols + (c + 1)
216 right_tile = tile_grid[r, c + 1]
218 overlap_w = cp.int32(W * overlap_ratio)
219 left_region = current_tile[:, -overlap_w:] # Right edge of left tile
220 right_region = right_tile[:, :overlap_w] # Left edge of right tile
222 # Debug: Check overlap region extraction (avoid GPU sync on .shape)
223 if conn_idx < debug_connection_limit:
224 print(f"🔥 HORIZONTAL OVERLAP {conn_idx}: tiles {tile_idx}->{right_idx}")
225 print(f" overlap_w={int(overlap_w)}, W={W}")
226 # Avoid .shape access which can cause GPU sync issues
227 print(" Processing overlap regions (shapes not shown to avoid GPU sync)")
229 if use_nist_robustness:
230 dy, dx, quality = phase_correlation_nist_gpu(
231 left_region, right_region,
232 direction='horizontal',
233 n_peaks=n_peaks,
234 use_nist_normalization=use_nist_normalization
235 )
236 else:
237 dy, dx = phase_correlation_gpu_only(
238 left_region, right_region, # Standardized: left_region first
239 subpixel=subpixel,
240 subpixel_radius=subpixel_radius,
241 regularization_eps_multiplier=regularization_eps_multiplier
242 )
243 # Compute quality after applying the shift
244 quality = compute_correlation_quality_gpu_aligned(left_region, right_region, dx, dy)
246 # Debug: Track all quality values
247 total_correlations += 1
248 all_qualities.append(quality)
250 if quality >= quality_threshold:
251 # Convert overlap-region coordinates to tile-center coordinates
252 tile_dy, tile_dx = _convert_overlap_to_tile_coordinates(
253 dy, dx, int(overlap_w), int(overlap_w), H, W, 'horizontal'
254 )
256 # Log coordinate transformation for debugging
257 if conn_idx < debug_connection_limit: # Only log first few for brevity
258 log_coordinate_transformation(
259 dy, dx, tile_dy, tile_dx, 'horizontal', (tile_idx, right_idx)
260 )
262 # Validate displacement magnitude
263 displacement_valid = _validate_displacement_magnitude(
264 tile_dx, tile_dy, float(expected_dx), 0.0, 'horizontal',
265 displacement_tolerance_factor, displacement_tolerance_percent
266 )
268 if displacement_valid:
269 passed_threshold += 1
270 connection_from[conn_idx] = tile_idx
271 connection_to[conn_idx] = right_idx
272 connection_dx[conn_idx] = tile_dx
273 connection_dy[conn_idx] = tile_dy
274 connection_quality[conn_idx] = quality
276 # Debug: Print first few connections
277 if conn_idx < debug_connection_limit:
278 print(f"🔥 HORIZONTAL CONNECTION {conn_idx}: {tile_idx}->{right_idx}")
279 print(f" overlap coords: dx={float(dx):.3f}, dy={float(dy):.3f}")
280 print(f" tile coords: dx={float(tile_dx):.3f}, dy={float(tile_dy):.3f}")
281 print(f" quality={float(quality):.6f}, displacement_valid={displacement_valid}")
283 conn_idx += 1
284 else:
285 # Debug: Log rejected connections
286 if conn_idx < debug_connection_limit:
287 print(f"🔥 REJECTED HORIZONTAL {tile_idx}->{right_idx}: displacement invalid")
288 print(f" tile coords: dx={float(tile_dx):.3f}, dy={float(tile_dy):.3f}")
289 print(f" expected: dx={float(expected_dx):.3f}, dy={float(expected_dy):.3f}")
290 # Show validation details
291 dx_error = abs(tile_dx - expected_dx)
292 max_allowed_error = displacement_tolerance_factor * expected_dx * displacement_tolerance_percent
293 max_allowed_dy = displacement_tolerance_factor * expected_dx * displacement_tolerance_percent
294 print(f" dx_error={dx_error:.3f} vs max_allowed={max_allowed_error:.3f}")
295 print(f" abs(dy)={abs(tile_dy):.3f} vs max_allowed_dy={max_allowed_dy:.3f}")
297 # Vertical connection
298 if r < num_rows - 1:
299 bottom_idx = (r + 1) * num_cols + c
300 bottom_tile = tile_grid[r + 1, c]
302 overlap_h = cp.int32(H * overlap_ratio)
303 top_region = current_tile[-overlap_h:, :] # Bottom edge of top tile
304 bottom_region = bottom_tile[:overlap_h, :] # Top edge of bottom tile
306 if use_nist_robustness:
307 dy, dx, quality = phase_correlation_nist_gpu(
308 top_region, bottom_region,
309 direction='vertical',
310 n_peaks=n_peaks,
311 use_nist_normalization=use_nist_normalization
312 )
313 else:
314 dy, dx = phase_correlation_gpu_only(
315 top_region, bottom_region, # Standardized: top_region first
316 subpixel=subpixel,
317 subpixel_radius=subpixel_radius,
318 regularization_eps_multiplier=regularization_eps_multiplier
319 )
320 # Compute quality after applying the shift
321 quality = compute_correlation_quality_gpu_aligned(top_region, bottom_region, dx, dy)
323 # Debug: Track all quality values
324 total_correlations += 1
325 all_qualities.append(quality)
327 if quality >= quality_threshold:
328 # Convert overlap-region coordinates to tile-center coordinates
329 tile_dy, tile_dx = _convert_overlap_to_tile_coordinates(
330 dy, dx, int(overlap_h), int(overlap_h), H, W, 'vertical'
331 )
333 # Log coordinate transformation for debugging
334 if conn_idx < debug_vertical_limit: # Only log first few for brevity
335 log_coordinate_transformation(
336 dy, dx, tile_dy, tile_dx, 'vertical', (tile_idx, bottom_idx)
337 )
339 # Validate displacement magnitude
340 displacement_valid = _validate_displacement_magnitude(
341 tile_dx, tile_dy, 0.0, float(expected_dy), 'vertical',
342 displacement_tolerance_factor, displacement_tolerance_percent
343 )
345 if displacement_valid:
346 passed_threshold += 1
347 connection_from[conn_idx] = tile_idx
348 connection_to[conn_idx] = bottom_idx
349 connection_dx[conn_idx] = tile_dx
350 connection_dy[conn_idx] = tile_dy
351 connection_quality[conn_idx] = quality
353 # Debug: Print first few connections
354 if conn_idx < debug_vertical_limit: # Show a few more since we want to see vertical connections too
355 print(f"🔥 VERTICAL CONNECTION {conn_idx}: {tile_idx}->{bottom_idx}")
356 print(f" overlap coords: dx={float(dx):.3f}, dy={float(dy):.3f}")
357 print(f" tile coords: dx={float(tile_dx):.3f}, dy={float(tile_dy):.3f}")
358 print(f" quality={float(quality):.6f}, displacement_valid={displacement_valid}")
360 conn_idx += 1
361 else:
362 # Debug: Log rejected connections
363 if conn_idx < debug_vertical_limit:
364 print(f"🔥 REJECTED VERTICAL {tile_idx}->{bottom_idx}: displacement invalid")
365 print(f" tile coords: dx={float(tile_dx):.3f}, dy={float(tile_dy):.3f}")
366 print(f" expected: dx={float(expected_dx):.3f}, dy={float(expected_dy):.3f}")
367 # Show validation details
368 dy_error = abs(tile_dy - expected_dy)
369 max_allowed_error = displacement_tolerance_factor * expected_dy * displacement_tolerance_percent
370 max_allowed_dx = displacement_tolerance_factor * expected_dy * displacement_tolerance_percent
371 print(f" dy_error={dy_error:.3f} vs max_allowed={max_allowed_error:.3f}")
372 print(f" abs(dx)={abs(tile_dx):.3f} vs max_allowed_dx={max_allowed_dx:.3f}")
374 # Compute adaptive quality threshold if we have quality data
375 if len(all_qualities) > 0:
376 adaptive_threshold = compute_adaptive_quality_threshold(
377 all_qualities, adaptive_base_threshold, adaptive_percentile_threshold
378 )
379 print(f"🔥 ADAPTIVE THRESHOLD: original={quality_threshold:.6f}, adaptive={adaptive_threshold:.6f}")
381 # Re-filter connections with adaptive threshold if it's different
382 if adaptive_threshold != quality_threshold and adaptive_threshold < quality_threshold:
383 print("🔥 RE-FILTERING with adaptive threshold...")
384 # Note: In a full implementation, we'd re-process with the adaptive threshold
385 # For now, we'll use the original threshold but log the adaptive one
387 # Debug: Print quality filtering summary
388 print(f"🔥 QUALITY FILTERING: {passed_threshold}/{total_correlations} connections passed threshold {quality_threshold}")
389 if len(all_qualities) > 0:
390 min_q = float(cp.min(cp.array(all_qualities)))
391 max_q = float(cp.max(cp.array(all_qualities)))
392 mean_q = float(cp.mean(cp.array(all_qualities)))
393 print(f"🔥 QUALITY RANGE: min={min_q:.6f}, max={max_q:.6f}, mean={mean_q:.6f}")
395 # Validate translation consistency (Plan 03)
396 if conn_idx > 0:
397 # Collect translations for validation
398 translations = []
399 for i in range(conn_idx):
400 dy_val = float(connection_dy[i])
401 dx_val = float(connection_dx[i])
402 quality_val = float(connection_quality[i])
403 translations.append((dy_val, dx_val, quality_val))
405 # Validate against expected spacing
406 expected_spacing = (float(expected_dx), float(expected_dy))
407 valid_flags = validate_translation_consistency(
408 translations, expected_spacing, translation_tolerance_factor, translation_min_quality
409 )
411 num_valid = sum(valid_flags)
412 print(f"🔥 TRANSLATION VALIDATION: {num_valid}/{len(translations)} connections are consistent")
414 if num_valid < len(translations) * consistency_threshold_percent: # Less than threshold% valid
415 print(f"🔥 WARNING: Low translation consistency ({num_valid}/{len(translations)})")
416 print(f"🔥 Expected spacing: dx={expected_spacing[0]:.1f}, dy={expected_spacing[1]:.1f}")
417 print("🔥 Consider adjusting overlap_ratio or quality thresholds")
419 # Trim arrays to actual size (GPU)
420 if conn_idx > 0:
421 valid_connections = cp.arange(conn_idx)
422 connection_from = connection_from[:conn_idx]
423 connection_to = connection_to[:conn_idx]
424 connection_dx = connection_dx[:conn_idx]
425 connection_dy = connection_dy[:conn_idx]
426 connection_quality = connection_quality[:conn_idx]
428 # Build MST using refactored GPU Borůvka's algorithm
429 mst_edges = build_mst_gpu(
430 connection_from, connection_to, connection_dx,
431 connection_dy, connection_quality, num_tiles
432 )
434 # Rebuild positions using MST (GPU)
435 new_positions = rebuild_positions_from_mst_gpu(
436 positions, mst_edges, num_tiles, anchor_tile_index
437 )
439 return new_positions
441 return positions
444@special_inputs("grid_dimensions")
445@special_outputs("positions")
446@cupy_func
447def mist_compute_tile_positions(
448 image_stack: "cp.ndarray", # type: ignore
449 grid_dimensions: Tuple[int, int],
450 *,
451 # === Input Validation Parameters ===
452 method: str = "phase_correlation",
453 fft_backend: str = "cupy",
455 # === Core Algorithm Parameters ===
456 normalize: bool = True,
457 verbose: bool = False,
458 overlap_ratio: float = 0.1,
459 subpixel: bool = True,
460 refinement_iterations: int = 10,
461 global_optimization: bool = True,
462 anchor_tile_index: int = 0,
464 # === Refinement Tuning Parameters ===
465 refinement_damping: float = 0.5,
466 correlation_weight_horizontal: float = 1.0,
467 correlation_weight_vertical: float = 1.0,
469 # === Phase Correlation Parameters ===
470 subpixel_radius: int = 3,
471 regularization_eps_multiplier: float = 1000.0,
473 # === MST Global Optimization Parameters ===
474 mst_quality_threshold: float = 0.5, # NIST Algorithm 15: ncc >= 0.5 for MST edge inclusion
475 # NIST robustness parameters (Algorithms 2-5)
476 use_nist_robustness: bool = True, # Enable full NIST PCIAM implementation
477 n_peaks: int = 2, # NIST Algorithm 2: Test 2 peaks (experimentally determined)
478 use_nist_normalization: bool = True, # NIST Algorithm 3: fc/abs(fc) normalization
479 # Debugging and validation parameters
480 debug_connection_limit: int = 3,
481 debug_vertical_limit: int = 6,
482 displacement_tolerance_factor: float = 2.0,
483 displacement_tolerance_percent: float = 0.3,
484 consistency_threshold_percent: float = 0.5,
485 max_connections_multiplier: int = 2,
486 # Quality metric tuning parameters
487 adaptive_base_threshold: float = 0.3,
488 adaptive_percentile_threshold: float = 0.25,
489 translation_tolerance_factor: float = 0.2,
490 translation_min_quality: float = 0.3,
491 # Phase correlation tuning parameters
492 magnitude_threshold_multiplier: float = 1e-6,
493 peak_candidates_multiplier: int = 4,
494 min_peak_distance: int = 5,
495 **kwargs
496) -> Tuple["cp.ndarray", "cp.ndarray"]: # type: ignore
497 """
498 Full GPU MIST implementation with zero CPU operations.
500 Performs microscopy image stitching using phase correlation and iterative refinement.
501 The algorithm has three phases:
502 1. Initial positioning using sequential phase correlation
503 2. Iterative refinement with constraint optimization
504 3. Global optimization using minimum spanning tree (MST)
506 Args:
507 image_stack: 3D tensor (Z, Y, X) of tiles to stitch
508 grid_dimensions: (num_cols, num_rows) grid layout of tiles
510 === Input Validation Parameters ===
511 method: Correlation method - must be "phase_correlation"
512 fft_backend: FFT backend - must be "cupy" for GPU acceleration
514 === Core Algorithm Parameters (NIST Algorithms 1-3) ===
515 normalize: Normalize each tile to [0,1] range using (tile-min)/(max-min).
516 True = better correlation accuracy, handles varying illumination.
517 False = faster but poor results with uneven lighting.
518 Used in NIST Algorithm 3 (PCM) preprocessing.
519 verbose: Enable detailed logging of algorithm progress and timing
520 overlap_ratio: Expected overlap between adjacent tiles as fraction (0.0-1.0).
521 Defines correlation region size: overlap_w = int(W * overlap_ratio).
522 CRITICAL: Must match actual overlap in data or correlation fails.
523 Higher (0.2-0.4) = more robust but slower.
524 Lower (0.05-0.08) = faster but less accurate.
525 Used in NIST Algorithm 10 (Compute Image Overlap).
526 subpixel: Enable subpixel-accurate phase correlation for higher precision.
527 True = center-of-mass interpolation around correlation peak.
528 False = pixel-only accuracy (faster, less precise).
529 Enhances NIST Algorithm 3 (PCM) with subpixel refinement.
530 refinement_iterations: Number of iterative position refinement passes (0-50).
531 Each iteration applies weighted position corrections.
532 Higher = better convergence but much slower.
533 0 = skip refinement (fastest, least accurate).
534 Implements NIST Algorithm 21 (Bounded NCC Hill Climb).
535 global_optimization: Enable MST-based global optimization phase.
536 Uses minimum spanning tree to optimize tile positions globally.
537 Significantly improves accuracy for large grids.
538 Implements NIST Phase 3 (Image Composition).
539 anchor_tile_index: Index of reference tile that remains fixed at origin (usually 0).
540 All other positions calculated relative to this tile.
541 Used in NIST MST position reconstruction.
543 === Refinement Tuning Parameters ===
544 refinement_damping: Controls how aggressively positions are updated (0.0-1.0).
545 Formula: new_pos = (1-damping)*old_pos + damping*correction.
546 Higher (0.7-0.9) = faster convergence but may overshoot.
547 Lower (0.1-0.3) = more stable but slower convergence.
548 1.0 = full correction (may be unstable), 0.0 = no updates.
549 correlation_weight_horizontal: Weight for horizontal tile constraints (>0).
550 Higher values prioritize horizontal alignment accuracy.
551 Typical range: 0.5-2.0.
552 correlation_weight_vertical: Weight for vertical tile constraints (>0).
553 Higher values prioritize vertical alignment accuracy.
554 Typical range: 0.5-2.0.
556 === Phase Correlation Parameters (NIST Algorithm 3) ===
557 subpixel_radius: Radius around correlation peak for center-of-mass calculation.
558 Extracts (2*radius+1)² region around peak for interpolation.
559 Higher (5-10) = more accurate subpixel positioning but slower.
560 Lower (1-2) = faster but less precise, may cause drift.
561 0 = pixel-only accuracy (fastest, least precise).
562 Enhances NIST Algorithm 3 (PCM) with subpixel precision.
563 regularization_eps_multiplier: Prevents division by zero in phase correlation.
564 Formula: eps = machine_epsilon * multiplier.
565 Higher (10000+) = more stable with noisy images.
566 Lower (100-500) = higher precision but may fail.
567 Too low (<10) = risk of numerical instability.
568 Used in NIST Algorithm 3 cross-power normalization.
570 === MST Global Optimization Parameters (NIST Algorithms 8-21) ===
571 mst_quality_threshold: Minimum correlation quality for MST edge inclusion (0.0-1.0).
572 NIST Algorithm 15: ncc >= 0.5 for valid translations.
573 Formula: if correlation_peak < threshold: reject_connection.
574 NIST default: 0.5 (stricter quality control).
575 Higher = fewer connections, lower = includes weak correlations.
576 Too high = MST may fail, too low = includes noise.
577 use_nist_robustness: Enable NIST robust phase correlation (Algorithm 2).
578 True = multi-peak PCIAM with interpretation testing.
579 False = simplified single-peak method (faster).
580 n_peaks: Number of correlation peaks to analyze (NIST Algorithm 4).
581 NIST default: n=2 (manually selected based on experimental testing).
582 Higher = more robust peak selection but slower processing.
583 use_nist_normalization: Apply NIST normalization method (Algorithm 3).
584 True = fc/abs(fc) normalization (NIST standard).
585 False = OpenHCS regularization method.
587 displacement_tolerance_factor: Multiplier for expected displacement tolerance.
588 NIST Algorithm 14: Stage model displacement validation.
589 Formula: max_error = factor * expected_displacement * percent.
590 Higher (3.0-5.0) = more permissive validation.
591 Lower (1.0-1.5) = stricter validation.
592 displacement_tolerance_percent: Percentage tolerance for displacement (0.0-1.0).
593 NIST Algorithm 14: Displacement validation threshold.
594 Formula: valid if |actual - expected| <= expected * percent.
595 0.3 = ±30% deviation allowed from expected displacement.
596 Higher = accepts larger deviations, lower = stricter.
598 debug_connection_limit: Max horizontal connections to log for debugging (0-10)
599 debug_vertical_limit: Max vertical connections to log for debugging (0-10)
600 consistency_threshold_percent: Translation consistency validation threshold (0.0-1.0).
601 NIST Algorithm 17: Filter by repeatability.
602 Formula: valid if |translation - median| <= median * threshold.
603 0.5 = ±50% deviation from median allowed.
604 Higher = more permissive, lower = stricter consistency.
605 max_connections_multiplier: Maximum connections per tile in MST construction.
606 Formula: max_connections = base_connections * multiplier.
607 Prevents over-connected graphs that slow MST algorithms.
608 2 = allow 2x normal connections, 1 = strict minimum.
609 adaptive_base_threshold: Minimum quality threshold for adaptive quality metrics.
610 NIST-inspired adaptive thresholding for challenging datasets.
611 Formula: final_threshold = max(base_threshold, percentile_threshold).
612 0.3 = minimum 30% correlation required regardless of distribution.
613 Prevents threshold from becoming too permissive.
614 adaptive_percentile_threshold: Percentile-based quality threshold (0.0-1.0).
615 NIST Algorithm 9: Stage model validation approach.
616 Formula: threshold = percentile(all_qualities, percentile * 100).
617 0.25 = use 25th percentile of quality distribution.
618 Lower = more permissive, higher = stricter selection.
619 translation_tolerance_factor: Tolerance multiplier for translation validation.
620 NIST Algorithm 14: Stage model displacement validation.
621 Formula: max_error = expected_displacement * factor * percent.
622 0.2 = allow 20% deviation from expected displacement.
623 Higher = more permissive validation.
624 translation_min_quality: Minimum correlation quality for translation acceptance.
625 NIST Algorithm 15: Quality-based filtering threshold.
626 Formula: accept if ncc >= min_quality.
627 0.3 = require 30% normalized cross-correlation minimum.
628 Higher = stricter quality, lower = more permissive.
629 magnitude_threshold_multiplier: FFT magnitude threshold for numerical stability.
630 NIST Algorithm 3: Cross-power spectrum normalization.
631 Formula: threshold = mean(magnitude) * multiplier.
632 1e-6 = very small threshold for numerical stability.
633 Higher = more aggressive filtering of low-magnitude frequencies.
634 peak_candidates_multiplier: Candidate peak search multiplier for robustness.
635 NIST Algorithm 4: Multi-peak max search optimization.
636 Formula: n_candidates = n_peaks * multiplier.
637 4 = search 4x more candidates than needed for robust selection.
638 Higher = more thorough search but slower processing.
639 min_peak_distance: Minimum pixel distance between correlation peaks.
640 NIST Algorithm 4: Prevents duplicate peak detection.
641 Formula: reject if distance(peak1, peak2) < min_distance.
642 5 = peaks must be ≥5 pixels apart to be considered distinct.
643 Higher = fewer but more distinct peaks, lower = more peaks.
645 === NIST Mathematical Formulas ===
647 Algorithm 3 (PCM): Peak Correlation Matrix
648 F1 ← fft2D(I1), F2 ← fft2D(I2)
649 FC ← F1 .* conj(F2)
650 PCM ← ifft2D(FC ./ abs(FC))
652 Algorithm 6 (NCC): Normalized Cross-Correlation
653 I1 ← I1 - mean(I1), I2 ← I2 - mean(I2)
654 ncc = (I1 · I2) / (|I1| * |I2|)
656 Algorithm 10 (Overlap): Image Overlap Computation
657 overlap_percent = 100 - mu (where mu is mean translation)
658 valid_range = [overlap ± overlap_uncertainty_percent]
660 Algorithm 16 (Outliers): Statistical Outlier Detection
661 q1 = 25th percentile, q3 = 75th percentile
662 IQR = q3 - q1
663 outlier if: value < (q1 - 1.5*IQR) OR value > (q3 + 1.5*IQR)
665 Algorithm 21 (Hill Climb): Bounded Translation Refinement
666 search_bounds = [current ± repeatability]
667 ncc_surface[i,j] = ncc(extract_overlap(I1, j, i), extract_overlap(I2, -j, -i))
668 climb to local maximum within bounds
670 === NIST Performance Guidance ===
672 Quality Threshold Tuning:
673 - Start with NIST default: 0.5 (strict quality control)
674 - Lower to 0.3-0.4 for noisy biological samples
675 - Lower to 0.1-0.2 for very challenging datasets
676 - Monitor MST edge count: need ≥(num_tiles-1) edges minimum
678 Peak Count Optimization:
679 - NIST default: n=2 peaks (experimentally optimal)
680 - Increase to 3-5 for highly repetitive patterns
681 - Keep at 2 for most microscopy applications
683 Overlap Ratio Guidelines:
684 - Must match actual image overlap precisely
685 - Typical microscopy: 0.1-0.2 (10-20% overlap)
686 - Higher overlap = more robust but slower processing
687 - Lower overlap = faster but less reliable alignment
689 Subpixel Refinement:
690 - Enable for publication-quality results
691 - Radius 3-5 optimal for most applications
692 - Disable for speed-critical applications
694 Expected Performance:
695 - With NIST defaults: High accuracy, moderate speed
696 - Quality threshold 0.5: Strict filtering, fewer edges
697 - Multi-peak robustness: 2-3x slower but more reliable
698 - Global optimization: Essential for large grids (>3x3)
700 Returns:
701 Tuple of (image_stack, positions) where:
702 - image_stack: Original input tiles (potentially normalized)
703 - positions: (Z, 2) array of tile positions in (x, y) format
704 Positions are centered around origin
706 Raises:
707 ValueError: If input validation fails (wrong method, backend, or dimensions)
708 TypeError: If image_stack is not a CuPy array
709 """
710 _validate_cupy_array(image_stack, "image_stack")
712 if image_stack.ndim != 3:
713 raise ValueError(f"Input must be a 3D tensor, got {image_stack.ndim}D")
715 if fft_backend != "cupy":
716 raise ValueError(f"FFT backend must be 'cupy', got '{fft_backend}'")
718 if method != "phase_correlation":
719 raise ValueError(f"Only 'phase_correlation' method is supported, got '{method}'")
721 num_cols, num_rows = grid_dimensions
722 Z, H, W = image_stack.shape
724 # VERY FIRST THING - Debug output to confirm function is called
725 print("🔥🔥🔥 MIST FUNCTION ENTRY POINT - FUNCTION IS DEFINITELY BEING CALLED! 🔥🔥🔥")
726 print(f"🔥 Image stack shape: {image_stack.shape}")
727 print(f"🔥 Grid dimensions: {grid_dimensions}")
729 # Debug: Log the actual overlap_ratio parameter being used
730 print(f"🔥 MIST FUNCTION CALLED WITH overlap_ratio={overlap_ratio}")
731 print(f"🔥 Expected: 0.1 (10% overlap), Actual: {overlap_ratio}")
733 if Z != num_rows * num_cols:
734 raise ValueError(
735 f"Number of tiles ({Z}) does not match grid size ({num_rows}x{num_cols}={num_rows*num_cols})"
736 )
738 # Normalize on GPU
739 tiles = image_stack.astype(cp.float32)
740 if normalize:
741 for z in range(Z):
742 tile = tiles[z]
743 tile_min = cp.min(tile)
744 tile_max = cp.max(tile)
745 tile_range = tile_max - tile_min
746 # Use GPU conditional to avoid division by zero
747 tiles[z] = cp.where(tile_range > 0, (tile - tile_min) / tile_range, tile)
749 # Reshape to grid (GPU operation)
750 tile_grid = tiles.reshape(num_rows, num_cols, H, W)
752 # Calculate expected spacing (GPU)
753 expected_dy = cp.float32(H * (1.0 - overlap_ratio))
754 expected_dx = cp.float32(W * (1.0 - overlap_ratio))
756 # Initialize positions on GPU
757 positions = cp.zeros((Z, 2), dtype=cp.float32)
759 if verbose:
760 logger.info(f"GPU MIST: {num_rows}x{num_cols} grid, spacing: dx={float(expected_dx):.1f}, dy={float(expected_dy):.1f}")
762 # Phase 1: Initial positioning (all GPU)
763 for r in range(num_rows):
764 for c in range(num_cols):
765 tile_idx = r * num_cols + c
767 if tile_idx == anchor_tile_index:
768 positions[tile_idx] = cp.array([0.0, 0.0])
769 continue
771 current_tile = tile_grid[r, c]
773 # Position from left neighbor (GPU operations)
774 if c > 0:
775 left_idx = r * num_cols + (c - 1)
776 left_tile = tile_grid[r, c - 1]
778 # Extract overlap regions (GPU)
779 overlap_w = cp.int32(W * overlap_ratio)
780 left_region = left_tile[:, -overlap_w:]
781 current_region = current_tile[:, :overlap_w]
783 # GPU phase correlation
784 dy, dx = phase_correlation_gpu_only(
785 left_region, current_region,
786 subpixel=subpixel,
787 subpixel_radius=subpixel_radius,
788 regularization_eps_multiplier=regularization_eps_multiplier
789 )
791 # Convert overlap-region coordinates to tile-center coordinates
792 tile_dy, tile_dx = _convert_overlap_to_tile_coordinates(
793 dy, dx, int(overlap_w), int(overlap_w), H, W, 'horizontal'
794 )
796 # Update position (GPU)
797 new_x = positions[left_idx, 0] + tile_dx
798 new_y = positions[left_idx, 1] + tile_dy
799 positions[tile_idx] = cp.array([new_x, new_y])
801 elif r > 0: # Position from top neighbor
802 top_idx = (r - 1) * num_cols + c
803 top_tile = tile_grid[r - 1, c]
805 # Extract overlap regions (GPU)
806 overlap_h = cp.int32(H * overlap_ratio)
807 top_region = top_tile[-overlap_h:, :]
808 current_region = current_tile[:overlap_h, :]
810 # GPU phase correlation
811 dy, dx = phase_correlation_gpu_only(
812 top_region, current_region,
813 subpixel=subpixel,
814 subpixel_radius=subpixel_radius,
815 regularization_eps_multiplier=regularization_eps_multiplier
816 )
818 # Convert overlap-region coordinates to tile-center coordinates
819 tile_dy, tile_dx = _convert_overlap_to_tile_coordinates(
820 dy, dx, int(overlap_h), int(overlap_h), H, W, 'vertical'
821 )
823 # Update position (GPU)
824 new_x = positions[top_idx, 0] + tile_dx
825 new_y = positions[top_idx, 1] + tile_dy
826 positions[tile_idx] = cp.array([new_x, new_y])
828 # Phase 2: Refinement iterations (all GPU)
829 for iteration in range(refinement_iterations):
830 if verbose:
831 logger.info(f"GPU refinement iteration {iteration + 1}/{refinement_iterations}")
833 position_corrections = cp.zeros_like(positions)
834 correction_weights = cp.zeros(Z, dtype=cp.float32)
836 # Horizontal constraints (GPU)
837 for r in range(num_rows):
838 for c in range(num_cols - 1):
839 left_idx = r * num_cols + c
840 right_idx = r * num_cols + (c + 1)
842 left_tile = tile_grid[r, c]
843 right_tile = tile_grid[r, c + 1]
845 overlap_w = cp.int32(W * overlap_ratio)
846 left_region = left_tile[:, -overlap_w:] # Right edge of left tile
847 right_region = right_tile[:, :overlap_w] # Left edge of right tile
849 dy, dx = phase_correlation_gpu_only(
850 left_region, right_region, # Standardized: left_region first
851 subpixel=subpixel,
852 subpixel_radius=subpixel_radius,
853 regularization_eps_multiplier=regularization_eps_multiplier
854 )
856 # Convert overlap-region coordinates to tile-center coordinates
857 tile_dy, tile_dx = _convert_overlap_to_tile_coordinates(
858 dy, dx, int(overlap_w), int(overlap_w), H, W, 'horizontal'
859 )
861 # Expected position (GPU)
862 expected_right = positions[left_idx] + cp.array([tile_dx, tile_dy])
864 # Accumulate updates (GPU)
865 position_corrections[right_idx] += expected_right * correlation_weight_horizontal
866 correction_weights[right_idx] += correlation_weight_horizontal
868 # Vertical constraints (GPU)
869 for r in range(num_rows - 1):
870 for c in range(num_cols):
871 top_idx = r * num_cols + c
872 bottom_idx = (r + 1) * num_cols + c
874 top_tile = tile_grid[r, c]
875 bottom_tile = tile_grid[r + 1, c]
877 overlap_h = cp.int32(H * overlap_ratio)
878 top_region = top_tile[-overlap_h:, :] # Bottom edge of top tile
879 bottom_region = bottom_tile[:overlap_h, :] # Top edge of bottom tile
881 dy, dx = phase_correlation_gpu_only(
882 top_region, bottom_region, # Standardized: top_region first
883 subpixel=subpixel,
884 subpixel_radius=subpixel_radius,
885 regularization_eps_multiplier=regularization_eps_multiplier
886 )
888 # Convert overlap-region coordinates to tile-center coordinates
889 tile_dy, tile_dx = _convert_overlap_to_tile_coordinates(
890 dy, dx, int(overlap_h), int(overlap_h), H, W, 'vertical'
891 )
893 # Expected position (GPU)
894 expected_bottom = positions[top_idx] + cp.array([tile_dx, tile_dy])
896 # Accumulate updates (GPU)
897 position_corrections[bottom_idx] += expected_bottom * correlation_weight_vertical
898 correction_weights[bottom_idx] += correlation_weight_vertical
900 # Apply corrections with damping (all GPU)
901 for tile_idx in range(Z):
902 if correction_weights[tile_idx] > 0 and tile_idx != anchor_tile_index:
903 averaged_correction = position_corrections[tile_idx] / correction_weights[tile_idx]
904 positions[tile_idx] = ((1 - refinement_damping) * positions[tile_idx] +
905 refinement_damping * averaged_correction)
907 # Phase 3: Global optimization MST (GPU operations)
908 print(f"🔥 PHASE 3: global_optimization={global_optimization}")
909 if global_optimization:
910 print("🔥 STARTING MST GLOBAL OPTIMIZATION")
911 positions = _global_optimization_gpu_only(
912 positions, tile_grid, num_rows, num_cols,
913 expected_dx, expected_dy, overlap_ratio, subpixel,
915 quality_threshold=mst_quality_threshold,
916 subpixel_radius=subpixel_radius,
917 regularization_eps_multiplier=regularization_eps_multiplier,
918 anchor_tile_index=anchor_tile_index,
919 debug_connection_limit=debug_connection_limit,
920 debug_vertical_limit=debug_vertical_limit,
921 displacement_tolerance_factor=displacement_tolerance_factor,
922 displacement_tolerance_percent=displacement_tolerance_percent,
923 consistency_threshold_percent=consistency_threshold_percent,
924 max_connections_multiplier=max_connections_multiplier,
925 adaptive_base_threshold=adaptive_base_threshold,
926 adaptive_percentile_threshold=adaptive_percentile_threshold,
927 translation_tolerance_factor=translation_tolerance_factor,
928 translation_min_quality=translation_min_quality,
929 magnitude_threshold_multiplier=magnitude_threshold_multiplier,
930 peak_candidates_multiplier=peak_candidates_multiplier,
931 min_peak_distance=min_peak_distance,
932 use_nist_robustness=use_nist_robustness,
933 n_peaks=n_peaks,
934 use_nist_normalization=use_nist_normalization
935 )
937 # Center positions (GPU)
938 mean_pos = cp.mean(positions, axis=0)
939 positions = positions - mean_pos
941 print(f"🔥 MIST COMPLETE: Returning {positions.shape} positions")
942 return tiles, positions