Coverage for openhcs/processing/backends/pos_gen/ashlar_main_gpu.py: 7.8%

507 statements  

« prev     ^ index     » next       coverage.py v7.11.0, created at 2025-11-04 02:09 +0000

1""" 

2OpenHCS Interface for Ashlar GPU Stitching Algorithm 

3 

4Array-based EdgeAligner implementation that works directly with CuPy arrays 

5instead of file-based readers. This is the complete Ashlar algorithm modified 

6to accept arrays directly and run on GPU. 

7""" 

8from __future__ import annotations 

9import logging 

10import sys 

11from typing import TYPE_CHECKING, Tuple, List 

12import numpy as np 

13import networkx as nx 

14import scipy.spatial.distance 

15import sklearn.linear_model 

16import pandas as pd 

17 

18from openhcs.core.pipeline.function_contracts import special_inputs, special_outputs 

19from openhcs.core.memory.decorators import cupy as cupy_func 

20from openhcs.core.utils import optional_import 

21 

22# Import CuPy using the established optional import pattern 

23cp = optional_import("cupy") 

24 

25import warnings 

26 

27if TYPE_CHECKING: 27 ↛ 28line 27 didn't jump to line 28 because the condition on line 27 was never true

28 pass 

29 

30logger = logging.getLogger(__name__) 

31 

32 

33class DataWarning(Warning): 

34 """Warnings about the content of user-provided image data.""" 

35 pass 

36 

37 

38def warn_data(message): 

39 """Issue a warning about image data.""" 

40 warnings.warn(message, DataWarning) 

41 

42 

43class IntersectionGPU: 

44 """Calculate intersection region between two tiles - EXACT Ashlar implementation for GPU.""" 

45 

46 def __init__(self, corners1, corners2, min_size=0): 

47 if not cp: 

48 raise ImportError("CuPy is required for GPU intersection calculations") 

49 if isinstance(min_size, (int, float)): 

50 min_size = cp.full(2, min_size) 

51 elif not isinstance(min_size, cp.ndarray): 

52 min_size = cp.array(min_size) 

53 self._calculate(corners1, corners2, min_size) 

54 

55 def _calculate(self, corners1, corners2, min_size): 

56 """Calculate intersection parameters using EXACT Ashlar logic.""" 

57 # This is the EXACT logic from the original Ashlar Intersection class 

58 max_shape = (corners2 - corners1).max(axis=0) 

59 min_size = cp.clip(min_size, 1, max_shape) 

60 position = corners1.max(axis=0) 

61 initial_shape = cp.floor(corners2.min(axis=0) - position).astype(int) 

62 clipped_shape = cp.maximum(initial_shape, min_size) 

63 self.shape = cp.ceil(clipped_shape).astype(int) 

64 self.padding = self.shape - initial_shape 

65 self.offsets = cp.maximum(position - corners1 - self.padding, 0) 

66 offset_diff = self.offsets[1] - self.offsets[0] 

67 self.offset_diff_frac = offset_diff - cp.round(offset_diff) 

68 

69 

70def _get_window(shape): 

71 """Build a 2D Hann window (from Ashlar utils.get_window) on GPU.""" 

72 if cp is None: 

73 raise ImportError("CuPy is required for GPU window functions") 

74 # Build a 2D Hann window by taking the outer product of two 1-D windows. 

75 wy = cp.hanning(shape[0]).astype(cp.float32) 

76 wx = cp.hanning(shape[1]).astype(cp.float32) 

77 window = cp.outer(wy, wx) 

78 return window 

79 

80 

81# Precompute Laplacian kernel for whitening (equivalent to skimage.restoration.uft.laplacian) 

82_laplace_kernel_gpu = None 

83if cp: 83 ↛ 84line 83 didn't jump to line 84 because the condition on line 83 was never true

84 _laplace_kernel_gpu = cp.array([[0, -1, 0], [-1, 4, -1], [0, -1, 0]], dtype=cp.float32) 

85 

86 

87def whiten_gpu(img, sigma): 

88 """ 

89 Vectorized GPU whitening filter - EXACT match to Ashlar reference implementation. 

90 

91 This implements the same whitening as ashlar.utils.whiten() but optimized for GPU: 

92 - sigma=0: Uses Laplacian convolution (high-pass filter) 

93 - sigma>0: Uses Gaussian-Laplacian (LoG filter) 

94 

95 Args: 

96 img: CuPy array (2D image) 

97 sigma: Standard deviation for Gaussian kernel (0 = pure Laplacian) 

98 

99 Returns: 

100 CuPy array: Whitened image 

101 """ 

102 # Convert to float32 (matches reference) 

103 if not isinstance(img, cp.ndarray): 

104 img = cp.asarray(img) 

105 img = img.astype(cp.float32) 

106 

107 if sigma == 0: 

108 # Pure Laplacian convolution (high-pass filter) 

109 # Equivalent to scipy.ndimage.convolve(img, _laplace_kernel) 

110 from cupyx.scipy import ndimage as cp_ndimage 

111 output = cp_ndimage.convolve(img, _laplace_kernel_gpu, mode='reflect') 

112 else: 

113 # Gaussian-Laplacian (LoG filter) 

114 # Equivalent to scipy.ndimage.gaussian_laplace(img, sigma) 

115 from cupyx.scipy import ndimage as cp_ndimage 

116 output = cp_ndimage.gaussian_laplace(img, sigma) 

117 

118 return output 

119 

120 

121def whiten_gpu_vectorized(img_stack, sigma): 

122 """ 

123 Vectorized GPU whitening for multiple images simultaneously. 

124 

125 This processes an entire stack of images in parallel on GPU for maximum efficiency. 

126 

127 Args: 

128 img_stack: CuPy array of shape (N, H, W) - stack of N images 

129 sigma: Standard deviation for Gaussian kernel (0 = pure Laplacian) 

130 

131 Returns: 

132 CuPy array: Stack of whitened images with same shape as input 

133 """ 

