Coverage for openhcs/processing/backends/pos_gen/ashlar_main_cpu.py: 79.1%

402 statements  

« prev     ^ index     » next       coverage.py v7.10.3, created at 2025-08-14 05:57 +0000

1""" 

2OpenHCS Interface for Ashlar CPU Stitching Algorithm 

3 

4Array-based EdgeAligner implementation that works directly with numpy arrays 

5instead of file-based readers. This is the complete Ashlar algorithm modified 

6to accept arrays directly. 

7""" 

8from __future__ import annotations 

9import logging 

10import sys 

11from typing import TYPE_CHECKING, Tuple, List 

12import numpy as np 

13import networkx as nx 

14import scipy.spatial.distance 

15import sklearn.linear_model 

16import pandas as pd 

17 

18from openhcs.core.pipeline.function_contracts import special_inputs, special_outputs 

19from openhcs.core.memory.decorators import numpy as numpy_func 

20from openhcs.core.utils import optional_import 

21 

22import warnings 

23 

24if TYPE_CHECKING: 24 ↛ 25line 24 didn't jump to line 25 because the condition on line 24 was never true

25 pass 

26 

27logger = logging.getLogger(__name__) 

28 

29 

30class DataWarning(Warning): 

31 """Warnings about the content of user-provided image data.""" 

32 pass 

33 

34 

35def warn_data(message): 

36 """Issue a warning about image data.""" 

37 warnings.warn(message, DataWarning) 

38 

39 

40class Intersection: 

41 """Calculate intersection region between two tiles (extracted from Ashlar).""" 

42 

43 def __init__(self, corners1, corners2, min_size=0): 

44 if np.isscalar(min_size): 

45 min_size = np.repeat(min_size, 2) 

46 self._calculate(corners1, corners2, min_size) 

47 

48 def _calculate(self, corners1, corners2, min_size): 

49 """Calculate intersection parameters with robust boundary validation.""" 

50 # corners1 and corners2 are arrays of shape (2, 2) containing 

51 # the upper-left and lower-right corners of the two tiles 

52 max_shape = (corners2 - corners1).max(axis=0) 

53 min_size = min_size.clip(1, max_shape) 

54 position = corners1.max(axis=0) 

55 initial_shape = np.floor(corners2.min(axis=0) - position).astype(int) 

56 clipped_shape = np.maximum(initial_shape, min_size) 

57 self.shape = np.ceil(clipped_shape).astype(int) 

58 self.padding = self.shape - initial_shape 

59 

60 # Calculate offsets with boundary validation 

61 raw_offsets = np.maximum(position - corners1 - self.padding, 0) 

62 

63 # Validate that offsets + shape don't exceed tile boundaries 

64 tile_sizes = corners2 - corners1 

65 for i in range(2): 

66 # Ensure offset + shape <= tile_size for each tile 

67 max_offset = tile_sizes[i] - self.shape 

68 raw_offsets[i] = np.minimum(raw_offsets[i], np.maximum(max_offset, 0)) 

69 

70 # Ensure shape doesn't exceed available space 

71 available_space = tile_sizes[i] - raw_offsets[i] 

72 self.shape = np.minimum(self.shape, available_space.astype(int)) 

73 

74 # Final validation - ensure shape is positive 

75 self.shape = np.maximum(self.shape, 1) 

76 

77 self.offsets = raw_offsets.astype(int) 

78 

79 # Calculate fractional offset difference for subpixel accuracy 

80 offset_diff = self.offsets[1] - self.offsets[0] 

81 self.offset_diff_frac = offset_diff - offset_diff.round() 

82 

83 

84def _get_window(shape): 

85 """Build a 2D Hann window (from Ashlar utils.get_window).""" 

86 # Build a 2D Hann window by taking the outer product of two 1-D windows. 

87 wy = np.hanning(shape[0]).astype(np.float32) 

88 wx = np.hanning(shape[1]).astype(np.float32) 

89 window = np.outer(wy, wx) 

90 return window 

91 

92 

93def ashlar_register_no_preprocessing(img1, img2, upsample=10): 

94 """ 

95 Robust Ashlar register function with comprehensive input validation. 

96 

97 This is based on ashlar.utils.register() but adds validation to handle 

98 edge cases that can occur with real microscopy data. 

99 """ 

100 import itertools 

101 import scipy.ndimage 

102 import skimage.registration 

103 

104 # Input validation 

105 if img1 is None or img2 is None: 105 ↛ 106line 105 didn't jump to line 106 because the condition on line 105 was never true

106 return np.array([0.0, 0.0]), np.inf 

107 

108 if img1.size == 0 or img2.size == 0: 108 ↛ 109line 108 didn't jump to line 109 because the condition on line 108 was never true

109 return np.array([0.0, 0.0]), np.inf 

110 

111 if img1.shape != img2.shape: 111 ↛ 112line 111 didn't jump to line 112 because the condition on line 111 was never true

112 return np.array([0.0, 0.0]), np.inf 

