Coverage for openhcs/processing/backends/pos_gen/ashlar_main_cpu.py: 77.5%

401 statements  

« prev     ^ index     » next       coverage.py v7.11.0, created at 2025-11-04 02:09 +0000

1""" 

2OpenHCS Interface for Ashlar CPU Stitching Algorithm 

3 

4Array-based EdgeAligner implementation that works directly with numpy arrays 

5instead of file-based readers. This is the complete Ashlar algorithm modified 

6to accept arrays directly. 

7""" 

8from __future__ import annotations 

9import logging 

10import sys 

11from typing import TYPE_CHECKING, Tuple, List 

12import numpy as np 

13import networkx as nx 

14import scipy.spatial.distance 

15import sklearn.linear_model 

16import pandas as pd 

17 

18from openhcs.core.pipeline.function_contracts import special_inputs, special_outputs 

19from openhcs.core.memory.decorators import numpy as numpy_func 

20 

21import warnings 

22 

23if TYPE_CHECKING: 23 ↛ 24line 23 didn't jump to line 24 because the condition on line 23 was never true

24 pass 

25 

26logger = logging.getLogger(__name__) 

27 

28 

29class DataWarning(Warning): 

30 """Warnings about the content of user-provided image data.""" 

31 pass 

32 

33 

34def warn_data(message): 

35 """Issue a warning about image data.""" 

36 warnings.warn(message, DataWarning) 

37 

38 

39class Intersection: 

40 """Calculate intersection region between two tiles (extracted from Ashlar).""" 

41 

42 def __init__(self, corners1, corners2, min_size=0): 

43 if np.isscalar(min_size): 

44 min_size = np.repeat(min_size, 2) 

45 self._calculate(corners1, corners2, min_size) 

46 

47 def _calculate(self, corners1, corners2, min_size): 

48 """Calculate intersection parameters with robust boundary validation.""" 

49 # corners1 and corners2 are arrays of shape (2, 2) containing 

50 # the upper-left and lower-right corners of the two tiles 

51 max_shape = (corners2 - corners1).max(axis=0) 

52 min_size = min_size.clip(1, max_shape) 

53 position = corners1.max(axis=0) 

54 initial_shape = np.floor(corners2.min(axis=0) - position).astype(int) 

55 clipped_shape = np.maximum(initial_shape, min_size) 

56 self.shape = np.ceil(clipped_shape).astype(int) 

57 self.padding = self.shape - initial_shape 

58 

59 # Calculate offsets with boundary validation 

60 raw_offsets = np.maximum(position - corners1 - self.padding, 0) 

61 

62 # Validate that offsets + shape don't exceed tile boundaries 

63 tile_sizes = corners2 - corners1 

64 for i in range(2): 

65 # Ensure offset + shape <= tile_size for each tile 

66 max_offset = tile_sizes[i] - self.shape 

67 raw_offsets[i] = np.minimum(raw_offsets[i], np.maximum(max_offset, 0)) 

68 

69 # Ensure shape doesn't exceed available space 

70 available_space = tile_sizes[i] - raw_offsets[i] 

71 self.shape = np.minimum(self.shape, available_space.astype(int)) 

72 

73 # Final validation - ensure shape is positive 

74 self.shape = np.maximum(self.shape, 1) 

75 

76 self.offsets = raw_offsets.astype(int) 

77 

78 # Calculate fractional offset difference for subpixel accuracy 

79 offset_diff = self.offsets[1] - self.offsets[0] 

80 self.offset_diff_frac = offset_diff - offset_diff.round() 

81 

82 

83def _get_window(shape): 

84 """Build a 2D Hann window (from Ashlar utils.get_window).""" 

85 # Build a 2D Hann window by taking the outer product of two 1-D windows. 

86 wy = np.hanning(shape[0]).astype(np.float32) 

87 wx = np.hanning(shape[1]).astype(np.float32) 

88 window = np.outer(wy, wx) 

89 return window 

90 

91 

92def ashlar_register_no_preprocessing(img1, img2, upsample=10): 

93 """ 

94 Robust Ashlar register function with comprehensive input validation. 

95 

96 This is based on ashlar.utils.register() but adds validation to handle 

97 edge cases that can occur with real microscopy data. 

98 """ 

99 import itertools 

100 import scipy.ndimage 

101 import skimage.registration 

102 

103 # Input validation 

104 if img1 is None or img2 is None: 104 ↛ 105line 104 didn't jump to line 105 because the condition on line 104 was never true

105 return np.array([0.0, 0.0]), np.inf 

106 

107 if img1.size == 0 or img2.size == 0: 107 ↛ 108line 107 didn't jump to line 108 because the condition on line 107 was never true

108 return np.array([0.0, 0.0]), np.inf 

109 

110 if img1.shape != img2.shape: 110 ↛ 111line 110 didn't jump to line 111 because the condition on line 110 was never true

111 return np.array([0.0, 0.0]), np.inf 

112 