134 if not isinstance(img_stack, cp.ndarray): 

135 img_stack = cp.asarray(img_stack) 

136 img_stack = img_stack.astype(cp.float32) 

137 

138 if sigma == 0: 

139 # Vectorized Laplacian convolution for entire stack 

140 from cupyx.scipy import ndimage as cp_ndimage 

141 # Process each image in the stack 

142 output_stack = cp.empty_like(img_stack) 

143 for i in range(img_stack.shape[0]): 

144 output_stack[i] = cp_ndimage.convolve(img_stack[i], _laplace_kernel_gpu, mode='reflect') 

145 else: 

146 # Vectorized Gaussian-Laplacian for entire stack 

147 from cupyx.scipy import ndimage as cp_ndimage 

148 output_stack = cp.empty_like(img_stack) 

149 for i in range(img_stack.shape[0]): 

150 output_stack[i] = cp_ndimage.gaussian_laplace(img_stack[i], sigma) 

151 

152 return output_stack 

153 

154 

155def ashlar_register_gpu(img1, img2, upsample=10): 

156 """ 

157 GPU register function using cuCIM - matches CPU version with windowing only. 

158 

159 This uses cuCIM's phase_cross_correlation which is the GPU equivalent 

160 of skimage.registration.phase_cross_correlation used in the CPU version. 

161 No whitening filter - just windowing like the CPU version. 

162 

163 Args: 

164 img1, img2: Input images 

165 upsample: Upsampling factor for phase correlation 

166 """ 

167 import cucim.skimage.registration 

168 

169 # Input validation (same as CPU version) 

170 if img1 is None or img2 is None: 

171 return cp.array([0.0, 0.0]), cp.inf 

172 

173 if img1.size == 0 or img2.size == 0: 

174 return cp.array([0.0, 0.0]), cp.inf 

175 

176 if img1.shape != img2.shape: 

177 return cp.array([0.0, 0.0]), cp.inf 

178 

179 if len(img1.shape) != 2: 

180 return cp.array([0.0, 0.0]), cp.inf 

181 

182 if img1.shape[0] < 1 or img1.shape[1] < 1: 

183 return cp.array([0.0, 0.0]), cp.inf 

184 

185 # Convert to CuPy arrays 

186 if not isinstance(img1, cp.ndarray): 

187 img1 = cp.asarray(img1) 

188 if not isinstance(img2, cp.ndarray): 

189 img2 = cp.asarray(img2) 

190 

191 # Convert to float32 and apply windowing (matches CPU version) 

192 img1w = img1.astype(cp.float32) * _get_window(img1.shape) 

193 img2w = img2.astype(cp.float32) * _get_window(img2.shape) 

194 

195 # Use cuCIM's phase cross correlation (GPU equivalent of skimage) 

196 try: 

197 shift, error, phase_diff = cucim.skimage.registration.phase_cross_correlation( 

198 img1w, img2w, upsample_factor=upsample 

199 ) 

200 

201 # Convert to numpy for consistency with CPU version 

202 shift = cp.asnumpy(shift) 

203 error = float(error) 

204 

205 # Only log high errors to avoid spam 

206 if error > 1.0: # High error threshold for Ashlar 

207 logger.warning(f"Ashlar GPU: HIGH CORRELATION ERROR - Error={error:.4f}, Shift=({shift[0]:.2f}, {shift[1]:.2f})") 

208 logger.warning(" This indicates poor overlap or image quality between tiles") 

209 

210 except Exception as e: 

211 # Fallback if correlation fails 

212 logger.error(f"Ashlar GPU: CORRELATION FAILED - Exception: {e}") 

213 logger.error(" Returning infinite error") 

214 shift = cp.array([0.0, 0.0]) 

215 error = cp.inf 

216 

217 return shift, error 

218 

219 

220 

221 

222 

223def ashlar_nccw_no_preprocessing_gpu(img1, img2): 

224 """ 

225 GPU nccw function - faithful to Ashlar but with better numerical stability. 

226 

227 This matches the CPU version but with improved precision handling for GPU. 

228 """ 

229 # Convert to CuPy arrays and float32 (equivalent to what whiten() does) 

230 if not isinstance(img1, cp.ndarray): 

231 img1 = cp.asarray(img1) 

232 if not isinstance(img2, cp.ndarray): 

233 img2 = cp.asarray(img2) 

234 

235 img1w = img1.astype(cp.float32) 

236 img2w = img2.astype(cp.float32) 

237 

238 correlation = float(cp.abs(cp.sum(img1w * img2w))) 

239 total_amplitude = float(cp.linalg.norm(img1w) * cp.linalg.norm(img2w)) 

240 

241 if correlation > 0 and total_amplitude > 0: 

242 diff = correlation - total_amplitude 

243 if diff <= 0: 

244 error = -cp.log(correlation / total_amplitude) 

245 elif diff < 1e-3: # Increased tolerance for GPU precision 

246 # This situation can occur due to numerical precision issues when 

247 # img1 and img2 are very nearly or exactly identical. If the 

248 # difference is small enough, let it slide. 

249 error = 0 

250 else: 

251 # Instead of raising error, return a large but finite error 

252 logger.warning(f"Ashlar GPU: NCCW numerical precision issue - diff={diff:.6f}, using error=100.0") 

253 error = 100.0 # Large error but not infinite 

254 else: 

255 logger.warning(f"Ashlar GPU: NCCW invalid correlation - correlation={correlation:.6f}, total_amplitude={total_amplitude:.6f}") 

256 error = cp.inf 

257 

258 # Log all NCCW results at INFO level for user visibility 

259 error_float = float(error) 

260 if error_float > 10.0: # High NCCW error threshold 

261 logger.warning(f"Ashlar GPU: HIGH NCCW ERROR - Error={error_float:.4f}") 

262 logger.warning(" This indicates poor image correlation between tiles") 

263 else: 