113 

114 if len(img1.shape) != 2: 114 ↛ 115line 114 didn't jump to line 115 because the condition on line 114 was never true

115 return np.array([0.0, 0.0]), np.inf 

116 

117 if img1.shape[0] < 1 or img1.shape[1] < 1: 117 ↛ 118line 117 didn't jump to line 118 because the condition on line 117 was never true

118 return np.array([0.0, 0.0]), np.inf 

119 

120 # Convert to float32 (equivalent to what whiten() does) - match GPU version 

121 img1w = img1.astype(np.float32) 

122 img2w = img2.astype(np.float32) 

123 

124 # Apply windowing function (from original Ashlar) 

125 img1w = img1w * _get_window(img1w.shape) 

126 img2w = img2w * _get_window(img2w.shape) 

127 

128 # Use skimage's phase cross correlation with error handling 

129 try: 

130 shift = skimage.registration.phase_cross_correlation( 

131 img1w, 

132 img2w, 

133 upsample_factor=upsample, 

134 normalization=None 

135 )[0] 

136 except Exception as e: 

137 # If phase correlation fails, return large error 

138 logger.error(f"Ashlar CPU: PHASE CORRELATION FAILED - Exception: {e}") 

139 logger.error(f" Returning infinite error") 

140 return np.array([0.0, 0.0]), np.inf 

141 

142 # At this point we may have a shift in the wrong quadrant since the FFT 

143 # assumes the signal is periodic. We test all four possibilities and return 

144 # the shift that gives the highest direct correlation (sum of products). 

145 shape = np.array(img1.shape) 

146 shift_pos = (shift + shape) % shape 

147 shift_neg = shift_pos - shape 

148 shifts = list(itertools.product(*zip(shift_pos, shift_neg))) 

149 correlations = [] 

150 for s in shifts: 

151 try: 

152 shifted_img = scipy.ndimage.shift(img2w, s, order=0) 

153 corr = np.abs(np.sum(img1w * shifted_img)) 

154 correlations.append(corr) 

155 except Exception: 

156 correlations.append(0.0) 

157 

158 if not correlations or max(correlations) == 0: 158 ↛ 159line 158 didn't jump to line 159 because the condition on line 158 was never true

159 logger.warning(f"Ashlar CPU: NO VALID CORRELATIONS - All correlations failed or zero") 

160 return np.array([0.0, 0.0]), np.inf 

161 

162 idx = np.argmax(correlations) 

163 shift = shifts[idx] 

164 correlation = correlations[idx] 

165 total_amplitude = np.linalg.norm(img1w) * np.linalg.norm(img2w) 

166 if correlation > 0 and total_amplitude > 0: 166 ↛ 169line 166 didn't jump to line 169 because the condition on line 166 was always true

167 error = -np.log(correlation / total_amplitude) 

168 else: 

169 error = np.inf 

170 

171 # Log all correlation results at INFO level for user visibility 

172 if error > 1.0: # High error threshold for Ashlar 

173 logger.warning(f"Ashlar CPU: HIGH CORRELATION ERROR - Error={error:.4f}, Shift=({shift[0]:.2f}, {shift[1]:.2f})") 

174 logger.warning(f" This indicates poor overlap or image quality between tiles") 

175 else: 

176 logger.info(f"Ashlar CPU: Correlation - Error={error:.4f}, Shift=({shift[0]:.2f}, {shift[1]:.2f})") 

177 

178 return shift, error 

179 

180 

181def ashlar_nccw_no_preprocessing(img1, img2): 

182 """ 

183 Robust Ashlar nccw function with comprehensive input validation. 

184 

185 This is based on ashlar.utils.nccw() but adds validation to handle 

186 edge cases that can occur with real microscopy data. 

187 """ 

188 # Input validation 

189 if img1 is None or img2 is None: 189 ↛ 190line 189 didn't jump to line 190 because the condition on line 189 was never true

190 return np.inf 

191 

192 if img1.size == 0 or img2.size == 0: 192 ↛ 193line 192 didn't jump to line 193 because the condition on line 192 was never true

193 return np.inf 

194 

195 if img1.shape != img2.shape: 195 ↛ 196line 195 didn't jump to line 196 because the condition on line 195 was never true

196 return np.inf 

197 

198 if len(img1.shape) != 2: 198 ↛ 199line 198 didn't jump to line 199 because the condition on line 198 was never true

199 return np.inf 

200 

201 if img1.shape[0] < 1 or img1.shape[1] < 1: 201 ↛ 202line 201 didn't jump to line 202 because the condition on line 201 was never true

202 return np.inf 

203 

204 # Convert to float32 (equivalent to what whiten() does) - match GPU version 

205 img1w = img1.astype(np.float32) 

206 img2w = img2.astype(np.float32) 

207 

208 correlation = np.abs(np.sum(img1w * img2w)) 

209 total_amplitude = np.linalg.norm(img1w) * np.linalg.norm(img2w) 