113 if len(img1.shape) != 2: 113 ↛ 114line 113 didn't jump to line 114 because the condition on line 113 was never true

114 return np.array([0.0, 0.0]), np.inf 

115 

116 if img1.shape[0] < 1 or img1.shape[1] < 1: 116 ↛ 117line 116 didn't jump to line 117 because the condition on line 116 was never true

117 return np.array([0.0, 0.0]), np.inf 

118 

119 # Convert to float32 (equivalent to what whiten() does) - match GPU version 

120 img1w = img1.astype(np.float32) 

121 img2w = img2.astype(np.float32) 

122 

123 # Apply windowing function (from original Ashlar) 

124 img1w = img1w * _get_window(img1w.shape) 

125 img2w = img2w * _get_window(img2w.shape) 

126 

127 # Use skimage's phase cross correlation with error handling 

128 try: 

129 shift = skimage.registration.phase_cross_correlation( 

130 img1w, 

131 img2w, 

132 upsample_factor=upsample, 

133 normalization=None 

134 )[0] 

135 except Exception as e: 

136 # If phase correlation fails, return large error 

137 logger.error(f"Ashlar CPU: PHASE CORRELATION FAILED - Exception: {e}") 

138 logger.error(" Returning infinite error") 

139 return np.array([0.0, 0.0]), np.inf 

140 

141 # At this point we may have a shift in the wrong quadrant since the FFT 

142 # assumes the signal is periodic. We test all four possibilities and return 

143 # the shift that gives the highest direct correlation (sum of products). 

144 shape = np.array(img1.shape) 

145 shift_pos = (shift + shape) % shape 

146 shift_neg = shift_pos - shape 

147 shifts = list(itertools.product(*zip(shift_pos, shift_neg))) 

148 correlations = [] 

149 for s in shifts: 

150 try: 

151 shifted_img = scipy.ndimage.shift(img2w, s, order=0) 

152 corr = np.abs(np.sum(img1w * shifted_img)) 

153 correlations.append(corr) 

154 except Exception: 

155 correlations.append(0.0) 

156 

157 if not correlations or max(correlations) == 0: 157 ↛ 158line 157 didn't jump to line 158 because the condition on line 157 was never true

158 logger.warning("Ashlar CPU: NO VALID CORRELATIONS - All correlations failed or zero") 

159 return np.array([0.0, 0.0]), np.inf 

160 

161 idx = np.argmax(correlations) 

162 shift = shifts[idx] 

163 correlation = correlations[idx] 

164 total_amplitude = np.linalg.norm(img1w) * np.linalg.norm(img2w) 

165 if correlation > 0 and total_amplitude > 0: 165 ↛ 168line 165 didn't jump to line 168 because the condition on line 165 was always true

166 error = -np.log(correlation / total_amplitude) 

167 else: 

168 error = np.inf 

169 

170 # Log all correlation results at INFO level for user visibility 

171 if error > 1.0: # High error threshold for Ashlar 171 ↛ 172line 171 didn't jump to line 172 because the condition on line 171 was never true

172 logger.warning(f"Ashlar CPU: HIGH CORRELATION ERROR - Error={error:.4f}, Shift=({shift[0]:.2f}, {shift[1]:.2f})") 

173 logger.warning(" This indicates poor overlap or image quality between tiles") 

174 else: 

175 logger.info(f"Ashlar CPU: Correlation - Error={error:.4f}, Shift=({shift[0]:.2f}, {shift[1]:.2f})") 

176 

177 return shift, error 

178 

179 

180def ashlar_nccw_no_preprocessing(img1, img2): 

181 """ 

182 Robust Ashlar nccw function with comprehensive input validation. 

183 

184 This is based on ashlar.utils.nccw() but adds validation to handle 

185 edge cases that can occur with real microscopy data. 

186 """ 

187 # Input validation 

188 if img1 is None or img2 is None: 188 ↛ 189line 188 didn't jump to line 189 because the condition on line 188 was never true

189 return np.inf 

190 

191 if img1.size == 0 or img2.size == 0: 191 ↛ 192line 191 didn't jump to line 192 because the condition on line 191 was never true

192 return np.inf 

193 

194 if img1.shape != img2.shape: 194 ↛ 195line 194 didn't jump to line 195 because the condition on line 194 was never true

195 return np.inf 

196 

197 if len(img1.shape) != 2: 197 ↛ 198line 197 didn't jump to line 198 because the condition on line 197 was never true

198 return np.inf 

199 

200 if img1.shape[0] < 1 or img1.shape[1] < 1: 200 ↛ 201line 200 didn't jump to line 201 because the condition on line 200 was never true

201 return np.inf 

202 

203 # Convert to float32 (equivalent to what whiten() does) - match GPU version 

204 img1w = img1.astype(np.float32) 

205 img2w = img2.astype(np.float32) 

206 

207 correlation = np.abs(np.sum(img1w * img2w)) 

208 total_amplitude = np.linalg.norm(img1w) * np.linalg.norm(img2w) 