264 logger.info(f"Ashlar GPU: NCCW - Error={error_float:.4f}") 

265 

266 return error_float 

267 

268 

269 

270 

271 

272def ashlar_crop_gpu(img, offset, shape): 

273 """ 

274 EXACT Ashlar crop function (from ashlar.utils.crop) for GPU arrays with boundary validation. 

275 

276 Note that this only crops to the nearest whole-pixel offset. 

277 """ 

278 # Convert to CuPy if needed 

279 if not isinstance(img, cp.ndarray): 

280 img = cp.asarray(img) 

281 if not isinstance(offset, cp.ndarray): 

282 offset = cp.asarray(offset) 

283 if not isinstance(shape, cp.ndarray): 

284 shape = cp.asarray(shape) 

285 

286 # Validate inputs to prevent zero-sized arrays 

287 if cp.any(shape <= 0): 

288 raise ValueError(f"Invalid crop shape: {shape}. Shape must be positive.") 

289 

290 # Note that this only crops to the nearest whole-pixel offset. 

291 start = cp.round(offset).astype(int) 

292 end = start + shape 

293 

294 # Validate bounds to prevent invalid slicing 

295 img_shape = cp.array(img.shape) 

296 if cp.any(start < 0) or cp.any(end > img_shape): 

297 # Clip to valid bounds 

298 start = cp.maximum(start, 0) 

299 end = cp.minimum(end, img_shape) 

300 

301 # Recalculate shape after clipping 

302 new_shape = end - start 

303 if cp.any(new_shape <= 0): 

304 raise ValueError(f"Invalid crop region after bounds checking: start={start}, end={end}, img_shape={img_shape}") 

305 

306 img = img[start[0]:end[0], start[1]:end[1]] 

307 return img 

308 

309 

310 

311 

312 

313class ArrayEdgeAlignerGPU: 

314 """ 

315 Array-based EdgeAligner that implements the complete Ashlar algorithm 

316 but works directly with CuPy arrays instead of file readers and runs on GPU. 

317 """ 

318 

319 def __init__(self, image_stack, positions, tile_size, pixel_size=1.0, 

320 max_shift=15, alpha=0.01, max_error=None, 

321 randomize=False, verbose=False, upsample_factor=10, 

322 permutation_upsample=1, permutation_samples=1000, 

323 min_permutation_samples=10, max_permutation_tries=100, 

324 window_size_factor=0.1): 

325 """ 

326 Initialize array-based EdgeAligner for position calculation on GPU. 

327 

328 Args: 

329 image_stack: 3D numpy/cupy array (num_tiles, height, width) - preprocessed grayscale 

330 positions: 2D array of tile positions (num_tiles, 2) in pixels 

331 tile_size: Array [height, width] of tile dimensions 

332 pixel_size: Pixel size in micrometers (for max_shift conversion) 

333 max_shift: Maximum allowed shift in micrometers 

334 alpha: Alpha value for error threshold (lower = stricter) 

335 max_error: Explicit error threshold (None = auto-compute) 

336 randomize: Use random seed for permutation testing 

337 verbose: Enable verbose logging 

338 """ 

339 # Convert to CuPy arrays for GPU processing 

340 if not isinstance(image_stack, cp.ndarray): 

341 self.image_stack = cp.asarray(image_stack) 

342 else: 

343 self.image_stack = image_stack 

344 

345 if not isinstance(positions, cp.ndarray): 

346 self.positions = cp.asarray(positions, dtype=cp.float64) 

347 else: 

348 self.positions = positions.astype(cp.float64) 

349 

350 self.tile_size = cp.array(tile_size) 

351 self.pixel_size = pixel_size 

352 self.max_shift = max_shift 

353 self.max_shift_pixels = self.max_shift / self.pixel_size 

354 self.alpha = alpha 

355 self.max_error = max_error 

356 self.randomize = randomize 

357 self.verbose = verbose 

358 self.upsample_factor = upsample_factor 

359 self.permutation_upsample = permutation_upsample 

360 self.permutation_samples = permutation_samples 

361 self.min_permutation_samples = min_permutation_samples 

362 self.max_permutation_tries = max_permutation_tries 

363 self.window_size_factor = window_size_factor 

364 self._cache = {} 

365 self.errors_negative_sampled = cp.empty(0) 

366 

367 # Build neighbors graph (this uses CPU operations with NetworkX) 

368 self.neighbors_graph = self._build_neighbors_graph() 

369 

370 def _build_neighbors_graph(self): 

371 """Build graph of neighboring (overlapping) tiles.""" 

372 # Convert to CPU for scipy operations 

373 positions_cpu = cp.asnumpy(self.positions) 

374 tile_size_cpu = cp.asnumpy(self.tile_size) 

375 

376 pdist = scipy.spatial.distance.pdist(positions_cpu, metric='cityblock') 

377 sp = scipy.spatial.distance.squareform(pdist) 

378 max_distance = tile_size_cpu.max() + 1 

379 edges = zip(*np.nonzero((sp > 0) & (sp < max_distance))) 

380 graph = nx.from_edgelist(edges) 

381 graph.add_nodes_from(range(len(positions_cpu))) 

382 return graph 

383 

384 

385 def run(self): 

386 """Run the complete Ashlar algorithm.""" 

387 self.check_overlaps() 

388 self.compute_threshold() 

389 self.register_all() 

390 self.build_spanning_tree() 

391 self.calculate_positions() 

392 self.fit_model() 

393 

394 def check_overlaps(self): 

395 """Check if tiles actually overlap based on positions.""" 

396 overlaps = [] 

397 for t1, t2 in self.neighbors_graph.edges: 

398 overlap = self.tile_size - cp.abs(self.positions[t1] - self.positions[t2]) 

399 overlaps.append(overlap) 

400 

401 if overlaps: 

402 overlaps = cp.stack(overlaps) 

403 failures = cp.any(overlaps < 1, axis=1) 

