Coverage for openhcs/processing/backends/pos_gen/mist/quality_metrics.py: 8.8%

153 statements  

« prev     ^ index     » next       coverage.py v7.10.3, created at 2025-08-14 05:57 +0000

1""" 

2Quality Metrics for MIST Algorithm 

3 

4Functions for computing correlation quality and adaptive thresholds. 

5""" 

6from __future__ import annotations 

7 

8from typing import TYPE_CHECKING, List, Tuple, Dict 

9import logging 

10 

11from openhcs.core.utils import optional_import 

12 

13# For type checking only 

14if TYPE_CHECKING: 14 ↛ 15line 14 didn't jump to line 15 because the condition on line 14 was never true

15 import cupy as cp 

16 

17# Import CuPy as an optional dependency 

18cp = optional_import("cupy") 

19 

20 

21def _validate_cupy_array(array, name: str = "input") -> None: # type: ignore 

22 """Validate that the input is a CuPy array.""" 

23 if not isinstance(array, cp.ndarray): 

24 raise TypeError(f"{name} must be a CuPy array, got {type(array)}") 

25 

26 

27def compute_correlation_quality_gpu(region1: "cp.ndarray", region2: "cp.ndarray") -> float: # type: ignore 

28 """GPU-only normalized cross-correlation quality metric.""" 

29 # Validate input regions have same shape 

30 if region1.shape != region2.shape: 

31 return 0.0 

32 

33 # Check for empty or single-pixel regions 

34 if region1.size <= 1: 

35 return 0.0 

36 

37 r1_flat = region1.flatten() 

38 r2_flat = region2.flatten() 

39 

40 # Normalize (GPU) 

41 r1_mean = cp.mean(r1_flat) 

42 r2_mean = cp.mean(r2_flat) 

43 r1_norm = r1_flat - r1_mean 

44 r2_norm = r2_flat - r2_mean 

45 

46 # Correlation (GPU) 

47 numerator = cp.sum(r1_norm * r2_norm) 

48 denom1 = cp.sqrt(cp.sum(r1_norm ** 2)) 

49 denom2 = cp.sqrt(cp.sum(r2_norm ** 2)) 

50 

51 # Avoid division by zero with more robust threshold (GPU) 

52 eps = cp.finfo(cp.float32).eps * 1000.0 

53 correlation = cp.where((denom1 > eps) & (denom2 > eps), 

54 cp.abs(numerator / (denom1 * denom2)), 

55 cp.float32(0.0)) 

56 

57 return float(correlation) 

58 

59 

60def compute_correlation_quality_gpu_aligned(region1: "cp.ndarray", region2: "cp.ndarray", dx: float, dy: float) -> float: # type: ignore 

61 """ 

62 GPU-only normalized cross-correlation quality metric after applying computed shift. 

63 

64 This measures how well the regions align after applying the phase correlation shift. 

65 """ 

66 # Convert shifts to integer pixels for alignment 

67 shift_x = int(round(dx)) 

68 shift_y = int(round(dy)) 

69 

70 # Get region dimensions 

71 h1, w1 = region1.shape 

72 h2, w2 = region2.shape 

73 

74 # Calculate overlap region after applying shift 

75 # For horizontal alignment: region1 is left, region2 is right 

76 # For vertical alignment: region1 is top, region2 is bottom 

77 

78 # Determine overlap bounds considering the shift 

79 if abs(shift_x) >= min(w1, w2) or abs(shift_y) >= min(h1, h2): 

80 # No overlap after shift 

81 return 0.0 

82 

83 # Calculate actual overlap region 

84 if shift_x >= 0: 

85 # region2 shifted right 

86 x1_start = max(0, shift_x) 

87 x1_end = min(w1, w2 + shift_x) 

88 x2_start = max(0, -shift_x) 

89 x2_end = min(w2, w1 - shift_x) 

90 else: 

91 # region2 shifted left 

92 x1_start = max(0, -shift_x) 

93 x1_end = min(w1, w2 - shift_x) 

94 x2_start = max(0, shift_x) 

95 x2_end = min(w2, w1 + shift_x) 

96 

97 if shift_y >= 0: 

98 # region2 shifted down 

99 y1_start = max(0, shift_y) 

100 y1_end = min(h1, h2 + shift_y) 

101 y2_start = max(0, -shift_y) 

102 y2_end = min(h2, h1 - shift_y) 

103 else: 

104 # region2 shifted up 

105 y1_start = max(0, -shift_y) 

106 y1_end = min(h1, h2 - shift_y) 

107 y2_start = max(0, shift_y) 