209 if correlation > 0 and total_amplitude > 0: 209 ↛ 223line 209 didn't jump to line 223 because the condition on line 209 was always true

210 diff = correlation - total_amplitude 

211 if diff <= 0: 

212 error = -np.log(correlation / total_amplitude) 

213 elif diff < 1e-3: # Increased tolerance for robustness 213 ↛ 217line 213 didn't jump to line 217 because the condition on line 213 was never true

214 # This situation can occur due to numerical precision issues when 

215 # img1 and img2 are very nearly or exactly identical. If the 

216 # difference is small enough, let it slide. 

217 error = 0 

218 else: 

219 # Instead of raising error, return large but finite error 

220 logger.warning(f"Ashlar CPU: NCCW numerical precision issue - diff={diff:.6f}, using error=100.0") 

221 error = 100.0 # Large error but not infinite 

222 else: 

223 logger.warning(f"Ashlar CPU: NCCW invalid correlation - correlation={correlation:.6f}, total_amplitude={total_amplitude:.6f}") 

224 error = np.inf 

225 

226 # Log all NCCW results at INFO level for user visibility 

227 if error > 10.0: # High NCCW error threshold 

228 logger.warning(f"Ashlar CPU: HIGH NCCW ERROR - Error={error:.4f}") 

229 logger.warning(" This indicates poor image correlation between tiles") 

230 else: 

231 logger.info(f"Ashlar CPU: NCCW - Error={error:.4f}") 

232 

233 return error 

234 

235 

236def ashlar_crop(img, offset, shape): 

237 """ 

238 Robust Ashlar crop function with comprehensive boundary validation. 

239 

240 This is based on ashlar.utils.crop() but adds validation to handle 

241 edge cases that can occur with real microscopy data. 

242 """ 

243 # Input validation 

244 if img is None or img.size == 0: 244 ↛ 245line 244 didn't jump to line 245 because the condition on line 244 was never true

245 raise ValueError("Cannot crop from empty or None image") 

246 

247 # Convert to integers and validate 

248 start = offset.round().astype(int) 

249 shape = np.round(shape).astype(int) 

250 

251 # Ensure start is non-negative 

252 start = np.maximum(start, 0) 

253 

254 # Ensure shape is positive 

255 shape = np.maximum(shape, 1) 

256 

257 # Validate bounds 

258 img_height, img_width = img.shape[:2] 

259 end = start + shape 

260 

261 # Clamp to image boundaries 

262 start[0] = min(start[0], img_height - 1) 

263 start[1] = min(start[1], img_width - 1) 

264 end[0] = min(end[0], img_height) 

265 end[1] = min(end[1], img_width) 

266 

267 # Ensure we have a valid region 

268 if end[0] <= start[0] or end[1] <= start[1]: 268 ↛ 270line 268 didn't jump to line 270 because the condition on line 268 was never true

269 # Return minimum valid region if bounds are invalid 

270 return img[start[0]:start[0]+1, start[1]:start[1]+1] 

271 

272 return img[start[0]:end[0], start[1]:end[1]] 

273 

274 

275class ArrayEdgeAligner: 

276 """ 

277 Array-based EdgeAligner that implements the complete Ashlar algorithm 

278 but works directly with numpy arrays instead of file readers. 

279 """ 

280 

281 def __init__(self, image_stack, positions, tile_size, pixel_size=1.0, 

282 max_shift=15, alpha=0.01, max_error=None, 

283 randomize=False, verbose=False): 

284 """ 

285 Initialize array-based EdgeAligner for pure position calculation. 

286 

287 Args: 

288 image_stack: 3D numpy array (num_tiles, height, width) - preprocessed grayscale 

289 positions: 2D array of tile positions (num_tiles, 2) in pixels 

290 tile_size: Array [height, width] of tile dimensions 

291 pixel_size: Pixel size in micrometers (for max_shift conversion) 

292 max_shift: Maximum allowed shift in micrometers 

293 alpha: Alpha value for error threshold (lower = stricter) 

294 max_error: Explicit error threshold (None = auto-compute) 

295 randomize: Use random seed for permutation testing 

296 verbose: Enable verbose logging 

297 """ 

298 self.image_stack = image_stack 

299 self.positions = positions.astype(float) 

300 self.tile_size = np.array(tile_size) 

301 self.pixel_size = pixel_size 

302 self.max_shift = max_shift 

303 self.max_shift_pixels = self.max_shift / self.pixel_size 

304 self.alpha = alpha 

305 self.max_error = max_error 

306 self.randomize = randomize 

307 self.verbose = verbose 

308 self._cache = {} 

309 self.errors_negative_sampled = np.empty(0) 

310 

311 # Build neighbors graph 

312 self.neighbors_graph = self._build_neighbors_graph() 

313 

314 def _build_neighbors_graph(self): 

315 """Build graph of neighboring (overlapping) tiles.""" 

316 pdist = scipy.spatial.distance.pdist(self.positions, metric='cityblock') 