210 if correlation > 0 and total_amplitude > 0: 210 ↛ 224line 210 didn't jump to line 224 because the condition on line 210 was always true

211 diff = correlation - total_amplitude 

212 if diff <= 0: 

213 error = -np.log(correlation / total_amplitude) 

214 elif diff < 1e-3: # Increased tolerance for robustness 214 ↛ 218line 214 didn't jump to line 218 because the condition on line 214 was never true

215 # This situation can occur due to numerical precision issues when 

216 # img1 and img2 are very nearly or exactly identical. If the 

217 # difference is small enough, let it slide. 

218 error = 0 

219 else: 

220 # Instead of raising error, return large but finite error 

221 logger.warning(f"Ashlar CPU: NCCW numerical precision issue - diff={diff:.6f}, using error=100.0") 

222 error = 100.0 # Large error but not infinite 

223 else: 

224 logger.warning(f"Ashlar CPU: NCCW invalid correlation - correlation={correlation:.6f}, total_amplitude={total_amplitude:.6f}") 

225 error = np.inf 

226 

227 # Log all NCCW results at INFO level for user visibility 

228 if error > 10.0: # High NCCW error threshold 

229 logger.warning(f"Ashlar CPU: HIGH NCCW ERROR - Error={error:.4f}") 

230 logger.warning(f" This indicates poor image correlation between tiles") 

231 else: 

232 logger.info(f"Ashlar CPU: NCCW - Error={error:.4f}") 

233 

234 return error 

235 

236 

237def ashlar_crop(img, offset, shape): 

238 """ 

239 Robust Ashlar crop function with comprehensive boundary validation. 

240 

241 This is based on ashlar.utils.crop() but adds validation to handle 

242 edge cases that can occur with real microscopy data. 

243 """ 

244 # Input validation 

245 if img is None or img.size == 0: 245 ↛ 246line 245 didn't jump to line 246 because the condition on line 245 was never true

246 raise ValueError("Cannot crop from empty or None image") 

247 

248 # Convert to integers and validate 

249 start = offset.round().astype(int) 

250 shape = np.round(shape).astype(int) 

251 

252 # Ensure start is non-negative 

253 start = np.maximum(start, 0) 

254 

255 # Ensure shape is positive 

256 shape = np.maximum(shape, 1) 

257 

258 # Validate bounds 

259 img_height, img_width = img.shape[:2] 

260 end = start + shape 

261 

262 # Clamp to image boundaries 

263 start[0] = min(start[0], img_height - 1) 

264 start[1] = min(start[1], img_width - 1) 

265 end[0] = min(end[0], img_height) 

266 end[1] = min(end[1], img_width) 

267 

268 # Ensure we have a valid region 

269 if end[0] <= start[0] or end[1] <= start[1]: 269 ↛ 271line 269 didn't jump to line 271 because the condition on line 269 was never true

270 # Return minimum valid region if bounds are invalid 

271 return img[start[0]:start[0]+1, start[1]:start[1]+1] 

272 

273 return img[start[0]:end[0], start[1]:end[1]] 

274 

275 

276class ArrayEdgeAligner: 

277 """ 

278 Array-based EdgeAligner that implements the complete Ashlar algorithm 

279 but works directly with numpy arrays instead of file readers. 

280 """ 

281 

282 def __init__(self, image_stack, positions, tile_size, pixel_size=1.0, 

283 max_shift=15, alpha=0.01, max_error=None, 

284 randomize=False, verbose=False): 

285 """ 

286 Initialize array-based EdgeAligner for pure position calculation. 

287 

288 Args: 

289 image_stack: 3D numpy array (num_tiles, height, width) - preprocessed grayscale 

290 positions: 2D array of tile positions (num_tiles, 2) in pixels 

291 tile_size: Array [height, width] of tile dimensions 

292 pixel_size: Pixel size in micrometers (for max_shift conversion) 

293 max_shift: Maximum allowed shift in micrometers 

294 alpha: Alpha value for error threshold (lower = stricter) 

295 max_error: Explicit error threshold (None = auto-compute) 

296 randomize: Use random seed for permutation testing 

297 verbose: Enable verbose logging 

298 """ 

299 self.image_stack = image_stack 

300 self.positions = positions.astype(float) 

301 self.tile_size = np.array(tile_size) 

302 self.pixel_size = pixel_size 

303 self.max_shift = max_shift 

304 self.max_shift_pixels = self.max_shift / self.pixel_size 

305 self.alpha = alpha 

306 self.max_error = max_error 

307 self.randomize = randomize 

308 self.verbose = verbose 

309 self._cache = {} 

310 self.errors_negative_sampled = np.empty(0) 

311 

312 # Build neighbors graph 

313 self.neighbors_graph = self._build_neighbors_graph() 

314 

315 def _build_neighbors_graph(self): 

316 """Build graph of neighboring (overlapping) tiles.""" 

317 pdist = scipy.spatial.distance.pdist(self.positions, metric='cityblock') 