108 y2_end = min(h2, h1 + shift_y) 

109 

110 # Extract aligned overlap regions 

111 if x1_end <= x1_start or y1_end <= y1_start or x2_end <= x2_start or y2_end <= y2_start: 

112 return 0.0 

113 

114 aligned_region1 = region1[y1_start:y1_end, x1_start:x1_end] 

115 aligned_region2 = region2[y2_start:y2_end, x2_start:x2_end] 

116 

117 # Ensure regions have the same size 

118 min_h = min(aligned_region1.shape[0], aligned_region2.shape[0]) 

119 min_w = min(aligned_region1.shape[1], aligned_region2.shape[1]) 

120 

121 if min_h <= 0 or min_w <= 0: 

122 return 0.0 

123 

124 aligned_region1 = aligned_region1[:min_h, :min_w] 

125 aligned_region2 = aligned_region2[:min_h, :min_w] 

126 

127 # Compute normalized cross-correlation on aligned regions 

128 return compute_correlation_quality_gpu(aligned_region1, aligned_region2) 

129 

130 

131def compute_adaptive_threshold(correlations: "cp.ndarray") -> float: # type: ignore 

132 """ 

133 Compute threshold using permutation test like ASHLAR. 

134 

135 Args: 

136 correlations: Array of correlation values (CuPy array) 

137 

138 Returns: 

139 Adaptive threshold value as float 

140 """ 

141 _validate_cupy_array(correlations, "correlations") 

142 

143 # Sample random non-adjacent pairs for null distribution 

144 # Use 99th percentile as threshold (following ASHLAR approach) 

145 if len(correlations) == 0: 

146 return 0.0 

147 

148 # For small arrays, use all values 

149 if len(correlations) <= 100: 

150 sample_correlations = correlations 

151 else: 

152 # Sample random subset for efficiency 

153 n_samples = min(1000, len(correlations)) 

154 indices = cp.random.choice(len(correlations), size=n_samples, replace=False) 

155 sample_correlations = correlations[indices] 

156 

157 # Use 99th percentile as adaptive threshold 

158 threshold = cp.percentile(sample_correlations, 99.0) 

159 

160 return float(threshold) 

161 

162 

163def estimate_stage_parameters( 

164 displacements: "cp.ndarray", # type: ignore 

165 expected_spacing: float 

166) -> tuple[float, float]: 

167 """ 

168 Estimate repeatability and backlash from measured displacements. 

169 

170 This implements MIST's key innovation for stage model estimation. 

171 

172 Args: 

173 displacements: Array of measured displacements (CuPy array) 

174 expected_spacing: Expected spacing between tiles 

175 

176 Returns: 

177 Tuple of (repeatability, backlash) as floats 

178 """ 

179 _validate_cupy_array(displacements, "displacements") 

180 

181 # Estimate repeatability as MAD (Median Absolute Deviation) of displacements 

182 median_displacement = cp.median(displacements) 

183 repeatability = cp.median(cp.abs(displacements - median_displacement)) 

184 

185 # Estimate systematic bias (backlash) 

186 backlash = cp.mean(displacements) - expected_spacing 

187 

188 return float(repeatability), float(backlash) 

189 

190 

191def compute_adaptive_quality_threshold( 

192 all_qualities: List[float], 

193 base_threshold: float = 0.3, 

194 percentile_threshold: float = 0.25 

195) -> float: 

196 """ 

197 Compute adaptive quality threshold based on distribution of correlation values. 

198 

199 Based on NIST stage model validation approach. 

200 """ 

201 if not all_qualities: 

202 return base_threshold 

203 

204 qualities_array = cp.array(all_qualities) 

205 

206 # Remove invalid correlations 

207 valid_qualities = qualities_array[qualities_array >= 0] 

208 

209 if len(valid_qualities) == 0: 

210 return base_threshold 

211 

212 # Use percentile-based threshold 

213 percentile_value = float(cp.percentile(valid_qualities, percentile_threshold * 100)) 

214 

215 # Ensure minimum threshold 

216 adaptive_threshold = max(base_threshold, percentile_value) 

217 

218 return adaptive_threshold 

219 

220 

221def validate_translation_consistency( 

222 translations: List[Tuple[float, float, float]], 

223 expected_spacing: Tuple[float, float], 

224 tolerance_factor: float = 0.2, 

225 min_quality: float = 0.3 

226) -> List[bool]: 

227 """ 

228 Validate translation consistency against expected grid spacing. 

229 

230 Based on NIST stage model validation. 

231 """ 

232 expected_dx, expected_dy = expected_spacing 

