Coverage for openhcs/processing/backends/pos_gen/mist/mist_main.py: 5.3%
305 statements
« prev ^ index » next coverage.py v7.10.3, created at 2025-08-14 05:57 +0000
« prev ^ index » next coverage.py v7.10.3, created at 2025-08-14 05:57 +0000
1"""
2Main MIST Implementation
4Full GPU-accelerated MIST implementation with zero CPU operations.
5Orchestrates all MIST components for tile position computation.
6"""
7from __future__ import annotations
9import logging
10from typing import TYPE_CHECKING, Tuple
12from openhcs.constants.constants import DEFAULT_PATCH_SIZE, DEFAULT_SEARCH_RADIUS
13from openhcs.core.memory.decorators import cupy as cupy_func
14from openhcs.core.pipeline.function_contracts import special_inputs, special_outputs
15from openhcs.core.utils import optional_import
17from .phase_correlation import phase_correlation_gpu_only, phase_correlation_nist_gpu
18from .quality_metrics import (
19 compute_correlation_quality_gpu_aligned,
20 compute_adaptive_quality_threshold,
21 validate_translation_consistency,
22 log_coordinate_transformation,
23 debug_phase_correlation_matrix
24)
25from .position_reconstruction import build_mst_gpu, rebuild_positions_from_mst_gpu
27# For type checking only
28if TYPE_CHECKING: 28 ↛ 29line 28 didn't jump to line 29 because the condition on line 28 was never true
29 import cupy as cp
31# Import CuPy as an optional dependency
32cp = optional_import("cupy")
34logger = logging.getLogger(__name__)
37def _convert_overlap_to_tile_coordinates(
38 dy: float, dx: float,
39 overlap_h: int, overlap_w: int,
40 tile_h: int, tile_w: int,
41 direction: str
42) -> Tuple[float, float]:
43 """
44 Convert overlap-region-relative displacements to tile-center coordinates.
46 Args:
47 dy, dx: Phase correlation displacements in overlap region coordinates
48 overlap_h, overlap_w: Overlap region dimensions
49 tile_h, tile_w: Full tile dimensions
50 direction: 'horizontal' or 'vertical'
52 Returns:
53 (tile_dy, tile_dx): Displacements in tile-center coordinates
54 """
55 if direction == 'horizontal':
56 # For horizontal connections (left-right)
57 # Expected displacement is approximately tile_w - overlap_w
58 expected_dx = tile_w - overlap_w
59 tile_dx = expected_dx + dx # Add phase correlation correction
60 tile_dy = dy # Vertical should be minimal
62 elif direction == 'vertical':
63 # For vertical connections (top-bottom)
64 # Expected displacement is approximately tile_h - overlap_h
65 expected_dy = tile_h - overlap_h
66 tile_dy = expected_dy + dy # Add phase correlation correction
67 tile_dx = dx # Horizontal should be minimal
69 else:
70 raise ValueError(f"Invalid direction: {direction}. Must be 'horizontal' or 'vertical'")
72 return tile_dy, tile_dx
78def _validate_displacement_magnitude(
79 tile_dx: float, tile_dy: float,
80 expected_dx: float, expected_dy: float,
81 direction: str,
82 tolerance_factor: float = 2.0,
83 tolerance_percent: float = 0.1
84) -> bool:
85 """
86 Validate that displacement magnitudes are reasonable.
88 Args:
89 tile_dx, tile_dy: Computed tile-center displacements
90 expected_dx, expected_dy: Expected displacements
91 direction: 'horizontal' or 'vertical'
92 tolerance_factor: How much deviation to allow
94 Returns:
95 True if displacement is reasonable, False otherwise
96 """
97 if direction == 'horizontal':
98 # For horizontal connections, dx should be close to expected_dx
99 dx_error = abs(tile_dx - expected_dx)
100 max_allowed_error = tolerance_factor * expected_dx * tolerance_percent
101 dx_valid = dx_error <= max_allowed_error
103 # dy should be small (minimal vertical drift relative to expected_dx, not expected_dy)
104 max_allowed_dy = tolerance_factor * expected_dx * tolerance_percent
105 dy_valid = abs(tile_dy) <= max_allowed_dy
107 return dx_valid and dy_valid
109 elif direction == 'vertical':
110 # For vertical connections, dy should be close to expected_dy
111 dy_error = abs(tile_dy - expected_dy)
112 max_allowed_error = tolerance_factor * expected_dy * tolerance_percent
113 dy_valid = dy_error <= max_allowed_error
115 # dx should be small (minimal horizontal drift relative to expected_dy, not expected_dx)
116 max_allowed_dx = tolerance_factor * expected_dy * tolerance_percent
117 dx_valid = abs(tile_dx) <= max_allowed_dx
119 return dy_valid and dx_valid
121 return False
124def _validate_cupy_array(array, name: str = "input") -> None: # type: ignore
125 """Validate that the input is a CuPy array."""
126 if not isinstance(array, cp.ndarray):
127 raise TypeError(f"{name} must be a CuPy array, got {type(array)}")
130def _global_optimization_gpu_only(
131 positions: "cp.ndarray", # type: ignore
132 tile_grid: "cp.ndarray", # type: ignore
133 num_rows: int,
134 num_cols: int,
135 expected_dx: float,
136 expected_dy: float,
137 overlap_ratio: float,
138 subpixel: bool,
139 *,
141 quality_threshold: float = 0.5, # NIST Algorithm 15: ncc >= 0.5 for valid translations
142 subpixel_radius: int = 3,
143 regularization_eps_multiplier: float = 1000.0,
144 anchor_tile_index: int = 0,
145 debug_connection_limit: int = 3,
146 debug_vertical_limit: int = 6,
147 displacement_tolerance_factor: float = 2.0,
148 displacement_tolerance_percent: float = 0.3,
149 consistency_threshold_percent: float = 0.5,
150 max_connections_multiplier: int = 2,
151 adaptive_base_threshold: float = 0.3,
152 adaptive_percentile_threshold: float = 0.25,
153 translation_tolerance_factor: float = 0.2,
154 translation_min_quality: float = 0.3,
155 magnitude_threshold_multiplier: float = 1e-6,
156 peak_candidates_multiplier: int = 4,
157 min_peak_distance: int = 5,
158 use_nist_robustness: bool = True, # NIST Algorithm 2: Enable multi-peak PCIAM with interpretation testing
159 n_peaks: int = 2, # NIST Algorithm 2: n=2 peaks tested (manually selected based on experimental testing)
160 use_nist_normalization: bool = True, # NIST Algorithm 3: Use fc/abs(fc) normalization instead of regularized approach
162 # NIST Algorithm 9: Stage model parameters
163 overlap_uncertainty_percent: float = 3.0, # NIST default: 3% overlap uncertainty (pou)
164 outlier_threshold_multiplier: float = 1.5, # NIST Algorithm 16: 1.5 × IQR for outlier detection
165) -> "cp.ndarray": # type: ignore
166 """
167 GPU-only global optimization using simplified MST approach.
168 """
169 H, W = tile_grid.shape[2], tile_grid.shape[3]
170 num_tiles = num_rows * num_cols
172 # Pre-allocate GPU arrays for connections
173 max_connections = max_connections_multiplier * num_tiles # Each tile has at most 2 neighbors (right, bottom)
174 connection_from = cp.full(max_connections, -1, dtype=cp.int32)
175 connection_to = cp.full(max_connections, -1, dtype=cp.int32)
176 connection_dx = cp.zeros(max_connections, dtype=cp.float32)
177 connection_dy = cp.zeros(max_connections, dtype=cp.float32)
178 connection_quality = cp.zeros(max_connections, dtype=cp.float32)
180 conn_idx = 0
182 # Debug: Track quality filtering
183 total_correlations = 0
184 passed_threshold = 0
185 all_qualities = []
187 # Debug: Print expected displacements and coordinate validation
188 print(f"🔥 EXPECTED DISPLACEMENTS: dx={float(expected_dx):.1f}, dy={float(expected_dy):.1f}")
189 print(f"🔥 OVERLAP RATIO: {overlap_ratio}, H={H}, W={W}")
190 print(f"🔥 COORDINATE VALIDATION:")
191 print(f" Expected tile spacing: dx={float(expected_dx):.1f}, dy={float(expected_dy):.1f}")
192 print(f" Overlap regions: H*ratio={H*overlap_ratio:.1f}, W*ratio={W*overlap_ratio:.1f}")
193 print(f" Actual overlap: H={H*overlap_ratio:.1f}, W={W*overlap_ratio:.1f} pixels")
195 # Debug: Check if images are black
196 tile_stats = []
197 for r in range(num_rows):
198 for c in range(num_cols):
199 tile = tile_grid[r, c]
200 tile_min = float(cp.min(tile))
201 tile_max = float(cp.max(tile))
202 tile_mean = float(cp.mean(tile))
203 tile_stats.append((tile_min, tile_max, tile_mean))
205 print(f"🔥 TILE STATS: First {debug_connection_limit} tiles - min/max/mean:")
206 for i, (tmin, tmax, tmean) in enumerate(tile_stats[:debug_connection_limit]):
207 print(f" Tile {i}: [{tmin:.1f}, {tmax:.1f}], mean={tmean:.1f}")
209 # Build connections (GPU operations)
210 for r in range(num_rows):
211 for c in range(num_cols):
212 tile_idx = r * num_cols + c
213 current_tile = tile_grid[r, c]
215 # Horizontal connection
216 if c < num_cols - 1:
217 right_idx = r * num_cols + (c + 1)
218 right_tile = tile_grid[r, c + 1]
220 overlap_w = cp.int32(W * overlap_ratio)
221 left_region = current_tile[:, -overlap_w:] # Right edge of left tile
222 right_region = right_tile[:, :overlap_w] # Left edge of right tile
224 # Debug: Check overlap region extraction (avoid GPU sync on .shape)
225 if conn_idx < debug_connection_limit:
226 print(f"🔥 HORIZONTAL OVERLAP {conn_idx}: tiles {tile_idx}->{right_idx}")
227 print(f" overlap_w={int(overlap_w)}, W={W}")
228 # Avoid .shape access which can cause GPU sync issues
229 print(f" Processing overlap regions (shapes not shown to avoid GPU sync)")
231 if use_nist_robustness:
232 dy, dx, quality = phase_correlation_nist_gpu(
233 left_region, right_region,
234 direction='horizontal',
235 n_peaks=n_peaks,
236 use_nist_normalization=use_nist_normalization
237 )
238 else:
239 dy, dx = phase_correlation_gpu_only(
240 left_region, right_region, # Standardized: left_region first
241 subpixel=subpixel,
242 subpixel_radius=subpixel_radius,
243 regularization_eps_multiplier=regularization_eps_multiplier
244 )
245 # Compute quality after applying the shift
246 quality = compute_correlation_quality_gpu_aligned(left_region, right_region, dx, dy)
248 # Debug: Track all quality values
249 total_correlations += 1
250 all_qualities.append(quality)
252 if quality >= quality_threshold:
253 # Convert overlap-region coordinates to tile-center coordinates
254 tile_dy, tile_dx = _convert_overlap_to_tile_coordinates(
255 dy, dx, int(overlap_w), int(overlap_w), H, W, 'horizontal'
256 )
258 # Log coordinate transformation for debugging
259 if conn_idx < debug_connection_limit: # Only log first few for brevity
260 log_coordinate_transformation(
261 dy, dx, tile_dy, tile_dx, 'horizontal', (tile_idx, right_idx)
262 )
264 # Validate displacement magnitude
265 displacement_valid = _validate_displacement_magnitude(
266 tile_dx, tile_dy, float(expected_dx), 0.0, 'horizontal',
267 displacement_tolerance_factor, displacement_tolerance_percent
268 )
270 if displacement_valid:
271 passed_threshold += 1
272 connection_from[conn_idx] = tile_idx
273 connection_to[conn_idx] = right_idx
274 connection_dx[conn_idx] = tile_dx
275 connection_dy[conn_idx] = tile_dy
276 connection_quality[conn_idx] = quality
278 # Debug: Print first few connections
279 if conn_idx < debug_connection_limit:
280 print(f"🔥 HORIZONTAL CONNECTION {conn_idx}: {tile_idx}->{right_idx}")
281 print(f" overlap coords: dx={float(dx):.3f}, dy={float(dy):.3f}")
282 print(f" tile coords: dx={float(tile_dx):.3f}, dy={float(tile_dy):.3f}")
283 print(f" quality={float(quality):.6f}, displacement_valid={displacement_valid}")
285 conn_idx += 1
286 else:
287 # Debug: Log rejected connections
288 if conn_idx < debug_connection_limit:
289 print(f"🔥 REJECTED HORIZONTAL {tile_idx}->{right_idx}: displacement invalid")
290 print(f" tile coords: dx={float(tile_dx):.3f}, dy={float(tile_dy):.3f}")
291 print(f" expected: dx={float(expected_dx):.3f}, dy={float(expected_dy):.3f}")
292 # Show validation details
293 dx_error = abs(tile_dx - expected_dx)
294 max_allowed_error = displacement_tolerance_factor * expected_dx * displacement_tolerance_percent
295 max_allowed_dy = displacement_tolerance_factor * expected_dx * displacement_tolerance_percent
296 print(f" dx_error={dx_error:.3f} vs max_allowed={max_allowed_error:.3f}")
297 print(f" abs(dy)={abs(tile_dy):.3f} vs max_allowed_dy={max_allowed_dy:.3f}")
299 # Vertical connection
300 if r < num_rows - 1:
301 bottom_idx = (r + 1) * num_cols + c
302 bottom_tile = tile_grid[r + 1, c]
304 overlap_h = cp.int32(H * overlap_ratio)
305 top_region = current_tile[-overlap_h:, :] # Bottom edge of top tile
306 bottom_region = bottom_tile[:overlap_h, :] # Top edge of bottom tile
308 if use_nist_robustness:
309 dy, dx, quality = phase_correlation_nist_gpu(
310 top_region, bottom_region,
311 direction='vertical',
312 n_peaks=n_peaks,
313 use_nist_normalization=use_nist_normalization
314 )
315 else:
316 dy, dx = phase_correlation_gpu_only(
317 top_region, bottom_region, # Standardized: top_region first
318 subpixel=subpixel,
319 subpixel_radius=subpixel_radius,
320 regularization_eps_multiplier=regularization_eps_multiplier
321 )
322 # Compute quality after applying the shift
323 quality = compute_correlation_quality_gpu_aligned(top_region, bottom_region, dx, dy)
325 # Debug: Track all quality values
326 total_correlations += 1
327 all_qualities.append(quality)
329 if quality >= quality_threshold:
330 # Convert overlap-region coordinates to tile-center coordinates
331 tile_dy, tile_dx = _convert_overlap_to_tile_coordinates(
332 dy, dx, int(overlap_h), int(overlap_h), H, W, 'vertical'
333 )
335 # Log coordinate transformation for debugging
336 if conn_idx < debug_vertical_limit: # Only log first few for brevity
337 log_coordinate_transformation(
338 dy, dx, tile_dy, tile_dx, 'vertical', (tile_idx, bottom_idx)
339 )
341 # Validate displacement magnitude
342 displacement_valid = _validate_displacement_magnitude(
343 tile_dx, tile_dy, 0.0, float(expected_dy), 'vertical',
344 displacement_tolerance_factor, displacement_tolerance_percent
345 )
347 if displacement_valid:
348 passed_threshold += 1
349 connection_from[conn_idx] = tile_idx
350 connection_to[conn_idx] = bottom_idx
351 connection_dx[conn_idx] = tile_dx
352 connection_dy[conn_idx] = tile_dy
353 connection_quality[conn_idx] = quality
355 # Debug: Print first few connections
356 if conn_idx < debug_vertical_limit: # Show a few more since we want to see vertical connections too
357 print(f"🔥 VERTICAL CONNECTION {conn_idx}: {tile_idx}->{bottom_idx}")
358 print(f" overlap coords: dx={float(dx):.3f}, dy={float(dy):.3f}")
359 print(f" tile coords: dx={float(tile_dx):.3f}, dy={float(tile_dy):.3f}")
360 print(f" quality={float(quality):.6f}, displacement_valid={displacement_valid}")
362 conn_idx += 1
363 else:
364 # Debug: Log rejected connections
365 if conn_idx < debug_vertical_limit:
366 print(f"🔥 REJECTED VERTICAL {tile_idx}->{bottom_idx}: displacement invalid")
367 print(f" tile coords: dx={float(tile_dx):.3f}, dy={float(tile_dy):.3f}")
368 print(f" expected: dx={float(expected_dx):.3f}, dy={float(expected_dy):.3f}")
369 # Show validation details
370 dy_error = abs(tile_dy - expected_dy)
371 max_allowed_error = displacement_tolerance_factor * expected_dy * displacement_tolerance_percent
372 max_allowed_dx = displacement_tolerance_factor * expected_dy * displacement_tolerance_percent
373 print(f" dy_error={dy_error:.3f} vs max_allowed={max_allowed_error:.3f}")
374 print(f" abs(dx)={abs(tile_dx):.3f} vs max_allowed_dx={max_allowed_dx:.3f}")
376 # Compute adaptive quality threshold if we have quality data
377 if len(all_qualities) > 0:
378 adaptive_threshold = compute_adaptive_quality_threshold(
379 all_qualities, adaptive_base_threshold, adaptive_percentile_threshold
380 )
381 print(f"🔥 ADAPTIVE THRESHOLD: original={quality_threshold:.6f}, adaptive={adaptive_threshold:.6f}")
383 # Re-filter connections with adaptive threshold if it's different
384 if adaptive_threshold != quality_threshold and adaptive_threshold < quality_threshold:
385 print(f"🔥 RE-FILTERING with adaptive threshold...")
386 # Note: In a full implementation, we'd re-process with the adaptive threshold
387 # For now, we'll use the original threshold but log the adaptive one
389 # Debug: Print quality filtering summary
390 print(f"🔥 QUALITY FILTERING: {passed_threshold}/{total_correlations} connections passed threshold {quality_threshold}")
391 if len(all_qualities) > 0:
392 min_q = float(cp.min(cp.array(all_qualities)))
393 max_q = float(cp.max(cp.array(all_qualities)))
394 mean_q = float(cp.mean(cp.array(all_qualities)))
395 print(f"🔥 QUALITY RANGE: min={min_q:.6f}, max={max_q:.6f}, mean={mean_q:.6f}")
397 # Validate translation consistency (Plan 03)
398 if conn_idx > 0:
399 # Collect translations for validation
400 translations = []
401 for i in range(conn_idx):
402 dy_val = float(connection_dy[i])
403 dx_val = float(connection_dx[i])
404 quality_val = float(connection_quality[i])
405 translations.append((dy_val, dx_val, quality_val))
407 # Validate against expected spacing
408 expected_spacing = (float(expected_dx), float(expected_dy))
409 valid_flags = validate_translation_consistency(
410 translations, expected_spacing, translation_tolerance_factor, translation_min_quality
411 )
413 num_valid = sum(valid_flags)
414 print(f"🔥 TRANSLATION VALIDATION: {num_valid}/{len(translations)} connections are consistent")
416 if num_valid < len(translations) * consistency_threshold_percent: # Less than threshold% valid
417 print(f"🔥 WARNING: Low translation consistency ({num_valid}/{len(translations)})")
418 print(f"🔥 Expected spacing: dx={expected_spacing[0]:.1f}, dy={expected_spacing[1]:.1f}")
419 print(f"🔥 Consider adjusting overlap_ratio or quality thresholds")
421 # Trim arrays to actual size (GPU)
422 if conn_idx > 0:
423 valid_connections = cp.arange(conn_idx)
424 connection_from = connection_from[:conn_idx]
425 connection_to = connection_to[:conn_idx]
426 connection_dx = connection_dx[:conn_idx]
427 connection_dy = connection_dy[:conn_idx]
428 connection_quality = connection_quality[:conn_idx]
430 # Build MST using refactored GPU Borůvka's algorithm
431 mst_edges = build_mst_gpu(
432 connection_from, connection_to, connection_dx,
433 connection_dy, connection_quality, num_tiles
434 )
436 # Rebuild positions using MST (GPU)
437 new_positions = rebuild_positions_from_mst_gpu(
438 positions, mst_edges, num_tiles, anchor_tile_index
439 )
441 return new_positions
443 return positions
446@special_inputs("grid_dimensions")
447@special_outputs("positions")
448@cupy_func
449def mist_compute_tile_positions(
450 image_stack: "cp.ndarray", # type: ignore
451 grid_dimensions: Tuple[int, int],
452 *,
453 # === Input Validation Parameters ===
454 method: str = "phase_correlation",
455 fft_backend: str = "cupy",
457 # === Core Algorithm Parameters ===
458 normalize: bool = True,
459 verbose: bool = False,
460 overlap_ratio: float = 0.1,
461 subpixel: bool = True,
462 refinement_iterations: int = 10,
463 global_optimization: bool = True,
464 anchor_tile_index: int = 0,
466 # === Refinement Tuning Parameters ===
467 refinement_damping: float = 0.5,
468 correlation_weight_horizontal: float = 1.0,
469 correlation_weight_vertical: float = 1.0,
471 # === Phase Correlation Parameters ===
472 subpixel_radius: int = 3,
473 regularization_eps_multiplier: float = 1000.0,
475 # === MST Global Optimization Parameters ===
476 mst_quality_threshold: float = 0.5, # NIST Algorithm 15: ncc >= 0.5 for MST edge inclusion
477 # NIST robustness parameters (Algorithms 2-5)
478 use_nist_robustness: bool = True, # Enable full NIST PCIAM implementation
479 n_peaks: int = 2, # NIST Algorithm 2: Test 2 peaks (experimentally determined)
480 use_nist_normalization: bool = True, # NIST Algorithm 3: fc/abs(fc) normalization
481 # Debugging and validation parameters
482 debug_connection_limit: int = 3,
483 debug_vertical_limit: int = 6,
484 displacement_tolerance_factor: float = 2.0,
485 displacement_tolerance_percent: float = 0.3,
486 consistency_threshold_percent: float = 0.5,
487 max_connections_multiplier: int = 2,
488 # Quality metric tuning parameters
489 adaptive_base_threshold: float = 0.3,
490 adaptive_percentile_threshold: float = 0.25,
491 translation_tolerance_factor: float = 0.2,
492 translation_min_quality: float = 0.3,
493 # Phase correlation tuning parameters
494 magnitude_threshold_multiplier: float = 1e-6,
495 peak_candidates_multiplier: int = 4,
496 min_peak_distance: int = 5,
497 **kwargs
498) -> Tuple["cp.ndarray", "cp.ndarray"]: # type: ignore
499 """
500 Full GPU MIST implementation with zero CPU operations.
502 Performs microscopy image stitching using phase correlation and iterative refinement.
503 The algorithm has three phases:
504 1. Initial positioning using sequential phase correlation
505 2. Iterative refinement with constraint optimization
506 3. Global optimization using minimum spanning tree (MST)
508 Args:
509 image_stack: 3D tensor (Z, Y, X) of tiles to stitch
510 grid_dimensions: (num_cols, num_rows) grid layout of tiles
512 === Input Validation Parameters ===
513 method: Correlation method - must be "phase_correlation"
514 fft_backend: FFT backend - must be "cupy" for GPU acceleration
516 === Core Algorithm Parameters (NIST Algorithms 1-3) ===
517 normalize: Normalize each tile to [0,1] range using (tile-min)/(max-min).
518 True = better correlation accuracy, handles varying illumination.
519 False = faster but poor results with uneven lighting.
520 Used in NIST Algorithm 3 (PCM) preprocessing.
521 verbose: Enable detailed logging of algorithm progress and timing
522 overlap_ratio: Expected overlap between adjacent tiles as fraction (0.0-1.0).
523 Defines correlation region size: overlap_w = int(W * overlap_ratio).
524 CRITICAL: Must match actual overlap in data or correlation fails.
525 Higher (0.2-0.4) = more robust but slower.
526 Lower (0.05-0.08) = faster but less accurate.
527 Used in NIST Algorithm 10 (Compute Image Overlap).
528 subpixel: Enable subpixel-accurate phase correlation for higher precision.
529 True = center-of-mass interpolation around correlation peak.
530 False = pixel-only accuracy (faster, less precise).
531 Enhances NIST Algorithm 3 (PCM) with subpixel refinement.
532 refinement_iterations: Number of iterative position refinement passes (0-50).
533 Each iteration applies weighted position corrections.
534 Higher = better convergence but much slower.
535 0 = skip refinement (fastest, least accurate).
536 Implements NIST Algorithm 21 (Bounded NCC Hill Climb).
537 global_optimization: Enable MST-based global optimization phase.
538 Uses minimum spanning tree to optimize tile positions globally.
539 Significantly improves accuracy for large grids.
540 Implements NIST Phase 3 (Image Composition).
541 anchor_tile_index: Index of reference tile that remains fixed at origin (usually 0).
542 All other positions calculated relative to this tile.
543 Used in NIST MST position reconstruction.
545 === Refinement Tuning Parameters ===
546 refinement_damping: Controls how aggressively positions are updated (0.0-1.0).
547 Formula: new_pos = (1-damping)*old_pos + damping*correction.
548 Higher (0.7-0.9) = faster convergence but may overshoot.
549 Lower (0.1-0.3) = more stable but slower convergence.
550 1.0 = full correction (may be unstable), 0.0 = no updates.
551 correlation_weight_horizontal: Weight for horizontal tile constraints (>0).
552 Higher values prioritize horizontal alignment accuracy.
553 Typical range: 0.5-2.0.
554 correlation_weight_vertical: Weight for vertical tile constraints (>0).
555 Higher values prioritize vertical alignment accuracy.
556 Typical range: 0.5-2.0.
558 === Phase Correlation Parameters (NIST Algorithm 3) ===
559 subpixel_radius: Radius around correlation peak for center-of-mass calculation.
560 Extracts (2*radius+1)² region around peak for interpolation.
561 Higher (5-10) = more accurate subpixel positioning but slower.
562 Lower (1-2) = faster but less precise, may cause drift.
563 0 = pixel-only accuracy (fastest, least precise).
564 Enhances NIST Algorithm 3 (PCM) with subpixel precision.
565 regularization_eps_multiplier: Prevents division by zero in phase correlation.
566 Formula: eps = machine_epsilon * multiplier.
567 Higher (10000+) = more stable with noisy images.
568 Lower (100-500) = higher precision but may fail.
569 Too low (<10) = risk of numerical instability.
570 Used in NIST Algorithm 3 cross-power normalization.
572 === MST Global Optimization Parameters (NIST Algorithms 8-21) ===
573 mst_quality_threshold: Minimum correlation quality for MST edge inclusion (0.0-1.0).
574 NIST Algorithm 15: ncc >= 0.5 for valid translations.
575 Formula: if correlation_peak < threshold: reject_connection.
576 NIST default: 0.5 (stricter quality control).
577 Higher = fewer connections, lower = includes weak correlations.
578 Too high = MST may fail, too low = includes noise.
579 use_nist_robustness: Enable NIST robust phase correlation (Algorithm 2).
580 True = multi-peak PCIAM with interpretation testing.
581 False = simplified single-peak method (faster).
582 n_peaks: Number of correlation peaks to analyze (NIST Algorithm 4).
583 NIST default: n=2 (manually selected based on experimental testing).
584 Higher = more robust peak selection but slower processing.
585 use_nist_normalization: Apply NIST normalization method (Algorithm 3).
586 True = fc/abs(fc) normalization (NIST standard).
587 False = OpenHCS regularization method.
589 displacement_tolerance_factor: Multiplier for expected displacement tolerance.
590 NIST Algorithm 14: Stage model displacement validation.
591 Formula: max_error = factor * expected_displacement * percent.
592 Higher (3.0-5.0) = more permissive validation.
593 Lower (1.0-1.5) = stricter validation.
594 displacement_tolerance_percent: Percentage tolerance for displacement (0.0-1.0).
595 NIST Algorithm 14: Displacement validation threshold.
596 Formula: valid if |actual - expected| <= expected * percent.
597 0.3 = ±30% deviation allowed from expected displacement.
598 Higher = accepts larger deviations, lower = stricter.
600 debug_connection_limit: Max horizontal connections to log for debugging (0-10)
601 debug_vertical_limit: Max vertical connections to log for debugging (0-10)
602 consistency_threshold_percent: Translation consistency validation threshold (0.0-1.0).
603 NIST Algorithm 17: Filter by repeatability.
604 Formula: valid if |translation - median| <= median * threshold.
605 0.5 = ±50% deviation from median allowed.
606 Higher = more permissive, lower = stricter consistency.
607 max_connections_multiplier: Maximum connections per tile in MST construction.
608 Formula: max_connections = base_connections * multiplier.
609 Prevents over-connected graphs that slow MST algorithms.
610 2 = allow 2x normal connections, 1 = strict minimum.
611 adaptive_base_threshold: Minimum quality threshold for adaptive quality metrics.
612 NIST-inspired adaptive thresholding for challenging datasets.
613 Formula: final_threshold = max(base_threshold, percentile_threshold).
614 0.3 = minimum 30% correlation required regardless of distribution.
615 Prevents threshold from becoming too permissive.
616 adaptive_percentile_threshold: Percentile-based quality threshold (0.0-1.0).
617 NIST Algorithm 9: Stage model validation approach.
618 Formula: threshold = percentile(all_qualities, percentile * 100).
619 0.25 = use 25th percentile of quality distribution.
620 Lower = more permissive, higher = stricter selection.
621 translation_tolerance_factor: Tolerance multiplier for translation validation.
622 NIST Algorithm 14: Stage model displacement validation.
623 Formula: max_error = expected_displacement * factor * percent.
624 0.2 = allow 20% deviation from expected displacement.
625 Higher = more permissive validation.
626 translation_min_quality: Minimum correlation quality for translation acceptance.
627 NIST Algorithm 15: Quality-based filtering threshold.
628 Formula: accept if ncc >= min_quality.
629 0.3 = require 30% normalized cross-correlation minimum.
630 Higher = stricter quality, lower = more permissive.
631 magnitude_threshold_multiplier: FFT magnitude threshold for numerical stability.
632 NIST Algorithm 3: Cross-power spectrum normalization.
633 Formula: threshold = mean(magnitude) * multiplier.
634 1e-6 = very small threshold for numerical stability.
635 Higher = more aggressive filtering of low-magnitude frequencies.
636 peak_candidates_multiplier: Candidate peak search multiplier for robustness.
637 NIST Algorithm 4: Multi-peak max search optimization.
638 Formula: n_candidates = n_peaks * multiplier.
639 4 = search 4x more candidates than needed for robust selection.
640 Higher = more thorough search but slower processing.
641 min_peak_distance: Minimum pixel distance between correlation peaks.
642 NIST Algorithm 4: Prevents duplicate peak detection.
643 Formula: reject if distance(peak1, peak2) < min_distance.
644 5 = peaks must be ≥5 pixels apart to be considered distinct.
645 Higher = fewer but more distinct peaks, lower = more peaks.
647 === NIST Mathematical Formulas ===
649 Algorithm 3 (PCM): Peak Correlation Matrix
650 F1 ← fft2D(I1), F2 ← fft2D(I2)
651 FC ← F1 .* conj(F2)
652 PCM ← ifft2D(FC ./ abs(FC))
654 Algorithm 6 (NCC): Normalized Cross-Correlation
655 I1 ← I1 - mean(I1), I2 ← I2 - mean(I2)
656 ncc = (I1 · I2) / (|I1| * |I2|)
658 Algorithm 10 (Overlap): Image Overlap Computation
659 overlap_percent = 100 - mu (where mu is mean translation)
660 valid_range = [overlap ± overlap_uncertainty_percent]
662 Algorithm 16 (Outliers): Statistical Outlier Detection
663 q1 = 25th percentile, q3 = 75th percentile
664 IQR = q3 - q1
665 outlier if: value < (q1 - 1.5*IQR) OR value > (q3 + 1.5*IQR)
667 Algorithm 21 (Hill Climb): Bounded Translation Refinement
668 search_bounds = [current ± repeatability]
669 ncc_surface[i,j] = ncc(extract_overlap(I1, j, i), extract_overlap(I2, -j, -i))
670 climb to local maximum within bounds
672 === NIST Performance Guidance ===
674 Quality Threshold Tuning:
675 - Start with NIST default: 0.5 (strict quality control)
676 - Lower to 0.3-0.4 for noisy biological samples
677 - Lower to 0.1-0.2 for very challenging datasets
678 - Monitor MST edge count: need ≥(num_tiles-1) edges minimum
680 Peak Count Optimization:
681 - NIST default: n=2 peaks (experimentally optimal)
682 - Increase to 3-5 for highly repetitive patterns
683 - Keep at 2 for most microscopy applications
685 Overlap Ratio Guidelines:
686 - Must match actual image overlap precisely
687 - Typical microscopy: 0.1-0.2 (10-20% overlap)
688 - Higher overlap = more robust but slower processing
689 - Lower overlap = faster but less reliable alignment
691 Subpixel Refinement:
692 - Enable for publication-quality results
693 - Radius 3-5 optimal for most applications
694 - Disable for speed-critical applications
696 Expected Performance:
697 - With NIST defaults: High accuracy, moderate speed
698 - Quality threshold 0.5: Strict filtering, fewer edges
699 - Multi-peak robustness: 2-3x slower but more reliable
700 - Global optimization: Essential for large grids (>3x3)
702 Returns:
703 Tuple of (image_stack, positions) where:
704 - image_stack: Original input tiles (potentially normalized)
705 - positions: (Z, 2) array of tile positions in (x, y) format
706 Positions are centered around origin
708 Raises:
709 ValueError: If input validation fails (wrong method, backend, or dimensions)
710 TypeError: If image_stack is not a CuPy array
711 """
712 _validate_cupy_array(image_stack, "image_stack")
714 if image_stack.ndim != 3:
715 raise ValueError(f"Input must be a 3D tensor, got {image_stack.ndim}D")
717 if fft_backend != "cupy":
718 raise ValueError(f"FFT backend must be 'cupy', got '{fft_backend}'")
720 if method != "phase_correlation":
721 raise ValueError(f"Only 'phase_correlation' method is supported, got '{method}'")
723 num_cols, num_rows = grid_dimensions
724 Z, H, W = image_stack.shape
726 # VERY FIRST THING - Debug output to confirm function is called
727 print("🔥🔥🔥 MIST FUNCTION ENTRY POINT - FUNCTION IS DEFINITELY BEING CALLED! 🔥🔥🔥")
728 print(f"🔥 Image stack shape: {image_stack.shape}")
729 print(f"🔥 Grid dimensions: {grid_dimensions}")
731 # Debug: Log the actual overlap_ratio parameter being used
732 print(f"🔥 MIST FUNCTION CALLED WITH overlap_ratio={overlap_ratio}")
733 print(f"🔥 Expected: 0.1 (10% overlap), Actual: {overlap_ratio}")
735 if Z != num_rows * num_cols:
736 raise ValueError(
737 f"Number of tiles ({Z}) does not match grid size ({num_rows}x{num_cols}={num_rows*num_cols})"
738 )
740 # Normalize on GPU
741 tiles = image_stack.astype(cp.float32)
742 if normalize:
743 for z in range(Z):
744 tile = tiles[z]
745 tile_min = cp.min(tile)
746 tile_max = cp.max(tile)
747 tile_range = tile_max - tile_min
748 # Use GPU conditional to avoid division by zero
749 tiles[z] = cp.where(tile_range > 0, (tile - tile_min) / tile_range, tile)
751 # Reshape to grid (GPU operation)
752 tile_grid = tiles.reshape(num_rows, num_cols, H, W)
754 # Calculate expected spacing (GPU)
755 expected_dy = cp.float32(H * (1.0 - overlap_ratio))
756 expected_dx = cp.float32(W * (1.0 - overlap_ratio))
758 # Initialize positions on GPU
759 positions = cp.zeros((Z, 2), dtype=cp.float32)
761 if verbose:
762 logger.info(f"GPU MIST: {num_rows}x{num_cols} grid, spacing: dx={float(expected_dx):.1f}, dy={float(expected_dy):.1f}")
764 # Phase 1: Initial positioning (all GPU)
765 for r in range(num_rows):
766 for c in range(num_cols):
767 tile_idx = r * num_cols + c
769 if tile_idx == anchor_tile_index:
770 positions[tile_idx] = cp.array([0.0, 0.0])
771 continue
773 current_tile = tile_grid[r, c]
775 # Position from left neighbor (GPU operations)
776 if c > 0:
777 left_idx = r * num_cols + (c - 1)
778 left_tile = tile_grid[r, c - 1]
780 # Extract overlap regions (GPU)
781 overlap_w = cp.int32(W * overlap_ratio)
782 left_region = left_tile[:, -overlap_w:]
783 current_region = current_tile[:, :overlap_w]
785 # GPU phase correlation
786 dy, dx = phase_correlation_gpu_only(
787 left_region, current_region,
788 subpixel=subpixel,
789 subpixel_radius=subpixel_radius,
790 regularization_eps_multiplier=regularization_eps_multiplier
791 )
793 # Convert overlap-region coordinates to tile-center coordinates
794 tile_dy, tile_dx = _convert_overlap_to_tile_coordinates(
795 dy, dx, int(overlap_w), int(overlap_w), H, W, 'horizontal'
796 )
798 # Update position (GPU)
799 new_x = positions[left_idx, 0] + tile_dx
800 new_y = positions[left_idx, 1] + tile_dy
801 positions[tile_idx] = cp.array([new_x, new_y])
803 elif r > 0: # Position from top neighbor
804 top_idx = (r - 1) * num_cols + c
805 top_tile = tile_grid[r - 1, c]
807 # Extract overlap regions (GPU)
808 overlap_h = cp.int32(H * overlap_ratio)
809 top_region = top_tile[-overlap_h:, :]
810 current_region = current_tile[:overlap_h, :]
812 # GPU phase correlation
813 dy, dx = phase_correlation_gpu_only(
814 top_region, current_region,
815 subpixel=subpixel,
816 subpixel_radius=subpixel_radius,
817 regularization_eps_multiplier=regularization_eps_multiplier
818 )
820 # Convert overlap-region coordinates to tile-center coordinates
821 tile_dy, tile_dx = _convert_overlap_to_tile_coordinates(
822 dy, dx, int(overlap_h), int(overlap_h), H, W, 'vertical'
823 )
825 # Update position (GPU)
826 new_x = positions[top_idx, 0] + tile_dx
827 new_y = positions[top_idx, 1] + tile_dy
828 positions[tile_idx] = cp.array([new_x, new_y])
830 # Phase 2: Refinement iterations (all GPU)
831 for iteration in range(refinement_iterations):
832 if verbose:
833 logger.info(f"GPU refinement iteration {iteration + 1}/{refinement_iterations}")
835 position_corrections = cp.zeros_like(positions)
836 correction_weights = cp.zeros(Z, dtype=cp.float32)
838 # Horizontal constraints (GPU)
839 for r in range(num_rows):
840 for c in range(num_cols - 1):
841 left_idx = r * num_cols + c
842 right_idx = r * num_cols + (c + 1)
844 left_tile = tile_grid[r, c]
845 right_tile = tile_grid[r, c + 1]
847 overlap_w = cp.int32(W * overlap_ratio)
848 left_region = left_tile[:, -overlap_w:] # Right edge of left tile
849 right_region = right_tile[:, :overlap_w] # Left edge of right tile
851 dy, dx = phase_correlation_gpu_only(
852 left_region, right_region, # Standardized: left_region first
853 subpixel=subpixel,
854 subpixel_radius=subpixel_radius,
855 regularization_eps_multiplier=regularization_eps_multiplier
856 )
858 # Convert overlap-region coordinates to tile-center coordinates
859 tile_dy, tile_dx = _convert_overlap_to_tile_coordinates(
860 dy, dx, int(overlap_w), int(overlap_w), H, W, 'horizontal'
861 )
863 # Expected position (GPU)
864 expected_right = positions[left_idx] + cp.array([tile_dx, tile_dy])
866 # Accumulate updates (GPU)
867 position_corrections[right_idx] += expected_right * correlation_weight_horizontal
868 correction_weights[right_idx] += correlation_weight_horizontal
870 # Vertical constraints (GPU)
871 for r in range(num_rows - 1):
872 for c in range(num_cols):
873 top_idx = r * num_cols + c
874 bottom_idx = (r + 1) * num_cols + c
876 top_tile = tile_grid[r, c]
877 bottom_tile = tile_grid[r + 1, c]
879 overlap_h = cp.int32(H * overlap_ratio)
880 top_region = top_tile[-overlap_h:, :] # Bottom edge of top tile
881 bottom_region = bottom_tile[:overlap_h, :] # Top edge of bottom tile
883 dy, dx = phase_correlation_gpu_only(
884 top_region, bottom_region, # Standardized: top_region first
885 subpixel=subpixel,
886 subpixel_radius=subpixel_radius,
887 regularization_eps_multiplier=regularization_eps_multiplier
888 )
890 # Convert overlap-region coordinates to tile-center coordinates
891 tile_dy, tile_dx = _convert_overlap_to_tile_coordinates(
892 dy, dx, int(overlap_h), int(overlap_h), H, W, 'vertical'
893 )
895 # Expected position (GPU)
896 expected_bottom = positions[top_idx] + cp.array([tile_dx, tile_dy])
898 # Accumulate updates (GPU)
899 position_corrections[bottom_idx] += expected_bottom * correlation_weight_vertical
900 correction_weights[bottom_idx] += correlation_weight_vertical
902 # Apply corrections with damping (all GPU)
903 for tile_idx in range(Z):
904 if correction_weights[tile_idx] > 0 and tile_idx != anchor_tile_index:
905 averaged_correction = position_corrections[tile_idx] / correction_weights[tile_idx]
906 positions[tile_idx] = ((1 - refinement_damping) * positions[tile_idx] +
907 refinement_damping * averaged_correction)
909 # Phase 3: Global optimization MST (GPU operations)
910 print(f"🔥 PHASE 3: global_optimization={global_optimization}")
911 if global_optimization:
912 print(f"🔥 STARTING MST GLOBAL OPTIMIZATION")
913 positions = _global_optimization_gpu_only(
914 positions, tile_grid, num_rows, num_cols,
915 expected_dx, expected_dy, overlap_ratio, subpixel,
917 quality_threshold=mst_quality_threshold,
918 subpixel_radius=subpixel_radius,
919 regularization_eps_multiplier=regularization_eps_multiplier,
920 anchor_tile_index=anchor_tile_index,
921 debug_connection_limit=debug_connection_limit,
922 debug_vertical_limit=debug_vertical_limit,
923 displacement_tolerance_factor=displacement_tolerance_factor,
924 displacement_tolerance_percent=displacement_tolerance_percent,
925 consistency_threshold_percent=consistency_threshold_percent,
926 max_connections_multiplier=max_connections_multiplier,
927 adaptive_base_threshold=adaptive_base_threshold,
928 adaptive_percentile_threshold=adaptive_percentile_threshold,
929 translation_tolerance_factor=translation_tolerance_factor,
930 translation_min_quality=translation_min_quality,
931 magnitude_threshold_multiplier=magnitude_threshold_multiplier,
932 peak_candidates_multiplier=peak_candidates_multiplier,
933 min_peak_distance=min_peak_distance,
934 use_nist_robustness=use_nist_robustness,
935 n_peaks=n_peaks,
936 use_nist_normalization=use_nist_normalization
937 )
939 # Center positions (GPU)
940 mean_pos = cp.mean(positions, axis=0)
941 positions = positions - mean_pos
943 print(f"🔥 MIST COMPLETE: Returning {positions.shape} positions")
944 return tiles, positions