317 sp = scipy.spatial.distance.squareform(pdist) 

318 max_distance = self.tile_size.max() + 1 

319 edges = zip(*np.nonzero((sp > 0) & (sp < max_distance))) 

320 graph = nx.from_edgelist(edges) 

321 graph.add_nodes_from(range(len(self.positions))) 

322 return graph 

323 

324 

325 def run(self): 

326 """Run the complete Ashlar algorithm.""" 

327 self.check_overlaps() 

328 self.compute_threshold() 

329 self.register_all() 

330 self.build_spanning_tree() 

331 self.calculate_positions() 

332 self.fit_model() 

333 

334 def check_overlaps(self): 

335 """Check if tiles actually overlap based on positions.""" 

336 overlaps = np.array([ 

337 self.tile_size - abs(self.positions[t1] - self.positions[t2]) 

338 for t1, t2 in self.neighbors_graph.edges 

339 ]) 

340 failures = np.any(overlaps < 1, axis=1) if len(overlaps) else [] 

341 if len(failures) and all(failures): 341 ↛ 342line 341 didn't jump to line 342 because the condition on line 341 was never true

342 warn_data("No tiles overlap, attempting alignment anyway.") 

343 elif any(failures): 343 ↛ 344line 343 didn't jump to line 344 because the condition on line 343 was never true

344 warn_data("Some neighboring tiles have zero overlap.") 

345 

346 def compute_threshold(self): 

347 """Compute error threshold using permutation testing.""" 

348 if self.max_error is not None: 348 ↛ 349line 348 didn't jump to line 349 because the condition on line 348 was never true

349 if self.verbose: 

350 print(" using explicit error threshold") 

351 return 

352 

353 edges = self.neighbors_graph.edges 

354 num_tiles = len(self.image_stack) 

355 

356 # If not enough tiles overlap to matter, skip this whole thing 

357 if len(edges) <= 1: 357 ↛ 358line 357 didn't jump to line 358 because the condition on line 357 was never true

358 self.max_error = np.inf 

359 return 

360 

361 widths = np.array([ 

362 self.intersection(t1, t2).shape.min() 

363 for t1, t2 in edges 

364 ]) 

365 w = widths.max() 

366 max_offset = self.tile_size[0] - w 

367 

368 # Number of possible pairs minus number of actual neighbor pairs 

369 num_distant_pairs = num_tiles * (num_tiles - 1) // 2 - len(edges) 

370 

371 # Reduce permutation count for small datasets 

372 n = 1000 if num_distant_pairs > 8 else (num_distant_pairs + 1) * 10 

373 pairs = np.empty((n, 2), dtype=int) 

374 offsets = np.empty((n, 2), dtype=int) 

375 

376 # Generate n random non-overlapping image strips 

377 max_tries = 100 

378 if self.randomize is False: 378 ↛ 381line 378 didn't jump to line 381 because the condition on line 378 was always true

379 random_state = np.random.RandomState(0) 

380 else: 

381 random_state = np.random.RandomState() 

382 

383 for i in range(n): 

384 # Limit tries to avoid infinite loop in pathological cases 

385 for current_try in range(max_tries): 385 ↛ 409line 385 didn't jump to line 409 because the loop on line 385 didn't complete

386 t1, t2 = random_state.randint(num_tiles, size=2) 

387 o1, o2 = random_state.randint(max_offset, size=2) 

388 

389 # Check for non-overlapping strips and abort the retry loop 

390 if t1 != t2 and (t1, t2) not in edges: 

391 # Different, non-neighboring tiles -- always OK 

392 break 

393 elif t1 == t2 and abs(o1 - o2) > w: 

394 # Same tile OK if strips don't overlap within the image 

395 break 

396 elif (t1, t2) in edges: 

397 # Neighbors OK if either strip is entirely outside the 

398 # expected overlap region (based on nominal positions) 

399 its = self.intersection(t1, t2, np.repeat(w, 2)) 

400 ioff1, ioff2 = its.offsets[:, 0] 

401 if ( 

402 its.shape[0] > its.shape[1] 

403 or o1 < ioff1 - w or o1 > ioff1 + w 

404 or o2 < ioff2 - w or o2 > ioff2 + w 

405 ): 

406 break 

407 else: 

408 # Retries exhausted. This should be very rare. 

409 warn_data(f"Could not find non-overlapping strips in {max_tries} tries") 

410 pairs[i] = t1, t2 

411 offsets[i] = o1, o2 

412 

413 errors = np.empty(n) 

414 for i, ((t1, t2), (offset1, offset2)) in enumerate(zip(pairs, offsets)): 

415 # if self.verbose and (i % 10 == 9 or i == n - 1): 

416 # sys.stdout.write(f'\r quantifying alignment error {i + 1}/{n}') 

417 # sys.stdout.flush() 

418 img1 = self.image_stack[t1][offset1:offset1+w, :] 

419 img2 = self.image_stack[t2][offset2:offset2+w, :] 