404 failures_cpu = cp.asnumpy(failures) 

405 

406 if len(failures_cpu) and all(failures_cpu): 

407 warn_data("No tiles overlap, attempting alignment anyway.") 

408 elif any(failures_cpu): 

409 warn_data("Some neighboring tiles have zero overlap.") 

410 

411 def compute_threshold(self): 

412 """Compute error threshold using permutation testing.""" 

413 if self.max_error is not None: 

414 if self.verbose: 

415 print(" using explicit error threshold") 

416 return 

417 

418 edges = self.neighbors_graph.edges 

419 num_tiles = len(self.image_stack) 

420 

421 # If not enough tiles overlap to matter, skip this whole thing 

422 if len(edges) <= 1: 

423 self.max_error = np.inf 

424 return 

425 

426 widths = [] 

427 for t1, t2 in edges: 

428 shape = self.intersection(t1, t2).shape 

429 widths.append(cp.min(cp.array(shape))) 

430 

431 widths = cp.array(widths) 

432 w = int(cp.max(widths)) 

433 max_offset = int(self.tile_size[0]) - w 

434 

435 # Number of possible pairs minus number of actual neighbor pairs 

436 num_distant_pairs = num_tiles * (num_tiles - 1) // 2 - len(edges) 

437 

438 # Reduce permutation count for small datasets 

439 n = self.permutation_samples if num_distant_pairs > 8 else (num_distant_pairs + 1) * self.min_permutation_samples 

440 pairs = np.empty((n, 2), dtype=int) # Keep on CPU for random generation 

441 offsets = np.empty((n, 2), dtype=int) # Keep on CPU for random generation 

442 

443 # Generate n random non-overlapping image strips 

444 max_tries = self.max_permutation_tries 

445 if self.randomize is False: 

446 random_state = np.random.RandomState(0) 

447 else: 

448 random_state = np.random.RandomState() 

449 

450 for i in range(n): 

451 # Limit tries to avoid infinite loop in pathological cases 

452 for current_try in range(max_tries): 

453 t1, t2 = random_state.randint(num_tiles, size=2) 

454 o1, o2 = random_state.randint(max_offset, size=2) 

455 

456 # Check for non-overlapping strips and abort the retry loop 

457 if t1 != t2 and (t1, t2) not in edges: 

458 # Different, non-neighboring tiles -- always OK 

459 break 

460 elif t1 == t2 and abs(o1 - o2) > w: 

461 # Same tile OK if strips don't overlap within the image 

462 break 

463 elif (t1, t2) in edges: 

464 # Neighbors OK if either strip is entirely outside the 

465 # expected overlap region (based on nominal positions) 

466 its = self.intersection(t1, t2, cp.full(2, w)) 

467 ioff1, ioff2 = its.offsets[:, 0] 

468 if ( 

469 its.shape[0] > its.shape[1] 

470 or o1 < ioff1 - w or o1 > ioff1 + w 

471 or o2 < ioff2 - w or o2 > ioff2 + w 

472 ): 

473 break 

474 else: 

475 # Retries exhausted. This should be very rare. 

476 warn_data(f"Could not find non-overlapping strips in {max_tries} tries") 

477 pairs[i] = t1, t2 

478 offsets[i] = o1, o2 

479 

480 errors = cp.empty(n) 

481 for i, ((t1, t2), (offset1, offset2)) in enumerate(zip(pairs, offsets)): 

482 if self.verbose and (i % 10 == 9 or i == n - 1): 

483 sys.stdout.write(f'\r quantifying alignment error {i + 1}/{n}') 

484 sys.stdout.flush() 

485 img1 = self.image_stack[t1][offset1:offset1+w, :] 

486 img2 = self.image_stack[t2][offset2:offset2+w, :] 

487 _, errors[i] = ashlar_register_gpu(img1, img2, upsample=self.permutation_upsample) 

488 if self.verbose: 

489 print() 

490 self.errors_negative_sampled = errors 

491 self.max_error = float(cp.percentile(errors, self.alpha * 100)) 

492 

493 

494 def register_all(self): 

495 """Register all neighboring tile pairs.""" 

496 n = self.neighbors_graph.size() 

497 for i, (t1, t2) in enumerate(self.neighbors_graph.edges, 1): 

498 if self.verbose: 

499 sys.stdout.write(f'\r aligning edge {i}/{n}') 

500 sys.stdout.flush() 

501 self.register_pair(t1, t2) 

502 if self.verbose: 

503 print() 

504 self.all_errors = cp.array([x[1] for x in self._cache.values()]) 

505 

506 # Set error values above the threshold to infinity 

507 for k, v in self._cache.items(): 

508 shift_array = cp.array(v[0]) if not isinstance(v[0], cp.ndarray) else v[0] 

509 if v[1] > self.max_error or cp.any(cp.abs(shift_array) > self.max_shift_pixels): 

510 self._cache[k] = (v[0], cp.inf) 

511 

512 def register_pair(self, t1, t2): 

513 """Return relative shift between images and the alignment error.""" 

514 key = tuple(sorted((t1, t2))) 

515 try: 

516 shift, error = self._cache[key] 

517 except KeyError: 

518 # Test a series of increasing overlap window sizes to help avoid 

519 # missing alignments when the stage position error is large relative 

520 # to the tile overlap. Simply using a large overlap in all cases 

521 # limits the maximum achievable correlation thus increasing the 

522 # error metric, leading to worse overall results. The window size 

523 # starts at the nominal size and doubles until it's at least 10% of 

524 # the tile size. If the nominal overlap is already 10% or greater, 

525 # we only use that one size. 

526 try: 

527 smin = self.intersection(key[0], key[1]).shape 

528 smax = cp.round(self.tile_size * self.window_size_factor) 

529 sizes = [smin] 

530 while any(cp.array(sizes[-1]) < smax): 

531 sizes.append(cp.array(sizes[-1]) * 2) 

532 

533 # Try each window size and collect results 