318 sp = scipy.spatial.distance.squareform(pdist) 

319 max_distance = self.tile_size.max() + 1 

320 edges = zip(*np.nonzero((sp > 0) & (sp < max_distance))) 

321 graph = nx.from_edgelist(edges) 

322 graph.add_nodes_from(range(len(self.positions))) 

323 return graph 

324 

325 

326 def run(self): 

327 """Run the complete Ashlar algorithm.""" 

328 self.check_overlaps() 

329 self.compute_threshold() 

330 self.register_all() 

331 self.build_spanning_tree() 

332 self.calculate_positions() 

333 self.fit_model() 

334 

335 def check_overlaps(self): 

336 """Check if tiles actually overlap based on positions.""" 

337 overlaps = np.array([ 

338 self.tile_size - abs(self.positions[t1] - self.positions[t2]) 

339 for t1, t2 in self.neighbors_graph.edges 

340 ]) 

341 failures = np.any(overlaps < 1, axis=1) if len(overlaps) else [] 

342 if len(failures) and all(failures): 342 ↛ 343line 342 didn't jump to line 343 because the condition on line 342 was never true

343 warn_data("No tiles overlap, attempting alignment anyway.") 

344 elif any(failures): 344 ↛ 345line 344 didn't jump to line 345 because the condition on line 344 was never true

345 warn_data("Some neighboring tiles have zero overlap.") 

346 

347 def compute_threshold(self): 

348 """Compute error threshold using permutation testing.""" 

349 if self.max_error is not None: 349 ↛ 350line 349 didn't jump to line 350 because the condition on line 349 was never true

350 if self.verbose: 

351 print(" using explicit error threshold") 

352 return 

353 

354 edges = self.neighbors_graph.edges 

355 num_tiles = len(self.image_stack) 

356 

357 # If not enough tiles overlap to matter, skip this whole thing 

358 if len(edges) <= 1: 358 ↛ 359line 358 didn't jump to line 359 because the condition on line 358 was never true

359 self.max_error = np.inf 

360 return 

361 

362 widths = np.array([ 

363 self.intersection(t1, t2).shape.min() 

364 for t1, t2 in edges 

365 ]) 

366 w = widths.max() 

367 max_offset = self.tile_size[0] - w 

368 

369 # Number of possible pairs minus number of actual neighbor pairs 

370 num_distant_pairs = num_tiles * (num_tiles - 1) // 2 - len(edges) 

371 

372 # Reduce permutation count for small datasets 

373 n = 1000 if num_distant_pairs > 8 else (num_distant_pairs + 1) * 10 

374 pairs = np.empty((n, 2), dtype=int) 

375 offsets = np.empty((n, 2), dtype=int) 

376 

377 # Generate n random non-overlapping image strips 

378 max_tries = 100 

379 if self.randomize is False: 379 ↛ 382line 379 didn't jump to line 382 because the condition on line 379 was always true

380 random_state = np.random.RandomState(0) 

381 else: 

382 random_state = np.random.RandomState() 

383 

384 for i in range(n): 

385 # Limit tries to avoid infinite loop in pathological cases 

386 for current_try in range(max_tries): 386 ↛ 410line 386 didn't jump to line 410 because the loop on line 386 didn't complete

387 t1, t2 = random_state.randint(num_tiles, size=2) 

388 o1, o2 = random_state.randint(max_offset, size=2) 

389 

390 # Check for non-overlapping strips and abort the retry loop 

391 if t1 != t2 and (t1, t2) not in edges: 

392 # Different, non-neighboring tiles -- always OK 

393 break 

394 elif t1 == t2 and abs(o1 - o2) > w: 

395 # Same tile OK if strips don't overlap within the image 

396 break 

397 elif (t1, t2) in edges: 

398 # Neighbors OK if either strip is entirely outside the 

399 # expected overlap region (based on nominal positions) 

400 its = self.intersection(t1, t2, np.repeat(w, 2)) 

401 ioff1, ioff2 = its.offsets[:, 0] 

402 if ( 

403 its.shape[0] > its.shape[1] 

404 or o1 < ioff1 - w or o1 > ioff1 + w 

405 or o2 < ioff2 - w or o2 > ioff2 + w 

406 ): 

407 break 

408 else: 

409 # Retries exhausted. This should be very rare. 

410 warn_data(f"Could not find non-overlapping strips in {max_tries} tries") 

411 pairs[i] = t1, t2 

412 offsets[i] = o1, o2 

413 

414 errors = np.empty(n) 

415 for i, ((t1, t2), (offset1, offset2)) in enumerate(zip(pairs, offsets)): 

416 # if self.verbose and (i % 10 == 9 or i == n - 1): 

417 # sys.stdout.write(f'\r quantifying alignment error {i + 1}/{n}') 

418 # sys.stdout.flush() 

419 img1 = self.image_stack[t1][offset1:offset1+w, :] 