420 _, errors[i] = ashlar_register_no_preprocessing(img1, img2, upsample=1) 

421 # if self.verbose: 

422 # print() 

423 self.errors_negative_sampled = errors 

424 self.max_error = np.percentile(errors, self.alpha * 100) 

425 

426 

427 def register_all(self): 

428 """Register all neighboring tile pairs.""" 

429 n = self.neighbors_graph.size() 

430 for i, (t1, t2) in enumerate(self.neighbors_graph.edges, 1): 

431 if self.verbose: 431 ↛ 432line 431 didn't jump to line 432 because the condition on line 431 was never true

432 sys.stdout.write(f'\r aligning edge {i}/{n}') 

433 sys.stdout.flush() 

434 self.register_pair(t1, t2) 

435 if self.verbose: 435 ↛ 436line 435 didn't jump to line 436 because the condition on line 435 was never true

436 print() 

437 self.all_errors = np.array([x[1] for x in self._cache.values()]) 

438 

439 # Set error values above the threshold to infinity 

440 for k, v in self._cache.items(): 

441 if v[1] > self.max_error or any(np.abs(v[0]) > self.max_shift_pixels): 

442 self._cache[k] = (v[0], np.inf) 

443 

444 def register_pair(self, t1, t2): 

445 """Return relative shift between images and the alignment error.""" 

446 key = tuple(sorted((t1, t2))) 

447 try: 

448 shift, error = self._cache[key] 

449 except KeyError: 

450 # Test a series of increasing overlap window sizes to help avoid 

451 # missing alignments when the stage position error is large relative 

452 # to the tile overlap. Simply using a large overlap in all cases 

453 # limits the maximum achievable correlation thus increasing the 

454 # error metric, leading to worse overall results. The window size 

455 # starts at the nominal size and doubles until it's at least 10% of 

456 # the tile size. If the nominal overlap is already 10% or greater, 

457 # we only use that one size. 

458 smin = self.intersection(key[0], key[1]).shape 

459 smax = np.round(self.tile_size * 0.1) 

460 sizes = [smin] 

461 while any(sizes[-1] < smax): 

462 sizes.append(sizes[-1] * 2) 

463 # Test each window size with validation 

464 results = [] 

465 for s in sizes: 

466 try: 

467 result = self._register(key[0], key[1], s) 

468 results.append(result) 

469 except Exception: 

470 # If this window size fails, use infinite error 

471 results.append((np.array([0.0, 0.0]), np.inf)) 

472 # Use the shift from the window size that gave the lowest error 

473 shift, _ = min(results, key=lambda r: r[1]) 

474 # Extract the images from the nominal overlap window but with the 

475 # shift applied to the second tile's position, and compute the error 

476 # metric on these images. This should be even lower than the error 

477 # computed above. 

478 _, o1, o2 = self.overlap(key[0], key[1], shift=shift) 

479 error = ashlar_nccw_no_preprocessing(o1, o2) 

480 self._cache[key] = (shift, error) 

481 if t1 > t2: 

482 shift = -shift 

483 # Return copy of shift to prevent corruption of cached values 

484 return shift.copy(), error 

485 

486 def _register(self, t1, t2, min_size=0): 

487 """Register a single tile pair with given minimum size.""" 

488 its, img1, img2 = self.overlap(t1, t2, min_size) 

489 # Account for padding, flipping the sign depending on the direction 

490 # between the tiles 

491 p1, p2 = self.positions[[t1, t2]] 

492 sx = 1 if p1[1] >= p2[1] else -1 

493 sy = 1 if p1[0] >= p2[0] else -1 

494 padding = its.padding * [sy, sx] 

495 shift, error = ashlar_register_no_preprocessing(img1, img2) 

496 shift += padding 

497 return shift, error 

498 

499 

500 def intersection(self, t1, t2, min_size=0, shift=None): 

501 """Calculate intersection region between two tiles.""" 

502 corners1 = self.positions[[t1, t2]] 

503 if shift is not None: 

504 corners1[1] += shift 

505 corners2 = corners1 + self.tile_size 

506 return Intersection(corners1, corners2, min_size) 

507 

508 def crop(self, tile_id, offset, shape): 

509 """Crop image from tile at given offset and shape.""" 

510 img = self.image_stack[tile_id] 

511 return ashlar_crop(img, offset, shape) 

512 

513 def overlap(self, t1, t2, min_size=0, shift=None): 

514 """Extract overlapping regions between two tiles.""" 

515 its = self.intersection(t1, t2, min_size, shift) 

516 img1 = self.crop(t1, its.offsets[0], its.shape) 

517 img2 = self.crop(t2, its.offsets[1], its.shape) 

518 return its, img1, img2 

519 

520 

521 

522 

523 

524 def build_spanning_tree(self): 

525 """Build minimum spanning tree from registered edges.""" 

526 g = nx.Graph() 

527 g.add_nodes_from(self.neighbors_graph) 