534 results = [] 

535 for s in sizes: 

536 try: 

537 result = self._register(key[0], key[1], s) 

538 if result is not None: 

539 results.append(result) 

540 except Exception as e: 

541 if self.verbose: 

542 print(f" window size {s} failed: {e}") 

543 continue 

544 

545 if not results: 

546 # All window sizes failed, return large error 

547 shift = cp.array([0.0, 0.0]) 

548 error = cp.inf 

549 else: 

550 # Use the shift from the window size that gave the lowest error 

551 shift, _ = min(results, key=lambda r: r[1]) 

552 # Extract the images from the nominal overlap window but with the 

553 # shift applied to the second tile's position, and compute the error 

554 # metric on these images. This should be even lower than the error 

555 # computed above. 

556 try: 

557 _, o1, o2 = self.overlap(key[0], key[1], shift=shift) 

558 error = ashlar_nccw_no_preprocessing_gpu(o1, o2) 

559 except Exception as e: 

560 if self.verbose: 

561 print(f" final error computation failed: {e}") 

562 error = cp.inf 

563 

564 except Exception as e: 

565 if self.verbose: 

566 print(f" registration failed for tiles {key}: {e}") 

567 shift = cp.array([0.0, 0.0]) 

568 error = cp.inf 

569 

570 self._cache[key] = (shift, error) 

571 if t1 > t2: 

572 shift = -shift 

573 # Return copy of shift to prevent corruption of cached values 

574 return shift.copy(), error 

575 

576 def _register(self, t1, t2, min_size=0): 

577 """Register a single tile pair with given minimum size.""" 

578 try: 

579 its, img1, img2 = self.overlap(t1, t2, min_size) 

580 

581 # Validate that we got valid images 

582 if img1.size == 0 or img2.size == 0: 

583 if self.verbose: 

584 print(f" empty images for tiles {t1}, {t2} with min_size {min_size}") 

585 return None 

586 

587 # Account for padding, flipping the sign depending on the direction 

588 # between the tiles 

589 p1, p2 = self.positions[[t1, t2]] 

590 sx = 1 if p1[1] >= p2[1] else -1 

591 sy = 1 if p1[0] >= p2[0] else -1 

592 padding = cp.array(its.padding) * cp.array([sy, sx]) 

593 shift, error = ashlar_register_gpu(img1, img2, upsample=self.upsample_factor) 

594 shift = cp.array(shift) + padding 

595 return shift.get(), error 

596 except Exception as e: 

597 if self.verbose: 

598 print(f" _register failed for tiles {t1}, {t2}: {e}") 

599 return None 

600 

601 

602 def intersection(self, t1, t2, min_size=0, shift=None): 

603 """Calculate intersection region between two tiles.""" 

604 corners1 = self.positions[[t1, t2]].copy() 

605 if shift is not None: 

606 if not isinstance(shift, cp.ndarray): 

607 shift = cp.array(shift) 

608 corners1[1] += shift 

609 corners2 = corners1 + self.tile_size 

610 return IntersectionGPU(corners1, corners2, min_size) 

611 

612 def crop(self, tile_id, offset, shape): 

613 """Crop image from tile at given offset and shape.""" 

614 img = self.image_stack[tile_id] 

615 return ashlar_crop_gpu(img, offset, shape) 

616 

617 def overlap(self, t1, t2, min_size=0, shift=None): 

618 """Extract overlapping regions between two tiles.""" 

619 its = self.intersection(t1, t2, min_size, shift) 

620 

621 # Validate intersection shape before cropping 

622 if cp.any(its.shape <= 0): 

623 raise ValueError(f"Invalid intersection shape {its.shape} for tiles {t1}, {t2}") 

624 

625 img1 = self.crop(t1, its.offsets[0], its.shape) 

626 img2 = self.crop(t2, its.offsets[1], its.shape) 

627 return its, img1, img2 

628 

629 

630 

631 

632 

633 def build_spanning_tree(self): 

634 """Build minimum spanning tree using GPU Boruvka algorithm.""" 

635 # Import the Boruvka MST implementation 

636 from openhcs.processing.backends.pos_gen.mist.boruvka_mst import build_mst_gpu_boruvka 

637 

638 # Convert cache to Boruvka format 

639 valid_edges = [(t1, t2, shift, error) for (t1, t2), (shift, error) in self._cache.items() if cp.isfinite(error)] 

640 

641 if len(valid_edges) == 0: 

642 # No valid edges - create empty graph with all nodes 

643 self.spanning_tree = nx.Graph() 

644 self.spanning_tree.add_nodes_from(range(len(self.positions))) 

645 return 

646 

647 # Prepare arrays for Boruvka MST 

648 connection_from = cp.array([t1 for t1, t2, shift, error in valid_edges], dtype=cp.int32) 

649 connection_to = cp.array([t2 for t1, t2, shift, error in valid_edges], dtype=cp.int32) 

650 connection_dx = cp.array([shift[1] for t1, t2, shift, error in valid_edges], dtype=cp.float32) # x shift 

651 connection_dy = cp.array([shift[0] for t1, t2, shift, error in valid_edges], dtype=cp.float32) # y shift 

652 # Use negative error as quality (higher quality = lower error) 

653 connection_quality = cp.array([-error for t1, t2, shift, error in valid_edges], dtype=cp.float32) 

654 

655 num_nodes = len(self.positions) 

656 

657 try: 

658 # Run GPU Boruvka MST 

659 mst_result = build_mst_gpu_boruvka( 

660 connection_from, connection_to, connection_dx, connection_dy, 

661 connection_quality, num_nodes 

662 ) 

663 

664 # Convert back to NetworkX format for compatibility with rest of algorithm 

665 self.spanning_tree = nx.Graph() 

666 self.spanning_tree.add_nodes_from(range(num_nodes)) 

667 

668 for edge in mst_result['edges']: 

669 t1, t2 = edge['from'], edge['to'] 