420 img2 = self.image_stack[t2][offset2:offset2+w, :] 

421 _, errors[i] = ashlar_register_no_preprocessing(img1, img2, upsample=1) 

422 # if self.verbose: 

423 # print() 

424 self.errors_negative_sampled = errors 

425 self.max_error = np.percentile(errors, self.alpha * 100) 

426 

427 

428 def register_all(self): 

429 """Register all neighboring tile pairs.""" 

430 n = self.neighbors_graph.size() 

431 for i, (t1, t2) in enumerate(self.neighbors_graph.edges, 1): 

432 if self.verbose: 432 ↛ 433line 432 didn't jump to line 433 because the condition on line 432 was never true

433 sys.stdout.write(f'\r aligning edge {i}/{n}') 

434 sys.stdout.flush() 

435 self.register_pair(t1, t2) 

436 if self.verbose: 436 ↛ 437line 436 didn't jump to line 437 because the condition on line 436 was never true

437 print() 

438 self.all_errors = np.array([x[1] for x in self._cache.values()]) 

439 

440 # Set error values above the threshold to infinity 

441 for k, v in self._cache.items(): 

442 if v[1] > self.max_error or any(np.abs(v[0]) > self.max_shift_pixels): 

443 self._cache[k] = (v[0], np.inf) 

444 

445 def register_pair(self, t1, t2): 

446 """Return relative shift between images and the alignment error.""" 

447 key = tuple(sorted((t1, t2))) 

448 try: 

449 shift, error = self._cache[key] 

450 except KeyError: 

451 # Test a series of increasing overlap window sizes to help avoid 

452 # missing alignments when the stage position error is large relative 

453 # to the tile overlap. Simply using a large overlap in all cases 

454 # limits the maximum achievable correlation thus increasing the 

455 # error metric, leading to worse overall results. The window size 

456 # starts at the nominal size and doubles until it's at least 10% of 

457 # the tile size. If the nominal overlap is already 10% or greater, 

458 # we only use that one size. 

459 smin = self.intersection(key[0], key[1]).shape 

460 smax = np.round(self.tile_size * 0.1) 

461 sizes = [smin] 

462 while any(sizes[-1] < smax): 

463 sizes.append(sizes[-1] * 2) 

464 # Test each window size with validation 

465 results = [] 

466 for s in sizes: 

467 try: 

468 result = self._register(key[0], key[1], s) 

469 results.append(result) 

470 except Exception: 

471 # If this window size fails, use infinite error 

472 results.append((np.array([0.0, 0.0]), np.inf)) 

473 # Use the shift from the window size that gave the lowest error 

474 shift, _ = min(results, key=lambda r: r[1]) 

475 # Extract the images from the nominal overlap window but with the 

476 # shift applied to the second tile's position, and compute the error 

477 # metric on these images. This should be even lower than the error 

478 # computed above. 

479 _, o1, o2 = self.overlap(key[0], key[1], shift=shift) 

480 error = ashlar_nccw_no_preprocessing(o1, o2) 

481 self._cache[key] = (shift, error) 

482 if t1 > t2: 

483 shift = -shift 

484 # Return copy of shift to prevent corruption of cached values 

485 return shift.copy(), error 

486 

487 def _register(self, t1, t2, min_size=0): 

488 """Register a single tile pair with given minimum size.""" 

489 its, img1, img2 = self.overlap(t1, t2, min_size) 

490 # Account for padding, flipping the sign depending on the direction 

491 # between the tiles 

492 p1, p2 = self.positions[[t1, t2]] 

493 sx = 1 if p1[1] >= p2[1] else -1 

494 sy = 1 if p1[0] >= p2[0] else -1 

495 padding = its.padding * [sy, sx] 

496 shift, error = ashlar_register_no_preprocessing(img1, img2) 

497 shift += padding 

498 return shift, error 

499 

500 

501 def intersection(self, t1, t2, min_size=0, shift=None): 

502 """Calculate intersection region between two tiles.""" 

503 corners1 = self.positions[[t1, t2]] 

504 if shift is not None: 

505 corners1[1] += shift 

506 corners2 = corners1 + self.tile_size 

507 return Intersection(corners1, corners2, min_size) 

508 

509 def crop(self, tile_id, offset, shape): 

510 """Crop image from tile at given offset and shape.""" 

511 img = self.image_stack[tile_id] 

512 return ashlar_crop(img, offset, shape) 

513 

514 def overlap(self, t1, t2, min_size=0, shift=None): 

515 """Extract overlapping regions between two tiles.""" 

516 its = self.intersection(t1, t2, min_size, shift) 

517 img1 = self.crop(t1, its.offsets[0], its.shape) 

518 img2 = self.crop(t2, its.offsets[1], its.shape) 

519 return its, img1, img2 

520 

521 

522 

523 

524 

525 def build_spanning_tree(self): 

526 """Build minimum spanning tree from registered edges.""" 

527 g = nx.Graph() 

528 g.add_nodes_from(self.neighbors_graph) 