528 g.add_weighted_edges_from( 

529 (t1, t2, error) 

530 for (t1, t2), (_, error) in self._cache.items() 

531 if np.isfinite(error) 

532 ) 

533 spanning_tree = nx.Graph() 

534 spanning_tree.add_nodes_from(g) 

535 for c in nx.connected_components(g): 

536 cc = g.subgraph(c) 

537 center = nx.center(cc)[0] 

538 paths = nx.single_source_dijkstra_path(cc, center).values() 

539 for path in paths: 

540 nx.add_path(spanning_tree, path) 

541 self.spanning_tree = spanning_tree 

542 

543 def calculate_positions(self): 

544 """Calculate final positions from spanning tree.""" 

545 shifts = {} 

546 for c in nx.connected_components(self.spanning_tree): 

547 cc = self.spanning_tree.subgraph(c) 

548 center = nx.center(cc)[0] 

549 shifts[center] = np.array([0, 0]) 

550 for edge in nx.traversal.bfs_edges(cc, center): 

551 source, dest = edge 

552 if source not in shifts: 552 ↛ 553line 552 didn't jump to line 553 because the condition on line 552 was never true

553 source, dest = dest, source 

554 shift = self.register_pair(source, dest)[0] 

555 shifts[dest] = shifts[source] + shift 

556 if shifts: 

557 self.shifts = np.array([s for _, s in sorted(shifts.items())]) 

558 self.final_positions = self.positions + self.shifts 

559 else: 

560 # TODO: fill in shifts and positions with 0x2 arrays 

561 raise NotImplementedError("No images") 

562 

563 

564 def fit_model(self): 

565 """Fit linear model to handle disconnected components.""" 

566 components = sorted( 

567 nx.connected_components(self.spanning_tree), 

568 key=len, reverse=True 

569 ) 

570 # Fit LR model on positions of largest connected component 

571 cc0 = list(components[0]) 

572 self.lr = sklearn.linear_model.LinearRegression() 

573 self.lr.fit(self.positions[cc0], self.final_positions[cc0]) 

574 

575 # Fix up degenerate transform matrix. This happens when the spanning 

576 # tree is completely edgeless or cc0's metadata positions fall in a 

577 # straight line. In this case we fall back to the identity transform. 

578 if np.linalg.det(self.lr.coef_) < 1e-3: 578 ↛ 579line 578 didn't jump to line 579 because the condition on line 578 was never true

579 warn_data( 

580 "Could not align enough edges, proceeding anyway with original" 

581 " stage positions." 

582 ) 

583 self.lr.coef_ = np.diag(np.ones(2)) 

584 self.lr.intercept_ = np.zeros(2) 

585 

586 # Adjust position of remaining components so their centroids match 

587 # the predictions of the model 

588 for cc in components[1:]: 

589 nodes = list(cc) 

590 centroid_m = np.mean(self.positions[nodes], axis=0) 

591 centroid_f = np.mean(self.final_positions[nodes], axis=0) 

592 shift = self.lr.predict([centroid_m])[0] - centroid_f 

593 self.final_positions[nodes] += shift 

594 

595 # Adjust positions and model intercept to put origin at 0,0 

596 self.origin = self.final_positions.min(axis=0) 

597 self.final_positions -= self.origin 

598 self.lr.intercept_ -= self.origin 

599 

600 

601def _calculate_initial_positions(image_stack: np.ndarray, grid_dims: tuple, overlap_ratio: float) -> np.ndarray: 

602 """Calculate initial grid positions based on overlap ratio.""" 

603 grid_rows, grid_cols = grid_dims 

604 tile_height, tile_width = image_stack.shape[1:3] 

605 spacing_factor = 1.0 - overlap_ratio 

606 

607 positions = [] 

608 for tile_idx in range(len(image_stack)): 

609 r = tile_idx // grid_cols 

610 c = tile_idx % grid_cols 

611 

612 y_pos = r * tile_height * spacing_factor 

613 x_pos = c * tile_width * spacing_factor 

614 positions.append([y_pos, x_pos]) 

615 

616 return np.array(positions, dtype=float) 

617 

618 

619def _convert_ashlar_positions_to_openhcs(ashlar_positions: np.ndarray) -> List[Tuple[float, float]]: 

620 """Convert Ashlar positions to OpenHCS format.""" 

621 positions = [] 

622 for tile_idx in range(len(ashlar_positions)): 

623 y, x = ashlar_positions[tile_idx] 

624 positions.append((float(x), float(y))) # OpenHCS uses (x, y) format 

625 return positions 

626 

627 

628@special_inputs("grid_dimensions") 

629@special_outputs("positions") 

630@numpy_func 