233 tolerance_dx = expected_dx * tolerance_factor 

234 tolerance_dy = expected_dy * tolerance_factor 

235 

236 valid_flags = [] 

237 

238 for dy, dx, quality in translations: 

239 # Check if displacement is within expected range 

240 dx_valid = abs(dx - expected_dx) <= tolerance_dx 

241 dy_valid = abs(dy - expected_dy) <= tolerance_dy 

242 quality_valid = quality >= min_quality # Minimum quality threshold 

243 

244 is_valid = dx_valid and dy_valid and quality_valid 

245 valid_flags.append(is_valid) 

246 

247 return valid_flags 

248 

249 

250def debug_phase_correlation_matrix( 

251 correlation_matrix: "cp.ndarray", 

252 peaks: List[Tuple[int, int, float]], 

253 save_path: str = None 

254) -> None: 

255 """ 

256 Create visualization of phase correlation matrix with detected peaks. 

257 """ 

258 try: 

259 import matplotlib.pyplot as plt 

260 except ImportError: 

261 logging.warning("matplotlib not available, skipping correlation matrix visualization") 

262 return 

263 

264 # Convert to CPU for visualization 

265 corr_cpu = cp.asnumpy(correlation_matrix) 

266 

267 plt.figure(figsize=(10, 8)) 

268 plt.imshow(corr_cpu, cmap='hot', interpolation='nearest') 

269 plt.colorbar(label='Correlation Value') 

270 

271 # Mark detected peaks 

272 for i, (y, x, value) in enumerate(peaks): 

273 plt.plot(x, y, 'bo', markersize=8, label=f'Peak {i+1}: {value:.3f}') 

274 

275 plt.legend() 

276 plt.title('Phase Correlation Matrix with Detected Peaks') 

277 plt.xlabel('X Coordinate') 

278 plt.ylabel('Y Coordinate') 

279 

280 if save_path: 

281 plt.savefig(save_path, dpi=150, bbox_inches='tight') 

282 else: 

283 plt.show() 

284 

285 plt.close() 

286 

287 

288def log_coordinate_transformation( 

289 original_dy: float, original_dx: float, 

290 tile_dy: float, tile_dx: float, 

291 direction: str, 

292 tile_index: Tuple[int, int] 

293) -> None: 

294 """ 

295 Log coordinate transformation details for debugging. 

296 """ 

297 logging.info(f"Coordinate Transform - Tile {tile_index}, Direction: {direction}") 

298 logging.info(f" Original (overlap coords): dy={original_dy:.2f}, dx={original_dx:.2f}") 

299 logging.info(f" Transformed (tile coords): dy={tile_dy:.2f}, dx={tile_dx:.2f}") 

300 logging.info(f" Delta: dy_delta={tile_dy-original_dy:.2f}, dx_delta={tile_dx-original_dx:.2f}") 

301 

302 

303def benchmark_phase_correlation_methods( 

304 test_images: List[Tuple["cp.ndarray", "cp.ndarray"]], 

305 methods: Dict[str, callable], 

306 num_iterations: int = 10 

307) -> Dict[str, Dict[str, float]]: 

308 """ 

309 Benchmark different phase correlation methods for performance and accuracy. 

310 """ 

311 import time 

312 

313 results = {} 

314 

315 for method_name, method_func in methods.items(): 

316 print(f"Benchmarking {method_name}...") 

317 

318 times = [] 

319 accuracies = [] 

320 

321 for iteration in range(num_iterations): 

322 start_time = time.time() 

323 

324 total_error = 0.0 

325 num_pairs = 0 

326 

327 for img1, img2 in test_images: 

328 try: 

329 dy, dx = method_func(img1, img2) 

330 # Compute error against known ground truth if available 

331 # For now, just measure consistency 

332 total_error += abs(dy) + abs(dx) # Placeholder 

333 num_pairs += 1 

334 except Exception as e: 

335 print(f"Error in {method_name}: {e}") 

336 continue 

337 

338 elapsed_time = time.time() - start_time 

339 times.append(elapsed_time) 

340 

341 if num_pairs > 0: 

342 avg_error = total_error / num_pairs 

343 accuracies.append(avg_error) 

344 

345 results[method_name] = { 

346 'avg_time': sum(times) / len(times), 

347 'std_time': cp.std(cp.array(times)), 

348 'avg_accuracy': sum(accuracies) / len(accuracies) if accuracies else float('inf'), 

349 'std_accuracy': cp.std(cp.array(accuracies)) if len(accuracies) > 1 else 0.0 

350 } 

351 

352 return results