Coverage for openhcs/processing/backends/pos_gen/mist/quality_metrics.py: 8.8%
153 statements
« prev ^ index » next coverage.py v7.10.3, created at 2025-08-14 05:57 +0000
« prev ^ index » next coverage.py v7.10.3, created at 2025-08-14 05:57 +0000
1"""
2Quality Metrics for MIST Algorithm
4Functions for computing correlation quality and adaptive thresholds.
5"""
6from __future__ import annotations
8from typing import TYPE_CHECKING, List, Tuple, Dict
9import logging
11from openhcs.core.utils import optional_import
13# For type checking only
14if TYPE_CHECKING: 14 ↛ 15line 14 didn't jump to line 15 because the condition on line 14 was never true
15 import cupy as cp
17# Import CuPy as an optional dependency
18cp = optional_import("cupy")
21def _validate_cupy_array(array, name: str = "input") -> None: # type: ignore
22 """Validate that the input is a CuPy array."""
23 if not isinstance(array, cp.ndarray):
24 raise TypeError(f"{name} must be a CuPy array, got {type(array)}")
27def compute_correlation_quality_gpu(region1: "cp.ndarray", region2: "cp.ndarray") -> float: # type: ignore
28 """GPU-only normalized cross-correlation quality metric."""
29 # Validate input regions have same shape
30 if region1.shape != region2.shape:
31 return 0.0
33 # Check for empty or single-pixel regions
34 if region1.size <= 1:
35 return 0.0
37 r1_flat = region1.flatten()
38 r2_flat = region2.flatten()
40 # Normalize (GPU)
41 r1_mean = cp.mean(r1_flat)
42 r2_mean = cp.mean(r2_flat)
43 r1_norm = r1_flat - r1_mean
44 r2_norm = r2_flat - r2_mean
46 # Correlation (GPU)
47 numerator = cp.sum(r1_norm * r2_norm)
48 denom1 = cp.sqrt(cp.sum(r1_norm ** 2))
49 denom2 = cp.sqrt(cp.sum(r2_norm ** 2))
51 # Avoid division by zero with more robust threshold (GPU)
52 eps = cp.finfo(cp.float32).eps * 1000.0
53 correlation = cp.where((denom1 > eps) & (denom2 > eps),
54 cp.abs(numerator / (denom1 * denom2)),
55 cp.float32(0.0))
57 return float(correlation)
60def compute_correlation_quality_gpu_aligned(region1: "cp.ndarray", region2: "cp.ndarray", dx: float, dy: float) -> float: # type: ignore
61 """
62 GPU-only normalized cross-correlation quality metric after applying computed shift.
64 This measures how well the regions align after applying the phase correlation shift.
65 """
66 # Convert shifts to integer pixels for alignment
67 shift_x = int(round(dx))
68 shift_y = int(round(dy))
70 # Get region dimensions
71 h1, w1 = region1.shape
72 h2, w2 = region2.shape
74 # Calculate overlap region after applying shift
75 # For horizontal alignment: region1 is left, region2 is right
76 # For vertical alignment: region1 is top, region2 is bottom
78 # Determine overlap bounds considering the shift
79 if abs(shift_x) >= min(w1, w2) or abs(shift_y) >= min(h1, h2):
80 # No overlap after shift
81 return 0.0
83 # Calculate actual overlap region
84 if shift_x >= 0:
85 # region2 shifted right
86 x1_start = max(0, shift_x)
87 x1_end = min(w1, w2 + shift_x)
88 x2_start = max(0, -shift_x)
89 x2_end = min(w2, w1 - shift_x)
90 else:
91 # region2 shifted left
92 x1_start = max(0, -shift_x)
93 x1_end = min(w1, w2 - shift_x)
94 x2_start = max(0, shift_x)
95 x2_end = min(w2, w1 + shift_x)
97 if shift_y >= 0:
98 # region2 shifted down
99 y1_start = max(0, shift_y)
100 y1_end = min(h1, h2 + shift_y)
101 y2_start = max(0, -shift_y)
102 y2_end = min(h2, h1 - shift_y)
103 else:
104 # region2 shifted up
105 y1_start = max(0, -shift_y)
106 y1_end = min(h1, h2 - shift_y)
107 y2_start = max(0, shift_y)
108 y2_end = min(h2, h1 + shift_y)
110 # Extract aligned overlap regions
111 if x1_end <= x1_start or y1_end <= y1_start or x2_end <= x2_start or y2_end <= y2_start:
112 return 0.0
114 aligned_region1 = region1[y1_start:y1_end, x1_start:x1_end]
115 aligned_region2 = region2[y2_start:y2_end, x2_start:x2_end]
117 # Ensure regions have the same size
118 min_h = min(aligned_region1.shape[0], aligned_region2.shape[0])
119 min_w = min(aligned_region1.shape[1], aligned_region2.shape[1])
121 if min_h <= 0 or min_w <= 0:
122 return 0.0
124 aligned_region1 = aligned_region1[:min_h, :min_w]
125 aligned_region2 = aligned_region2[:min_h, :min_w]
127 # Compute normalized cross-correlation on aligned regions
128 return compute_correlation_quality_gpu(aligned_region1, aligned_region2)
131def compute_adaptive_threshold(correlations: "cp.ndarray") -> float: # type: ignore
132 """
133 Compute threshold using permutation test like ASHLAR.
135 Args:
136 correlations: Array of correlation values (CuPy array)
138 Returns:
139 Adaptive threshold value as float
140 """
141 _validate_cupy_array(correlations, "correlations")
143 # Sample random non-adjacent pairs for null distribution
144 # Use 99th percentile as threshold (following ASHLAR approach)
145 if len(correlations) == 0:
146 return 0.0
148 # For small arrays, use all values
149 if len(correlations) <= 100:
150 sample_correlations = correlations
151 else:
152 # Sample random subset for efficiency
153 n_samples = min(1000, len(correlations))
154 indices = cp.random.choice(len(correlations), size=n_samples, replace=False)
155 sample_correlations = correlations[indices]
157 # Use 99th percentile as adaptive threshold
158 threshold = cp.percentile(sample_correlations, 99.0)
160 return float(threshold)
163def estimate_stage_parameters(
164 displacements: "cp.ndarray", # type: ignore
165 expected_spacing: float
166) -> tuple[float, float]:
167 """
168 Estimate repeatability and backlash from measured displacements.
170 This implements MIST's key innovation for stage model estimation.
172 Args:
173 displacements: Array of measured displacements (CuPy array)
174 expected_spacing: Expected spacing between tiles
176 Returns:
177 Tuple of (repeatability, backlash) as floats
178 """
179 _validate_cupy_array(displacements, "displacements")
181 # Estimate repeatability as MAD (Median Absolute Deviation) of displacements
182 median_displacement = cp.median(displacements)
183 repeatability = cp.median(cp.abs(displacements - median_displacement))
185 # Estimate systematic bias (backlash)
186 backlash = cp.mean(displacements) - expected_spacing
188 return float(repeatability), float(backlash)
191def compute_adaptive_quality_threshold(
192 all_qualities: List[float],
193 base_threshold: float = 0.3,
194 percentile_threshold: float = 0.25
195) -> float:
196 """
197 Compute adaptive quality threshold based on distribution of correlation values.
199 Based on NIST stage model validation approach.
200 """
201 if not all_qualities:
202 return base_threshold
204 qualities_array = cp.array(all_qualities)
206 # Remove invalid correlations
207 valid_qualities = qualities_array[qualities_array >= 0]
209 if len(valid_qualities) == 0:
210 return base_threshold
212 # Use percentile-based threshold
213 percentile_value = float(cp.percentile(valid_qualities, percentile_threshold * 100))
215 # Ensure minimum threshold
216 adaptive_threshold = max(base_threshold, percentile_value)
218 return adaptive_threshold
221def validate_translation_consistency(
222 translations: List[Tuple[float, float, float]],
223 expected_spacing: Tuple[float, float],
224 tolerance_factor: float = 0.2,
225 min_quality: float = 0.3
226) -> List[bool]:
227 """
228 Validate translation consistency against expected grid spacing.
230 Based on NIST stage model validation.
231 """
232 expected_dx, expected_dy = expected_spacing
233 tolerance_dx = expected_dx * tolerance_factor
234 tolerance_dy = expected_dy * tolerance_factor
236 valid_flags = []
238 for dy, dx, quality in translations:
239 # Check if displacement is within expected range
240 dx_valid = abs(dx - expected_dx) <= tolerance_dx
241 dy_valid = abs(dy - expected_dy) <= tolerance_dy
242 quality_valid = quality >= min_quality # Minimum quality threshold
244 is_valid = dx_valid and dy_valid and quality_valid
245 valid_flags.append(is_valid)
247 return valid_flags
250def debug_phase_correlation_matrix(
251 correlation_matrix: "cp.ndarray",
252 peaks: List[Tuple[int, int, float]],
253 save_path: str = None
254) -> None:
255 """
256 Create visualization of phase correlation matrix with detected peaks.
257 """
258 try:
259 import matplotlib.pyplot as plt
260 except ImportError:
261 logging.warning("matplotlib not available, skipping correlation matrix visualization")
262 return
264 # Convert to CPU for visualization
265 corr_cpu = cp.asnumpy(correlation_matrix)
267 plt.figure(figsize=(10, 8))
268 plt.imshow(corr_cpu, cmap='hot', interpolation='nearest')
269 plt.colorbar(label='Correlation Value')
271 # Mark detected peaks
272 for i, (y, x, value) in enumerate(peaks):
273 plt.plot(x, y, 'bo', markersize=8, label=f'Peak {i+1}: {value:.3f}')
275 plt.legend()
276 plt.title('Phase Correlation Matrix with Detected Peaks')
277 plt.xlabel('X Coordinate')
278 plt.ylabel('Y Coordinate')
280 if save_path:
281 plt.savefig(save_path, dpi=150, bbox_inches='tight')
282 else:
283 plt.show()
285 plt.close()
288def log_coordinate_transformation(
289 original_dy: float, original_dx: float,
290 tile_dy: float, tile_dx: float,
291 direction: str,
292 tile_index: Tuple[int, int]
293) -> None:
294 """
295 Log coordinate transformation details for debugging.
296 """
297 logging.info(f"Coordinate Transform - Tile {tile_index}, Direction: {direction}")
298 logging.info(f" Original (overlap coords): dy={original_dy:.2f}, dx={original_dx:.2f}")
299 logging.info(f" Transformed (tile coords): dy={tile_dy:.2f}, dx={tile_dx:.2f}")
300 logging.info(f" Delta: dy_delta={tile_dy-original_dy:.2f}, dx_delta={tile_dx-original_dx:.2f}")
303def benchmark_phase_correlation_methods(
304 test_images: List[Tuple["cp.ndarray", "cp.ndarray"]],
305 methods: Dict[str, callable],
306 num_iterations: int = 10
307) -> Dict[str, Dict[str, float]]:
308 """
309 Benchmark different phase correlation methods for performance and accuracy.
310 """
311 import time
313 results = {}
315 for method_name, method_func in methods.items():
316 print(f"Benchmarking {method_name}...")
318 times = []
319 accuracies = []
321 for iteration in range(num_iterations):
322 start_time = time.time()
324 total_error = 0.0
325 num_pairs = 0
327 for img1, img2 in test_images:
328 try:
329 dy, dx = method_func(img1, img2)
330 # Compute error against known ground truth if available
331 # For now, just measure consistency
332 total_error += abs(dy) + abs(dx) # Placeholder
333 num_pairs += 1
334 except Exception as e:
335 print(f"Error in {method_name}: {e}")
336 continue
338 elapsed_time = time.time() - start_time
339 times.append(elapsed_time)
341 if num_pairs > 0:
342 avg_error = total_error / num_pairs
343 accuracies.append(avg_error)
345 results[method_name] = {
346 'avg_time': sum(times) / len(times),
347 'std_time': cp.std(cp.array(times)),
348 'avg_accuracy': sum(accuracies) / len(accuracies) if accuracies else float('inf'),
349 'std_accuracy': cp.std(cp.array(accuracies)) if len(accuracies) > 1 else 0.0
350 }
352 return results