631def ashlar_compute_tile_positions_cpu( 

632 image_stack: np.ndarray, 

633 grid_dimensions: Tuple[int, int], 

634 overlap_ratio: float = 0.1, 

635 max_shift: float = 15.0, 

636 stitch_alpha: float = 0.01, 

637 max_error: float = None, 

638 randomize: bool = False, 

639 verbose: bool = False, 

640 upsample_factor: int = 10, 

641 permutation_upsample: int = 1, 

642 permutation_samples: int = 1000, 

643 min_permutation_samples: int = 10, 

644 max_permutation_tries: int = 100, 

645 window_size_factor: float = 0.1, 

646 **kwargs 

647) -> Tuple[np.ndarray, List[Tuple[float, float]]]: 

648 """ 

649 Compute tile positions using the complete Ashlar algorithm - pure position calculation only. 

650 

651 This function implements the full Ashlar edge-based stitching algorithm but works directly 

652 on preprocessed grayscale image arrays. It performs ONLY position calculation without any 

653 file I/O, channel selection, or image preprocessing. All the mathematical sophistication 

654 and robustness of the original Ashlar algorithm is preserved. 

655 

656 Args: 

657 image_stack: 3D numpy array of shape (num_tiles, height, width) containing preprocessed 

658 grayscale images. Each slice [i] should be a single-channel 2D image ready 

659 for correlation analysis. No further preprocessing will be applied. 

660 

661 grid_dimensions: Tuple of (grid_rows, grid_cols) specifying the logical arrangement of 

662 tiles. For example, (2, 3) means 2 rows and 3 columns of tiles, for a 

663 total of 6 tiles. Must match the number of images in image_stack. 

664 

665 overlap_ratio: Expected fractional overlap between adjacent tiles (0.0-1.0). Default 0.1 

666 means 10% overlap. This is used to calculate initial grid positions and 

667 should match the actual overlap in your microscopy data. Typical values: 

668 - 0.05-0.15 for well-controlled microscopes 

669 - 0.15-0.25 for less precise stages 

670 

671 max_shift: Maximum allowed shift correction in micrometers. Default 15.0. This limits 

672 how far tiles can be moved from their initial grid positions during alignment. 

673 Should be set based on your microscope's stage accuracy: 

674 - 5-15 μm for high-precision stages 

675 - 15-50 μm for standard stages 

676 - 50+ μm for low-precision or manual stages 

677 

678 stitch_alpha: Alpha value for statistical error threshold computation (0.0-1.0). Default 

679 0.01 means 1% false positive rate. Lower values are stricter and reject more 

680 alignments, higher values are more permissive. This controls the trade-off 

681 between alignment quality and success rate: 

682 - 0.001-0.01: Very strict, high quality alignments only 

683 - 0.01-0.05: Balanced (recommended for most data) 

684 - 0.05-0.1: Permissive, accepts lower quality alignments 

685 

686 max_error: Explicit error threshold for rejecting alignments (None = auto-compute). 

687 When None (default), the threshold is computed automatically using permutation 

688 testing. Set to a specific value to override automatic computation. Higher 

689 values accept more alignments, lower values are stricter. 

690 

691 randomize: Whether to use random seed for permutation testing (bool). Default False uses 

692 a fixed seed for reproducible results. Set True for different random sampling 

693 in each run. Generally should be False for consistent results. 

694 

695 verbose: Enable detailed progress logging (bool). Default False. When True, prints 

696 progress information including permutation testing, edge alignment, and 

697 spanning tree construction. Useful for debugging and monitoring progress 

698 on large datasets. 

699 

700 upsample_factor: Sub-pixel accuracy factor for phase cross correlation (int). Default 10. 

701 Higher values provide better sub-pixel accuracy but increase computation time. 

702 Range: 1-100+. Values of 10-50 are typical for high-accuracy stitching. 

703 - 1: Pixel-level accuracy (fastest) 

704 - 10: 0.1 pixel accuracy (balanced) 

705 - 50: 0.02 pixel accuracy (high precision) 

706 

707 permutation_upsample: Upsample factor for permutation testing (int). Default 1. 

708 Lower than upsample_factor for speed during threshold computation. 

709 Usually kept at 1 since permutation testing doesn't need sub-pixel accuracy. 

710 

711 permutation_samples: Number of random samples for error threshold computation (int). Default 1000. 

712 Higher values give more accurate thresholds but slower computation. 

713 Automatically reduced for small datasets to avoid infinite loops. 

714 

715 min_permutation_samples: Minimum permutation samples for small datasets (int). Default 10. 

716 When there are few non-overlapping pairs, this sets the minimum 

717 number of samples to ensure statistical validity. 

718 

719 max_permutation_tries: Maximum attempts to find non-overlapping strips (int). Default 100. 

720 Prevents infinite loops in pathological cases where valid strips 

721 are hard to find. Rarely needs adjustment. 

722 

723 window_size_factor: Fraction of tile size for maximum window size (float). Default 0.1. 

724 Controls the largest overlap window tested during progressive sizing. 

725 Larger values allow detection of bigger stage errors but may reduce 

726 correlation quality. Range: 0.05-0.2 typical. 

727 

728 **kwargs: Additional parameters (ignored). Allows compatibility with other stitching 

729 algorithms that may have different parameter sets. 

730 

731 Returns: 

732 Tuple of (image_stack, positions) where: 

733 - image_stack: The original input image array (unchanged) 

734 - positions: List of (x, y) position tuples in OpenHCS format, one per tile. 

735 Positions are in pixel coordinates with (0, 0) at the top-left. 

736 The positions represent the optimal tile placement after Ashlar 

737 alignment, accounting for stage errors and image correlation. 

738 

739 Raises: 

740 Exception: If the Ashlar algorithm fails (e.g., insufficient overlap, correlation 

741 errors), the function automatically falls back to grid-based positioning 

742 using the specified overlap_ratio. 

743 

744 Notes: 

745 - This implementation contains the complete Ashlar algorithm including permutation 

746 testing, progressive window sizing, minimum spanning tree construction, and 

747 linear model fitting for disconnected components. 

748 - The correlation functions are identical to original Ashlar but without image 

749 preprocessing (whitening/filtering), allowing OpenHCS to handle preprocessing 

750 in separate pipeline steps. 

751 - For best results, ensure your image_stack contains properly preprocessed, 

752 single-channel grayscale images with good contrast and minimal noise. 

753 """ 