529 g.add_weighted_edges_from( 

530 (t1, t2, error) 

531 for (t1, t2), (_, error) in self._cache.items() 

532 if np.isfinite(error) 

533 ) 

534 spanning_tree = nx.Graph() 

535 spanning_tree.add_nodes_from(g) 

536 for c in nx.connected_components(g): 

537 cc = g.subgraph(c) 

538 center = nx.center(cc)[0] 

539 paths = nx.single_source_dijkstra_path(cc, center).values() 

540 for path in paths: 

541 nx.add_path(spanning_tree, path) 

542 self.spanning_tree = spanning_tree 

543 

544 def calculate_positions(self): 

545 """Calculate final positions from spanning tree.""" 

546 shifts = {} 

547 for c in nx.connected_components(self.spanning_tree): 

548 cc = self.spanning_tree.subgraph(c) 

549 center = nx.center(cc)[0] 

550 shifts[center] = np.array([0, 0]) 

551 for edge in nx.traversal.bfs_edges(cc, center): 

552 source, dest = edge 

553 if source not in shifts: 553 ↛ 554line 553 didn't jump to line 554 because the condition on line 553 was never true

554 source, dest = dest, source 

555 shift = self.register_pair(source, dest)[0] 

556 shifts[dest] = shifts[source] + shift 

557 if shifts: 

558 self.shifts = np.array([s for _, s in sorted(shifts.items())]) 

559 self.final_positions = self.positions + self.shifts 

560 else: 

561 # TODO: fill in shifts and positions with 0x2 arrays 

562 raise NotImplementedError("No images") 

563 

564 

565 def fit_model(self): 

566 """Fit linear model to handle disconnected components.""" 

567 components = sorted( 

568 nx.connected_components(self.spanning_tree), 

569 key=len, reverse=True 

570 ) 

571 # Fit LR model on positions of largest connected component 

572 cc0 = list(components[0]) 

573 self.lr = sklearn.linear_model.LinearRegression() 

574 self.lr.fit(self.positions[cc0], self.final_positions[cc0]) 

575 

576 # Fix up degenerate transform matrix. This happens when the spanning 

577 # tree is completely edgeless or cc0's metadata positions fall in a 

578 # straight line. In this case we fall back to the identity transform. 

579 if np.linalg.det(self.lr.coef_) < 1e-3: 

580 warn_data( 

581 "Could not align enough edges, proceeding anyway with original" 

582 " stage positions." 

583 ) 

584 self.lr.coef_ = np.diag(np.ones(2)) 

585 self.lr.intercept_ = np.zeros(2) 

586 

587 # Adjust position of remaining components so their centroids match 

588 # the predictions of the model 

589 for cc in components[1:]: 

590 nodes = list(cc) 

591 centroid_m = np.mean(self.positions[nodes], axis=0) 

592 centroid_f = np.mean(self.final_positions[nodes], axis=0) 

593 shift = self.lr.predict([centroid_m])[0] - centroid_f 

594 self.final_positions[nodes] += shift 

595 

596 # Adjust positions and model intercept to put origin at 0,0 

597 self.origin = self.final_positions.min(axis=0) 

598 self.final_positions -= self.origin 

599 self.lr.intercept_ -= self.origin 

600 

601 

602def _calculate_initial_positions(image_stack: np.ndarray, grid_dims: tuple, overlap_ratio: float) -> np.ndarray: 

603 """Calculate initial grid positions based on overlap ratio.""" 

604 grid_rows, grid_cols = grid_dims 

605 tile_height, tile_width = image_stack.shape[1:3] 

606 spacing_factor = 1.0 - overlap_ratio 

607 

608 positions = [] 

609 for tile_idx in range(len(image_stack)): 

610 r = tile_idx // grid_cols 

611 c = tile_idx % grid_cols 

612 

613 y_pos = r * tile_height * spacing_factor 

614 x_pos = c * tile_width * spacing_factor 

615 positions.append([y_pos, x_pos]) 

616 

617 return np.array(positions, dtype=float) 

618 

619 

620def _convert_ashlar_positions_to_openhcs(ashlar_positions: np.ndarray) -> List[Tuple[float, float]]: 

621 """Convert Ashlar positions to OpenHCS format.""" 

622 positions = [] 

623 for tile_idx in range(len(ashlar_positions)): 

624 y, x = ashlar_positions[tile_idx] 

625 positions.append((float(x), float(y))) # OpenHCS uses (x, y) format 

626 return positions 

627 

628 

629@special_inputs("grid_dimensions") 

630@special_outputs("positions") 

631@numpy_func 