670 # Reconstruct error from quality 

671 error = -edge['quality'] if 'quality' in edge else 0.0 

672 self.spanning_tree.add_edge(t1, t2, weight=error) 

673 

674 except Exception as e: 

675 # Fallback to NetworkX if Boruvka fails 

676 print(f"Boruvka MST failed, falling back to NetworkX: {e}") 

677 g = nx.Graph() 

678 g.add_nodes_from(self.neighbors_graph) 

679 g.add_weighted_edges_from( 

680 (t1, t2, error) 

681 for (t1, t2), (_, error) in self._cache.items() 

682 if cp.isfinite(error) 

683 ) 

684 spanning_tree = nx.Graph() 

685 spanning_tree.add_nodes_from(g) 

686 for c in nx.connected_components(g): 

687 cc = g.subgraph(c) 

688 center = nx.center(cc)[0] 

689 paths = nx.single_source_dijkstra_path(cc, center).values() 

690 for path in paths: 

691 nx.add_path(spanning_tree, path) 

692 self.spanning_tree = spanning_tree 

693 

694 def calculate_positions(self): 

695 """Calculate final positions from spanning tree.""" 

696 shifts = {} 

697 for c in nx.connected_components(self.spanning_tree): 

698 cc = self.spanning_tree.subgraph(c) 

699 center = nx.center(cc)[0] 

700 shifts[center] = cp.array([0, 0]) 

701 for edge in nx.traversal.bfs_edges(cc, center): 

702 source, dest = edge 

703 if source not in shifts: 

704 source, dest = dest, source 

705 shift = self.register_pair(source, dest)[0] 

706 shifts[dest] = shifts[source] + cp.array(shift) 

707 if shifts: 

708 self.shifts = cp.array([s for _, s in sorted(shifts.items())]) 

709 self.final_positions = self.positions + self.shifts 

710 else: 

711 # TODO: fill in shifts and positions with 0x2 arrays 

712 raise NotImplementedError("No images") 

713 

714 

715 def fit_model(self): 

716 """Fit linear model to handle disconnected components.""" 

717 components = sorted( 

718 nx.connected_components(self.spanning_tree), 

719 key=len, reverse=True 

720 ) 

721 # Fit LR model on positions of largest connected component 

722 cc0 = list(components[0]) 

723 self.lr = sklearn.linear_model.LinearRegression() 

724 

725 # Convert to CPU for sklearn operations 

726 positions_cpu = cp.asnumpy(self.positions[cc0]) 

727 final_positions_cpu = cp.asnumpy(self.final_positions[cc0]) 

728 self.lr.fit(positions_cpu, final_positions_cpu) 

729 

730 # Fix up degenerate transform matrix. This happens when the spanning 

731 # tree is completely edgeless or cc0's metadata positions fall in a 

732 # straight line. In this case we fall back to the identity transform. 

733 if np.linalg.det(self.lr.coef_) < 1e-3: 

734 warn_data( 

735 "Could not align enough edges, proceeding anyway with original" 

736 " stage positions." 

737 ) 

738 self.lr.coef_ = np.diag(np.ones(2)) 

739 self.lr.intercept_ = np.zeros(2) 

740 

741 # Adjust position of remaining components so their centroids match 

742 # the predictions of the model 

743 for cc in components[1:]: 

744 nodes = list(cc) 

745 centroid_m = cp.mean(self.positions[nodes], axis=0) 

746 centroid_f = cp.mean(self.final_positions[nodes], axis=0) 

747 

748 # Convert to CPU for prediction, then back to GPU 

749 centroid_m_cpu = cp.asnumpy(centroid_m).reshape(1, -1) 

750 shift_cpu = self.lr.predict(centroid_m_cpu)[0] - cp.asnumpy(centroid_f) 

751 shift = cp.array(shift_cpu) 

752 

753 self.final_positions[nodes] += shift 

754 

755 # Adjust positions and model intercept to put origin at 0,0 

756 self.origin = cp.min(self.final_positions, axis=0) 

757 self.final_positions -= self.origin 

758 self.lr.intercept_ -= cp.asnumpy(self.origin) 

759 

760 

761def _calculate_initial_positions_gpu(image_stack, grid_dims: tuple, overlap_ratio: float): 

762 """Calculate initial grid positions based on overlap ratio (GPU version).""" 

763 grid_rows, grid_cols = grid_dims 

764 

765 # Handle both numpy and cupy arrays 

766 if isinstance(image_stack, cp.ndarray): 

767 tile_height, tile_width = image_stack.shape[1:3] 

768 else: 

769 tile_height, tile_width = image_stack.shape[1:3] 

770 

771 spacing_factor = 1.0 - overlap_ratio 

772 

773 positions = [] 

774 for tile_idx in range(len(image_stack)): 

775 r = tile_idx // grid_cols 

776 c = tile_idx % grid_cols 

777 

778 y_pos = r * tile_height * spacing_factor 

779 x_pos = c * tile_width * spacing_factor 

780 positions.append([y_pos, x_pos]) 

781 

782 return cp.array(positions, dtype=cp.float64) 

783 

784 

785def _convert_ashlar_positions_to_openhcs_gpu(ashlar_positions) -> List[Tuple[float, float]]: 

786 """Convert Ashlar positions to OpenHCS format (GPU version).""" 

787 # Convert to CPU if needed 

788 if isinstance(ashlar_positions, cp.ndarray): 

789 ashlar_positions = cp.asnumpy(ashlar_positions) 

790 

791 positions = [] 

792 for tile_idx in range(len(ashlar_positions)): 

793 y, x = ashlar_positions[tile_idx] 

794 positions.append((float(x), float(y))) # OpenHCS uses (x, y) format 

795 return positions 

796 

797 

798@special_inputs("grid_dimensions") 

799@special_outputs("positions") 

800@cupy_func 