754 grid_rows, grid_cols = grid_dimensions 

755 

756 logger.info(f"Ashlar CPU: Processing {grid_rows}x{grid_cols} grid with {len(image_stack)} tiles") 

757 

758 try: 

759 # Calculate initial grid positions 

760 initial_positions = _calculate_initial_positions(image_stack, grid_dimensions, overlap_ratio) 

761 tile_size = np.array(image_stack.shape[1:3]) # (height, width) 

762 

763 # Create and run ArrayEdgeAligner with complete Ashlar algorithm 

764 logger.info("Running complete Ashlar edge-based stitching algorithm") 

765 aligner = ArrayEdgeAligner( 

766 image_stack=image_stack, 

767 positions=initial_positions, 

768 tile_size=tile_size, 

769 pixel_size=1.0, # Assume 1 micrometer per pixel if not specified 

770 max_shift=max_shift, 

771 alpha=stitch_alpha, 

772 max_error=max_error, 

773 randomize=randomize, 

774 verbose=verbose 

775 ) 

776 

777 # Run the complete algorithm 

778 aligner.run() 

779 

780 # Convert to OpenHCS format 

781 positions = _convert_ashlar_positions_to_openhcs(aligner.final_positions) 

782 

783 logger.info("Ashlar algorithm completed successfully") 

784 

785 except Exception as e: 

786 logger.error(f"Ashlar algorithm failed: {e}") 

787 # Fallback to grid positions if Ashlar fails 

788 logger.warning("Falling back to grid-based positioning") 

789 positions = [] 

790 tile_height, tile_width = image_stack.shape[1:3] 

791 spacing_factor = 1.0 - overlap_ratio 

792 

793 for tile_idx in range(len(image_stack)): 

794 r = tile_idx // grid_cols 

795 c = tile_idx % grid_cols 

796 x_pos = c * tile_width * spacing_factor 

797 y_pos = r * tile_height * spacing_factor 

798 positions.append((float(x_pos), float(y_pos))) 

799 

800 logger.info(f"Ashlar CPU: Completed processing {len(positions)} tile positions") 

801 

802 return image_stack, positions 

803 

804 

805def materialize_ashlar_cpu_positions(data: List[Tuple[float, float]], path: str, filemanager) -> str: 

806 """Materialize Ashlar CPU tile positions as scientific CSV with grid metadata.""" 

807 csv_path = path.replace('.pkl', '_ashlar_positions.csv') 

808 

809 df = pd.DataFrame(data, columns=['x_position_um', 'y_position_um']) 

810 df['tile_id'] = range(len(df)) 

811 

812 # Estimate grid dimensions from position layout 

813 unique_x = sorted(df['x_position_um'].unique()) 

814 unique_y = sorted(df['y_position_um'].unique()) 

815 

816 grid_cols = len(unique_x) 

817 grid_rows = len(unique_y) 

818 

819 # Add grid coordinates 

820 df['grid_row'] = df.index // grid_cols 

821 df['grid_col'] = df.index % grid_cols 

822 

823 # Add spacing information 

824 if len(unique_x) > 1: 

825 x_spacing = unique_x[1] - unique_x[0] 

826 df['x_spacing_um'] = x_spacing 

827 else: 

828 df['x_spacing_um'] = 0 

829 

830 if len(unique_y) > 1: 

831 y_spacing = unique_y[1] - unique_y[0] 

832 df['y_spacing_um'] = y_spacing 

833 else: 

834 df['y_spacing_um'] = 0 

835 

836 # Add metadata 

837 df['algorithm'] = 'ashlar_cpu' 

838 df['grid_dimensions'] = f"{grid_rows}x{grid_cols}" 

839 

840 csv_content = df.to_csv(index=False) 

841 filemanager.save(csv_content, csv_path, "disk") 

842 return csv_path