632def ashlar_compute_tile_positions_cpu( 

633 image_stack: np.ndarray, 

634 grid_dimensions: Tuple[int, int], 

635 overlap_ratio: float = 0.1, 

636 max_shift: float = 15.0, 

637 stitch_alpha: float = 0.01, 

638 max_error: float = None, 

639 randomize: bool = False, 

640 verbose: bool = False, 

641 upsample_factor: int = 10, 

642 permutation_upsample: int = 1, 

643 permutation_samples: int = 1000, 

644 min_permutation_samples: int = 10, 

645 max_permutation_tries: int = 100, 

646 window_size_factor: float = 0.1, 

647 **kwargs 

648) -> Tuple[np.ndarray, List[Tuple[float, float]]]: 

649 """ 

650 Compute tile positions using the complete Ashlar algorithm - pure position calculation only. 

651 

652 This function implements the full Ashlar edge-based stitching algorithm but works directly 

653 on preprocessed grayscale image arrays. It performs ONLY position calculation without any 

654 file I/O, channel selection, or image preprocessing. All the mathematical sophistication 

655 and robustness of the original Ashlar algorithm is preserved. 

656 

657 Args: 

658 image_stack: 3D numpy array of shape (num_tiles, height, width) containing preprocessed 

659 grayscale images. Each slice [i] should be a single-channel 2D image ready 

660 for correlation analysis. No further preprocessing will be applied. 

661 

662 grid_dimensions: Tuple of (grid_rows, grid_cols) specifying the logical arrangement of 

663 tiles. For example, (2, 3) means 2 rows and 3 columns of tiles, for a 

664 total of 6 tiles. Must match the number of images in image_stack. 

665 

666 overlap_ratio: Expected fractional overlap between adjacent tiles (0.0-1.0). Default 0.1 

667 means 10% overlap. This is used to calculate initial grid positions and 

668 should match the actual overlap in your microscopy data. Typical values: 

669 - 0.05-0.15 for well-controlled microscopes 

670 - 0.15-0.25 for less precise stages 

671 

672 max_shift: Maximum allowed shift correction in micrometers. Default 15.0. This limits 

673 how far tiles can be moved from their initial grid positions during alignment. 

674 Should be set based on your microscope's stage accuracy: 

675 - 5-15 μm for high-precision stages 

676 - 15-50 μm for standard stages 

677 - 50+ μm for low-precision or manual stages 

678 

679 stitch_alpha: Alpha value for statistical error threshold computation (0.0-1.0). Default 

680 0.01 means 1% false positive rate. Lower values are stricter and reject more 

681 alignments, higher values are more permissive. This controls the trade-off 

682 between alignment quality and success rate: 

683 - 0.001-0.01: Very strict, high quality alignments only 

684 - 0.01-0.05: Balanced (recommended for most data) 

685 - 0.05-0.1: Permissive, accepts lower quality alignments 

686 

687 max_error: Explicit error threshold for rejecting alignments (None = auto-compute). 

688 When None (default), the threshold is computed automatically using permutation 

689 testing. Set to a specific value to override automatic computation. Higher 

690 values accept more alignments, lower values are stricter. 

691 

692 randomize: Whether to use random seed for permutation testing (bool). Default False uses 

693 a fixed seed for reproducible results. Set True for different random sampling 

694 in each run. Generally should be False for consistent results. 

695 

696 verbose: Enable detailed progress logging (bool). Default False. When True, prints 

697 progress information including permutation testing, edge alignment, and 

698 spanning tree construction. Useful for debugging and monitoring progress 

699 on large datasets. 

700 

701 upsample_factor: Sub-pixel accuracy factor for phase cross correlation (int). Default 10. 

702 Higher values provide better sub-pixel accuracy but increase computation time. 

703 Range: 1-100+. Values of 10-50 are typical for high-accuracy stitching. 

704 - 1: Pixel-level accuracy (fastest) 

705 - 10: 0.1 pixel accuracy (balanced) 

706 - 50: 0.02 pixel accuracy (high precision) 

707 

708 permutation_upsample: Upsample factor for permutation testing (int). Default 1. 

709 Lower than upsample_factor for speed during threshold computation. 

710 Usually kept at 1 since permutation testing doesn't need sub-pixel accuracy. 

711 

712 permutation_samples: Number of random samples for error threshold computation (int). Default 1000. 

713 Higher values give more accurate thresholds but slower computation. 

714 Automatically reduced for small datasets to avoid infinite loops. 

715 

716 min_permutation_samples: Minimum permutation samples for small datasets (int). Default 10. 

717 When there are few non-overlapping pairs, this sets the minimum 

718 number of samples to ensure statistical validity. 

719 

720 max_permutation_tries: Maximum attempts to find non-overlapping strips (int). Default 100. 

721 Prevents infinite loops in pathological cases where valid strips 

722 are hard to find. Rarely needs adjustment. 

723 

724 window_size_factor: Fraction of tile size for maximum window size (float). Default 0.1. 

725 Controls the largest overlap window tested during progressive sizing. 

726 Larger values allow detection of bigger stage errors but may reduce 

727 correlation quality. Range: 0.05-0.2 typical. 

728 

729 **kwargs: Additional parameters (ignored). Allows compatibility with other stitching 

730 algorithms that may have different parameter sets. 

731 

732 Returns: 

733 Tuple of (image_stack, positions) where: 

734 - image_stack: The original input image array (unchanged) 

735 - positions: List of (x, y) position tuples in OpenHCS format, one per tile. 

736 Positions are in pixel coordinates with (0, 0) at the top-left. 

737 The positions represent the optimal tile placement after Ashlar 

738 alignment, accounting for stage errors and image correlation. 

739 

740 Raises: 

741 Exception: If the Ashlar algorithm fails (e.g., insufficient overlap, correlation 

742 errors), the function automatically falls back to grid-based positioning 

743 using the specified overlap_ratio. 

744 

745 Notes: 

746 - This implementation contains the complete Ashlar algorithm including permutation 

747 testing, progressive window sizing, minimum spanning tree construction, and 

748 linear model fitting for disconnected components. 

749 - The correlation functions are identical to original Ashlar but without image 

750 preprocessing (whitening/filtering), allowing OpenHCS to handle preprocessing 

751 in separate pipeline steps. 

752 - For best results, ensure your image_stack contains properly preprocessed, 

753 single-channel grayscale images with good contrast and minimal noise. 

754 """ 