801def ashlar_compute_tile_positions_gpu( 

802 image_stack, 

803 grid_dimensions: Tuple[int, int], 

804 overlap_ratio: float = 0.1, 

805 max_shift: float = 15.0, 

806 stitch_alpha: float = 0.01, 

807 max_error: float = None, 

808 randomize: bool = False, 

809 verbose: bool = False, 

810 upsample_factor: int = 10, 

811 permutation_upsample: int = 1, 

812 permutation_samples: int = 1000, 

813 min_permutation_samples: int = 10, 

814 max_permutation_tries: int = 100, 

815 window_size_factor: float = 0.1, 

816 **kwargs 

817) -> Tuple[np.ndarray, List[Tuple[float, float]]]: 

818 """ 

819 Compute tile positions using the Ashlar algorithm on GPU - matches CPU version. 

820 

821 This function implements the Ashlar edge-based stitching algorithm using GPU acceleration. 

822 It performs position calculation with minimal preprocessing (windowing only, no whitening) 

823 to match the CPU version behavior. 

824 

825 Args: 

826 image_stack: 3D numpy/cupy array of shape (num_tiles, height, width) containing preprocessed 

827 grayscale images. Each slice [i] should be a single-channel 2D image ready 

828 for correlation analysis. No further preprocessing will be applied. 

829 

830 grid_dimensions: Tuple of (grid_rows, grid_cols) specifying the logical arrangement of 

831 tiles. For example, (2, 3) means 2 rows and 3 columns of tiles, for a 

832 total of 6 tiles. Must match the number of images in image_stack. 

833 

834 overlap_ratio: Expected fractional overlap between adjacent tiles (0.0-1.0). Default 0.1 

835 means 10% overlap. This is used to calculate initial grid positions and 

836 should match the actual overlap in your microscopy data. Typical values: 

837 - 0.05-0.15 for well-controlled microscopes 

838 - 0.15-0.25 for less precise stages 

839 

840 max_shift: Maximum allowed shift correction in micrometers. Default 15.0. This limits 

841 how far tiles can be moved from their initial grid positions during alignment. 

842 Should be set based on your microscope's stage accuracy: 

843 - 5-15 μm for high-precision stages 

844 - 15-50 μm for standard stages 

845 - 50+ μm for low-precision or manual stages 

846 

847 stitch_alpha: Alpha value for statistical error threshold computation (0.0-1.0). Default 

848 0.01 means 1% false positive rate. Lower values are stricter and reject more 

849 alignments, higher values are more permissive. This controls the trade-off 

850 between alignment quality and success rate: 

851 - 0.001-0.01: Very strict, high quality alignments only 

852 - 0.01-0.05: Balanced (recommended for most data) 

853 - 0.05-0.1: Permissive, accepts lower quality alignments 

854 

855 max_error: Explicit error threshold for rejecting alignments (None = auto-compute). 

856 When None (default), the threshold is computed automatically using permutation 

857 testing. Set to a specific value to override automatic computation. Higher 

858 values accept more alignments, lower values are stricter. 

859 

860 randomize: Whether to use random seed for permutation testing (bool). Default False uses 

861 a fixed seed for reproducible results. Set True for different random sampling 

862 in each run. Generally should be False for consistent results. 

863 

864 verbose: Enable detailed progress logging (bool). Default False. When True, prints 

865 progress information including permutation testing, edge alignment, and 

866 spanning tree construction. Useful for debugging and monitoring progress 

867 on large datasets. 

868 

869 upsample_factor: Sub-pixel accuracy factor for phase cross correlation (int). Default 10. 

870 Higher values provide better sub-pixel accuracy but increase computation time. 

871 Range: 1-100+. Values of 10-50 are typical for high-accuracy stitching. 

872 - 1: Pixel-level accuracy (fastest) 

873 - 10: 0.1 pixel accuracy (balanced) 

874 - 50: 0.02 pixel accuracy (high precision) 

875 

876 permutation_upsample: Upsample factor for permutation testing (int). Default 1. 

877 Lower than upsample_factor for speed during threshold computation. 

878 Usually kept at 1 since permutation testing doesn't need sub-pixel accuracy. 

879 

880 permutation_samples: Number of random samples for error threshold computation (int). Default 1000. 

881 Higher values give more accurate thresholds but slower computation. 

882 Automatically reduced for small datasets to avoid infinite loops. 

883 

884 min_permutation_samples: Minimum permutation samples for small datasets (int). Default 10. 

885 When there are few non-overlapping pairs, this sets the minimum 

886 number of samples to ensure statistical validity. 

887 

888 max_permutation_tries: Maximum attempts to find non-overlapping strips (int). Default 100. 

889 Prevents infinite loops in pathological cases where valid strips 

890 are hard to find. Rarely needs adjustment. 

891 

892 window_size_factor: Fraction of tile size for maximum window size (float). Default 0.1. 

893 Controls the largest overlap window tested during progressive sizing. 

894 Larger values allow detection of bigger stage errors but may reduce 

895 correlation quality. Range: 0.05-0.2 typical. 

896 

897 filter_sigma: Whitening filter sigma for preprocessing (float). Default 0. 

898 Controls the whitening filter applied before correlation: 

899 - 0: Pure Laplacian filter (high-pass, matches original Ashlar) 

900 - >0: Gaussian-Laplacian (LoG) filter with specified sigma 

901 - Typical values: 0-2.0 for most microscopy data 

902 

903 **kwargs: Additional parameters (ignored). Allows compatibility with other stitching 

904 algorithms that may have different parameter sets. 

905 

906 Returns: 

907 Tuple of (image_stack, positions) where: 

908 - image_stack: The original input image array (unchanged) 

909 - positions: List of (x, y) position tuples in OpenHCS format, one per tile. 

910 Positions are in pixel coordinates with (0, 0) at the top-left. 

911 The positions represent the optimal tile placement after Ashlar 

912 alignment, accounting for stage errors and image correlation. 

913 

914 Raises: 

915 Exception: If the Ashlar algorithm fails (e.g., insufficient overlap, correlation 

916 errors), the function automatically falls back to grid-based positioning 

917 using the specified overlap_ratio. 

918 

919 Notes: 

920 - This implementation contains the complete Ashlar algorithm including whitening 

921 filter preprocessing, permutation testing, progressive window sizing, minimum 

922 spanning tree construction, and linear model fitting for disconnected components. 

923 - The correlation functions are identical to original Ashlar including proper 

924 whitening/filtering preprocessing as specified by filter_sigma parameter. 

925 - For best results, ensure your image_stack contains single-channel grayscale 

926 images. The whitening filter will be applied automatically during correlation. 

927 """ 

928 grid_rows, grid_cols = grid_dimensions 

929 

930 if verbose: 

931 logger.info(f"Ashlar GPU: Processing {grid_rows}x{grid_cols} grid with {len(image_stack)} tiles") 

932 

933 try: 

934 # Convert to CuPy array if needed 

935 if not isinstance(image_stack, cp.ndarray): 

936 image_stack_gpu = cp.asarray(image_stack) 

937 else: 

938 image_stack_gpu = image_stack 

939 

940 # Calculate initial grid positions 

941 initial_positions = _calculate_initial_positions_gpu(image_stack_gpu, grid_dimensions, overlap_ratio) 

942 tile_size = cp.array(image_stack_gpu.shape[1:3]) # (height, width) 

943 

944 # Create and run ArrayEdgeAlignerGPU with complete Ashlar algorithm 

945 logger.info("Running complete Ashlar edge-based stitching algorithm on GPU") 

946 aligner = ArrayEdgeAlignerGPU( 

947 image_stack=image_stack_gpu, 

948 positions=initial_positions, 

949 tile_size=tile_size, 

950 pixel_size=1.0, # Assume 1 micrometer per pixel if not specified 

951 max_shift=max_shift, 

952 alpha=stitch_alpha, 

953 max_error=max_error, 

954 randomize=randomize, 

955 verbose=verbose, 

956 upsample_factor=upsample_factor, 

957 permutation_upsample=permutation_upsample, 

958 permutation_samples=permutation_samples, 

959 min_permutation_samples=min_permutation_samples, 

960 max_permutation_tries=max_permutation_tries, 

961 window_size_factor=window_size_factor 

962 ) 

963 

964 # Run the complete algorithm 

965 aligner.run() 

966 

967 # Convert to OpenHCS format 

968 positions = _convert_ashlar_positions_to_openhcs_gpu(aligner.final_positions) 

969 

970 # Convert result back to original format (CPU if input was CPU) 

971 if not isinstance(image_stack, cp.ndarray): 

972 result_image_stack = cp.asnumpy(image_stack_gpu) 

973 else: 

974 result_image_stack = image_stack_gpu 

975 

976 logger.info("Ashlar GPU algorithm completed successfully") 

977 

978 except Exception as e: 

979 logger.error(f"Ashlar GPU algorithm failed: {e}") 

980 # Fallback to grid positions if Ashlar fails 

981 logger.warning("Falling back to grid-based positioning") 

982 positions = [] 

983 

984 # Use original image_stack for fallback dimensions 

985 if isinstance(image_stack, cp.ndarray): 

986 tile_height, tile_width = image_stack.shape[1:3] 

987 else: 

988 tile_height, tile_width = image_stack.shape[1:3] 

989 

990 spacing_factor = 1.0 - overlap_ratio 

991 

992 for tile_idx in range(len(image_stack)): 

993 r = tile_idx // grid_cols 

994 c = tile_idx % grid_cols 

995 x_pos = c * tile_width * spacing_factor 

996 y_pos = r * tile_height * spacing_factor 

997 positions.append((float(x_pos), float(y_pos))) 

998 

999 # Set result_image_stack for fallback case 

1000 if not isinstance(image_stack, cp.ndarray): 

1001 result_image_stack = image_stack 

1002 else: 

1003 result_image_stack = image_stack 

1004 

1005 logger.info(f"Ashlar GPU: Completed processing {len(positions)} tile positions") 

1006 

1007 return result_image_stack, positions 

1008 

1009 

1010def materialize_ashlar_gpu_positions(data: List[Tuple[float, float]], path: str, filemanager) -> str: 

1011 """Materialize Ashlar GPU tile positions as scientific CSV with grid metadata.""" 

1012 csv_path = path.replace('.pkl', '_ashlar_positions_gpu.csv') 

1013 

1014 df = pd.DataFrame(data, columns=['x_position_um', 'y_position_um']) 

1015 df['tile_id'] = range(len(df)) 

1016 

1017 # Estimate grid dimensions from position layout 

1018 unique_x = sorted(df['x_position_um'].unique()) 

1019 unique_y = sorted(df['y_position_um'].unique()) 

1020 

1021 grid_cols = len(unique_x) 

1022 grid_rows = len(unique_y) 

1023 

1024 # Add grid coordinates 

1025 df['grid_row'] = df.index // grid_cols 

1026 df['grid_col'] = df.index % grid_cols 

1027 

1028 # Add spacing information 

1029 if len(unique_x) > 1: 

1030 x_spacing = unique_x[1] - unique_x[0] 

1031 df['x_spacing_um'] = x_spacing 

1032 else: 

1033 df['x_spacing_um'] = 0 

1034 

1035 if len(unique_y) > 1: 

1036 y_spacing = unique_y[1] - unique_y[0] 

1037 df['y_spacing_um'] = y_spacing 

1038 else: 

1039 df['y_spacing_um'] = 0 

1040 

1041 # Add metadata 

1042 df['algorithm'] = 'ashlar_gpu' 

1043 df['grid_dimensions'] = f"{grid_rows}x{grid_cols}" 

1044 

1045 csv_content = df.to_csv(index=False) 

1046 filemanager.save(csv_content, csv_path, "disk") 

1047 return csv_path