755 grid_rows, grid_cols = grid_dimensions 

756 

757 logger.info(f"Ashlar CPU: Processing {grid_rows}x{grid_cols} grid with {len(image_stack)} tiles") 

758 

759 try: 

760 # Calculate initial grid positions 

761 initial_positions = _calculate_initial_positions(image_stack, grid_dimensions, overlap_ratio) 

762 tile_size = np.array(image_stack.shape[1:3]) # (height, width) 

763 

764 # Create and run ArrayEdgeAligner with complete Ashlar algorithm 

765 logger.info("Running complete Ashlar edge-based stitching algorithm") 

766 aligner = ArrayEdgeAligner( 

767 image_stack=image_stack, 

768 positions=initial_positions, 

769 tile_size=tile_size, 

770 pixel_size=1.0, # Assume 1 micrometer per pixel if not specified 

771 max_shift=max_shift, 

772 alpha=stitch_alpha, 

773 max_error=max_error, 

774 randomize=randomize, 

775 verbose=verbose 

776 ) 

777 

778 # Run the complete algorithm 

779 aligner.run() 

780 

781 # Convert to OpenHCS format 

782 positions = _convert_ashlar_positions_to_openhcs(aligner.final_positions) 

783 

784 logger.info("Ashlar algorithm completed successfully") 

785 

786 except Exception as e: 

787 logger.error(f"Ashlar algorithm failed: {e}") 

788 # Fallback to grid positions if Ashlar fails 

789 logger.warning("Falling back to grid-based positioning") 

790 positions = [] 

791 tile_height, tile_width = image_stack.shape[1:3] 

792 spacing_factor = 1.0 - overlap_ratio 

793 

794 for tile_idx in range(len(image_stack)): 

795 r = tile_idx // grid_cols 

796 c = tile_idx % grid_cols 

797 x_pos = c * tile_width * spacing_factor 

798 y_pos = r * tile_height * spacing_factor 

799 positions.append((float(x_pos), float(y_pos))) 

800 

801 logger.info(f"Ashlar CPU: Completed processing {len(positions)} tile positions") 

802 

803 return image_stack, positions 

804 

805 

806def materialize_ashlar_cpu_positions(data: List[Tuple[float, float]], path: str, filemanager) -> str: 

807 """Materialize Ashlar CPU tile positions as scientific CSV with grid metadata.""" 

808 csv_path = path.replace('.pkl', '_ashlar_positions.csv') 

809 

810 df = pd.DataFrame(data, columns=['x_position_um', 'y_position_um']) 

811 df['tile_id'] = range(len(df)) 

812 

813 # Estimate grid dimensions from position layout 

814 unique_x = sorted(df['x_position_um'].unique()) 

815 unique_y = sorted(df['y_position_um'].unique()) 

816 

817 grid_cols = len(unique_x) 

818 grid_rows = len(unique_y) 

819 

820 # Add grid coordinates 

821 df['grid_row'] = df.index // grid_cols 

822 df['grid_col'] = df.index % grid_cols 

823 

824 # Add spacing information 

825 if len(unique_x) > 1: 

826 x_spacing = unique_x[1] - unique_x[0] 

827 df['x_spacing_um'] = x_spacing 

828 else: 

829 df['x_spacing_um'] = 0 

830 

831 if len(unique_y) > 1: 

832 y_spacing = unique_y[1] - unique_y[0] 

833 df['y_spacing_um'] = y_spacing 

834 else: 

835 df['y_spacing_um'] = 0 

836 

837 # Add metadata 

838 df['algorithm'] = 'ashlar_cpu' 

839 df['grid_dimensions'] = f"{grid_rows}x{grid_cols}" 

840 

841 csv_content = df.to_csv(index=False) 

842 filemanager.save(csv_content, csv_path, "disk") 

843 return csv_path