Coverage for openhcs/processing/backends/pos_gen/ashlar_main_gpu.py: 7.8%

508 statements  

« prev     ^ index     » next       coverage.py v7.10.3, created at 2025-08-14 05:57 +0000

1""" 

2OpenHCS Interface for Ashlar GPU Stitching Algorithm 

3 

4Array-based EdgeAligner implementation that works directly with CuPy arrays 

5instead of file-based readers. This is the complete Ashlar algorithm modified 

6to accept arrays directly and run on GPU. 

7""" 

8from __future__ import annotations 

9import logging 

10import sys 

11from typing import TYPE_CHECKING, Tuple, List 

12import numpy as np 

13import networkx as nx 

14import scipy.spatial.distance 

15import sklearn.linear_model 

16import pandas as pd 

17 

18from openhcs.core.pipeline.function_contracts import special_inputs, special_outputs 

19from openhcs.core.memory.decorators import cupy as cupy_func 

20from openhcs.core.utils import optional_import 

21 

22# Import CuPy using the established optional import pattern 

23cp = optional_import("cupy") 

24 

25import warnings 

26 

27if TYPE_CHECKING: 27 ↛ 28line 27 didn't jump to line 28 because the condition on line 27 was never true

28 pass 

29 

30logger = logging.getLogger(__name__) 

31 

32 

33class DataWarning(Warning): 

34 """Warnings about the content of user-provided image data.""" 

35 pass 

36 

37 

38def warn_data(message): 

39 """Issue a warning about image data.""" 

40 warnings.warn(message, DataWarning) 

41 

42 

43class IntersectionGPU: 

44 """Calculate intersection region between two tiles - EXACT Ashlar implementation for GPU.""" 

45 

46 def __init__(self, corners1, corners2, min_size=0): 

47 if not cp: 

48 raise ImportError("CuPy is required for GPU intersection calculations") 

49 if isinstance(min_size, (int, float)): 

50 min_size = cp.full(2, min_size) 

51 elif not isinstance(min_size, cp.ndarray): 

52 min_size = cp.array(min_size) 

53 self._calculate(corners1, corners2, min_size) 

54 

55 def _calculate(self, corners1, corners2, min_size): 

56 """Calculate intersection parameters using EXACT Ashlar logic.""" 

57 # This is the EXACT logic from the original Ashlar Intersection class 

58 max_shape = (corners2 - corners1).max(axis=0) 

59 min_size = cp.clip(min_size, 1, max_shape) 

60 position = corners1.max(axis=0) 

61 initial_shape = cp.floor(corners2.min(axis=0) - position).astype(int) 

62 clipped_shape = cp.maximum(initial_shape, min_size) 

63 self.shape = cp.ceil(clipped_shape).astype(int) 

64 self.padding = self.shape - initial_shape 

65 self.offsets = cp.maximum(position - corners1 - self.padding, 0) 

66 offset_diff = self.offsets[1] - self.offsets[0] 

67 self.offset_diff_frac = offset_diff - cp.round(offset_diff) 

68 

69 

70def _get_window(shape): 

71 """Build a 2D Hann window (from Ashlar utils.get_window) on GPU.""" 

72 if cp is None: 

73 raise ImportError("CuPy is required for GPU window functions") 

74 # Build a 2D Hann window by taking the outer product of two 1-D windows. 

75 wy = cp.hanning(shape[0]).astype(cp.float32) 

76 wx = cp.hanning(shape[1]).astype(cp.float32) 

77 window = cp.outer(wy, wx) 

78 return window 

79 

80 

81# Precompute Laplacian kernel for whitening (equivalent to skimage.restoration.uft.laplacian) 

82_laplace_kernel_gpu = None 

83if cp: 83 ↛ 84line 83 didn't jump to line 84 because the condition on line 83 was never true

84 _laplace_kernel_gpu = cp.array([[0, -1, 0], [-1, 4, -1], [0, -1, 0]], dtype=cp.float32) 

85 

86 

87def whiten_gpu(img, sigma): 

88 """ 

89 Vectorized GPU whitening filter - EXACT match to Ashlar reference implementation. 

90 

91 This implements the same whitening as ashlar.utils.whiten() but optimized for GPU: 

92 - sigma=0: Uses Laplacian convolution (high-pass filter) 

93 - sigma>0: Uses Gaussian-Laplacian (LoG filter) 

94 

95 Args: 

96 img: CuPy array (2D image) 

97 sigma: Standard deviation for Gaussian kernel (0 = pure Laplacian) 

98 

99 Returns: 

100 CuPy array: Whitened image 

101 """ 

102 # Convert to float32 (matches reference) 

103 if not isinstance(img, cp.ndarray): 

104 img = cp.asarray(img) 

105 img = img.astype(cp.float32) 

106 

107 if sigma == 0: 

108 # Pure Laplacian convolution (high-pass filter) 

109 # Equivalent to scipy.ndimage.convolve(img, _laplace_kernel) 

110 from cupyx.scipy import ndimage as cp_ndimage 

111 output = cp_ndimage.convolve(img, _laplace_kernel_gpu, mode='reflect') 

112 else: 

113 # Gaussian-Laplacian (LoG filter) 

114 # Equivalent to scipy.ndimage.gaussian_laplace(img, sigma) 

115 from cupyx.scipy import ndimage as cp_ndimage 

116 output = cp_ndimage.gaussian_laplace(img, sigma) 

117 

118 return output 

119 

120 

121def whiten_gpu_vectorized(img_stack, sigma): 

122 """ 

123 Vectorized GPU whitening for multiple images simultaneously. 

124 

125 This processes an entire stack of images in parallel on GPU for maximum efficiency. 

126 

127 Args: 

128 img_stack: CuPy array of shape (N, H, W) - stack of N images 

129 sigma: Standard deviation for Gaussian kernel (0 = pure Laplacian) 

130 

131 Returns: 

132 CuPy array: Stack of whitened images with same shape as input 

133 """ 

134 if not isinstance(img_stack, cp.ndarray): 

135 img_stack = cp.asarray(img_stack) 

136 img_stack = img_stack.astype(cp.float32) 

137 

138 if sigma == 0: 

139 # Vectorized Laplacian convolution for entire stack 

140 from cupyx.scipy import ndimage as cp_ndimage 

141 # Process each image in the stack 

142 output_stack = cp.empty_like(img_stack) 

143 for i in range(img_stack.shape[0]): 

144 output_stack[i] = cp_ndimage.convolve(img_stack[i], _laplace_kernel_gpu, mode='reflect') 

145 else: 

146 # Vectorized Gaussian-Laplacian for entire stack 

147 from cupyx.scipy import ndimage as cp_ndimage 

148 output_stack = cp.empty_like(img_stack) 

149 for i in range(img_stack.shape[0]): 

150 output_stack[i] = cp_ndimage.gaussian_laplace(img_stack[i], sigma) 

151 

152 return output_stack 

153 

154 

155def ashlar_register_gpu(img1, img2, upsample=10): 

156 """ 

157 GPU register function using cuCIM - matches CPU version with windowing only. 

158 

159 This uses cuCIM's phase_cross_correlation which is the GPU equivalent 

160 of skimage.registration.phase_cross_correlation used in the CPU version. 

161 No whitening filter - just windowing like the CPU version. 

162 

163 Args: 

164 img1, img2: Input images 

165 upsample: Upsampling factor for phase correlation 

166 """ 

167 import itertools 

168 import cucim.skimage.registration 

169 

170 # Input validation (same as CPU version) 

171 if img1 is None or img2 is None: 

172 return cp.array([0.0, 0.0]), cp.inf 

173 

174 if img1.size == 0 or img2.size == 0: 

175 return cp.array([0.0, 0.0]), cp.inf 

176 

177 if img1.shape != img2.shape: 

178 return cp.array([0.0, 0.0]), cp.inf 

179 

180 if len(img1.shape) != 2: 

181 return cp.array([0.0, 0.0]), cp.inf 

182 

183 if img1.shape[0] < 1 or img1.shape[1] < 1: 

184 return cp.array([0.0, 0.0]), cp.inf 

185 

186 # Convert to CuPy arrays 

187 if not isinstance(img1, cp.ndarray): 

188 img1 = cp.asarray(img1) 

189 if not isinstance(img2, cp.ndarray): 

190 img2 = cp.asarray(img2) 

191 

192 # Convert to float32 and apply windowing (matches CPU version) 

193 img1w = img1.astype(cp.float32) * _get_window(img1.shape) 

194 img2w = img2.astype(cp.float32) * _get_window(img2.shape) 

195 

196 # Use cuCIM's phase cross correlation (GPU equivalent of skimage) 

197 try: 

198 shift, error, phase_diff = cucim.skimage.registration.phase_cross_correlation( 

199 img1w, img2w, upsample_factor=upsample 

200 ) 

201 

202 # Convert to numpy for consistency with CPU version 

203 shift = cp.asnumpy(shift) 

204 error = float(error) 

205 

206 # Only log high errors to avoid spam 

207 if error > 1.0: # High error threshold for Ashlar 

208 logger.warning(f"Ashlar GPU: HIGH CORRELATION ERROR - Error={error:.4f}, Shift=({shift[0]:.2f}, {shift[1]:.2f})") 

209 logger.warning(f" This indicates poor overlap or image quality between tiles") 

210 

211 except Exception as e: 

212 # Fallback if correlation fails 

213 logger.error(f"Ashlar GPU: CORRELATION FAILED - Exception: {e}") 

214 logger.error(f" Returning infinite error") 

215 shift = cp.array([0.0, 0.0]) 

216 error = cp.inf 

217 

218 return shift, error 

219 

220 

221 

222 

223 

224def ashlar_nccw_no_preprocessing_gpu(img1, img2): 

225 """ 

226 GPU nccw function - faithful to Ashlar but with better numerical stability. 

227 

228 This matches the CPU version but with improved precision handling for GPU. 

229 """ 

230 # Convert to CuPy arrays and float32 (equivalent to what whiten() does) 

231 if not isinstance(img1, cp.ndarray): 

232 img1 = cp.asarray(img1) 

233 if not isinstance(img2, cp.ndarray): 

234 img2 = cp.asarray(img2) 

235 

236 img1w = img1.astype(cp.float32) 

237 img2w = img2.astype(cp.float32) 

238 

239 correlation = float(cp.abs(cp.sum(img1w * img2w))) 

240 total_amplitude = float(cp.linalg.norm(img1w) * cp.linalg.norm(img2w)) 

241 

242 if correlation > 0 and total_amplitude > 0: 

243 diff = correlation - total_amplitude 

244 if diff <= 0: 

245 error = -cp.log(correlation / total_amplitude) 

246 elif diff < 1e-3: # Increased tolerance for GPU precision 

247 # This situation can occur due to numerical precision issues when 

248 # img1 and img2 are very nearly or exactly identical. If the 

249 # difference is small enough, let it slide. 

250 error = 0 

251 else: 

252 # Instead of raising error, return a large but finite error 

253 logger.warning(f"Ashlar GPU: NCCW numerical precision issue - diff={diff:.6f}, using error=100.0") 

254 error = 100.0 # Large error but not infinite 

255 else: 

256 logger.warning(f"Ashlar GPU: NCCW invalid correlation - correlation={correlation:.6f}, total_amplitude={total_amplitude:.6f}") 

257 error = cp.inf 

258 

259 # Log all NCCW results at INFO level for user visibility 

260 error_float = float(error) 

261 if error_float > 10.0: # High NCCW error threshold 

262 logger.warning(f"Ashlar GPU: HIGH NCCW ERROR - Error={error_float:.4f}") 

263 logger.warning(f" This indicates poor image correlation between tiles") 

264 else: 

265 logger.info(f"Ashlar GPU: NCCW - Error={error_float:.4f}") 

266 

267 return error_float 

268 

269 

270 

271 

272 

273def ashlar_crop_gpu(img, offset, shape): 

274 """ 

275 EXACT Ashlar crop function (from ashlar.utils.crop) for GPU arrays with boundary validation. 

276 

277 Note that this only crops to the nearest whole-pixel offset. 

278 """ 

279 # Convert to CuPy if needed 

280 if not isinstance(img, cp.ndarray): 

281 img = cp.asarray(img) 

282 if not isinstance(offset, cp.ndarray): 

283 offset = cp.asarray(offset) 

284 if not isinstance(shape, cp.ndarray): 

285 shape = cp.asarray(shape) 

286 

287 # Validate inputs to prevent zero-sized arrays 

288 if cp.any(shape <= 0): 

289 raise ValueError(f"Invalid crop shape: {shape}. Shape must be positive.") 

290 

291 # Note that this only crops to the nearest whole-pixel offset. 

292 start = cp.round(offset).astype(int) 

293 end = start + shape 

294 

295 # Validate bounds to prevent invalid slicing 

296 img_shape = cp.array(img.shape) 

297 if cp.any(start < 0) or cp.any(end > img_shape): 

298 # Clip to valid bounds 

299 start = cp.maximum(start, 0) 

300 end = cp.minimum(end, img_shape) 

301 

302 # Recalculate shape after clipping 

303 new_shape = end - start 

304 if cp.any(new_shape <= 0): 

305 raise ValueError(f"Invalid crop region after bounds checking: start={start}, end={end}, img_shape={img_shape}") 

306 

307 img = img[start[0]:end[0], start[1]:end[1]] 

308 return img 

309 

310 

311 

312 

313 

314class ArrayEdgeAlignerGPU: 

315 """ 

316 Array-based EdgeAligner that implements the complete Ashlar algorithm 

317 but works directly with CuPy arrays instead of file readers and runs on GPU. 

318 """ 

319 

320 def __init__(self, image_stack, positions, tile_size, pixel_size=1.0, 

321 max_shift=15, alpha=0.01, max_error=None, 

322 randomize=False, verbose=False, upsample_factor=10, 

323 permutation_upsample=1, permutation_samples=1000, 

324 min_permutation_samples=10, max_permutation_tries=100, 

325 window_size_factor=0.1): 

326 """ 

327 Initialize array-based EdgeAligner for position calculation on GPU. 

328 

329 Args: 

330 image_stack: 3D numpy/cupy array (num_tiles, height, width) - preprocessed grayscale 

331 positions: 2D array of tile positions (num_tiles, 2) in pixels 

332 tile_size: Array [height, width] of tile dimensions 

333 pixel_size: Pixel size in micrometers (for max_shift conversion) 

334 max_shift: Maximum allowed shift in micrometers 

335 alpha: Alpha value for error threshold (lower = stricter) 

336 max_error: Explicit error threshold (None = auto-compute) 

337 randomize: Use random seed for permutation testing 

338 verbose: Enable verbose logging 

339 """ 

340 # Convert to CuPy arrays for GPU processing 

341 if not isinstance(image_stack, cp.ndarray): 

342 self.image_stack = cp.asarray(image_stack) 

343 else: 

344 self.image_stack = image_stack 

345 

346 if not isinstance(positions, cp.ndarray): 

347 self.positions = cp.asarray(positions, dtype=cp.float64) 

348 else: 

349 self.positions = positions.astype(cp.float64) 

350 

351 self.tile_size = cp.array(tile_size) 

352 self.pixel_size = pixel_size 

353 self.max_shift = max_shift 

354 self.max_shift_pixels = self.max_shift / self.pixel_size 

355 self.alpha = alpha 

356 self.max_error = max_error 

357 self.randomize = randomize 

358 self.verbose = verbose 

359 self.upsample_factor = upsample_factor 

360 self.permutation_upsample = permutation_upsample 

361 self.permutation_samples = permutation_samples 

362 self.min_permutation_samples = min_permutation_samples 

363 self.max_permutation_tries = max_permutation_tries 

364 self.window_size_factor = window_size_factor 

365 self._cache = {} 

366 self.errors_negative_sampled = cp.empty(0) 

367 

368 # Build neighbors graph (this uses CPU operations with NetworkX) 

369 self.neighbors_graph = self._build_neighbors_graph() 

370 

371 def _build_neighbors_graph(self): 

372 """Build graph of neighboring (overlapping) tiles.""" 

373 # Convert to CPU for scipy operations 

374 positions_cpu = cp.asnumpy(self.positions) 

375 tile_size_cpu = cp.asnumpy(self.tile_size) 

376 

377 pdist = scipy.spatial.distance.pdist(positions_cpu, metric='cityblock') 

378 sp = scipy.spatial.distance.squareform(pdist) 

379 max_distance = tile_size_cpu.max() + 1 

380 edges = zip(*np.nonzero((sp > 0) & (sp < max_distance))) 

381 graph = nx.from_edgelist(edges) 

382 graph.add_nodes_from(range(len(positions_cpu))) 

383 return graph 

384 

385 

386 def run(self): 

387 """Run the complete Ashlar algorithm.""" 

388 self.check_overlaps() 

389 self.compute_threshold() 

390 self.register_all() 

391 self.build_spanning_tree() 

392 self.calculate_positions() 

393 self.fit_model() 

394 

395 def check_overlaps(self): 

396 """Check if tiles actually overlap based on positions.""" 

397 overlaps = [] 

398 for t1, t2 in self.neighbors_graph.edges: 

399 overlap = self.tile_size - cp.abs(self.positions[t1] - self.positions[t2]) 

400 overlaps.append(overlap) 

401 

402 if overlaps: 

403 overlaps = cp.stack(overlaps) 

404 failures = cp.any(overlaps < 1, axis=1) 

405 failures_cpu = cp.asnumpy(failures) 

406 

407 if len(failures_cpu) and all(failures_cpu): 

408 warn_data("No tiles overlap, attempting alignment anyway.") 

409 elif any(failures_cpu): 

410 warn_data("Some neighboring tiles have zero overlap.") 

411 

412 def compute_threshold(self): 

413 """Compute error threshold using permutation testing.""" 

414 if self.max_error is not None: 

415 if self.verbose: 

416 print(" using explicit error threshold") 

417 return 

418 

419 edges = self.neighbors_graph.edges 

420 num_tiles = len(self.image_stack) 

421 

422 # If not enough tiles overlap to matter, skip this whole thing 

423 if len(edges) <= 1: 

424 self.max_error = np.inf 

425 return 

426 

427 widths = [] 

428 for t1, t2 in edges: 

429 shape = self.intersection(t1, t2).shape 

430 widths.append(cp.min(cp.array(shape))) 

431 

432 widths = cp.array(widths) 

433 w = int(cp.max(widths)) 

434 max_offset = int(self.tile_size[0]) - w 

435 

436 # Number of possible pairs minus number of actual neighbor pairs 

437 num_distant_pairs = num_tiles * (num_tiles - 1) // 2 - len(edges) 

438 

439 # Reduce permutation count for small datasets 

440 n = self.permutation_samples if num_distant_pairs > 8 else (num_distant_pairs + 1) * self.min_permutation_samples 

441 pairs = np.empty((n, 2), dtype=int) # Keep on CPU for random generation 

442 offsets = np.empty((n, 2), dtype=int) # Keep on CPU for random generation 

443 

444 # Generate n random non-overlapping image strips 

445 max_tries = self.max_permutation_tries 

446 if self.randomize is False: 

447 random_state = np.random.RandomState(0) 

448 else: 

449 random_state = np.random.RandomState() 

450 

451 for i in range(n): 

452 # Limit tries to avoid infinite loop in pathological cases 

453 for current_try in range(max_tries): 

454 t1, t2 = random_state.randint(num_tiles, size=2) 

455 o1, o2 = random_state.randint(max_offset, size=2) 

456 

457 # Check for non-overlapping strips and abort the retry loop 

458 if t1 != t2 and (t1, t2) not in edges: 

459 # Different, non-neighboring tiles -- always OK 

460 break 

461 elif t1 == t2 and abs(o1 - o2) > w: 

462 # Same tile OK if strips don't overlap within the image 

463 break 

464 elif (t1, t2) in edges: 

465 # Neighbors OK if either strip is entirely outside the 

466 # expected overlap region (based on nominal positions) 

467 its = self.intersection(t1, t2, cp.full(2, w)) 

468 ioff1, ioff2 = its.offsets[:, 0] 

469 if ( 

470 its.shape[0] > its.shape[1] 

471 or o1 < ioff1 - w or o1 > ioff1 + w 

472 or o2 < ioff2 - w or o2 > ioff2 + w 

473 ): 

474 break 

475 else: 

476 # Retries exhausted. This should be very rare. 

477 warn_data(f"Could not find non-overlapping strips in {max_tries} tries") 

478 pairs[i] = t1, t2 

479 offsets[i] = o1, o2 

480 

481 errors = cp.empty(n) 

482 for i, ((t1, t2), (offset1, offset2)) in enumerate(zip(pairs, offsets)): 

483 if self.verbose and (i % 10 == 9 or i == n - 1): 

484 sys.stdout.write(f'\r quantifying alignment error {i + 1}/{n}') 

485 sys.stdout.flush() 

486 img1 = self.image_stack[t1][offset1:offset1+w, :] 

487 img2 = self.image_stack[t2][offset2:offset2+w, :] 

488 _, errors[i] = ashlar_register_gpu(img1, img2, upsample=self.permutation_upsample) 

489 if self.verbose: 

490 print() 

491 self.errors_negative_sampled = errors 

492 self.max_error = float(cp.percentile(errors, self.alpha * 100)) 

493 

494 

495 def register_all(self): 

496 """Register all neighboring tile pairs.""" 

497 n = self.neighbors_graph.size() 

498 for i, (t1, t2) in enumerate(self.neighbors_graph.edges, 1): 

499 if self.verbose: 

500 sys.stdout.write(f'\r aligning edge {i}/{n}') 

501 sys.stdout.flush() 

502 self.register_pair(t1, t2) 

503 if self.verbose: 

504 print() 

505 self.all_errors = cp.array([x[1] for x in self._cache.values()]) 

506 

507 # Set error values above the threshold to infinity 

508 for k, v in self._cache.items(): 

509 shift_array = cp.array(v[0]) if not isinstance(v[0], cp.ndarray) else v[0] 

510 if v[1] > self.max_error or cp.any(cp.abs(shift_array) > self.max_shift_pixels): 

511 self._cache[k] = (v[0], cp.inf) 

512 

513 def register_pair(self, t1, t2): 

514 """Return relative shift between images and the alignment error.""" 

515 key = tuple(sorted((t1, t2))) 

516 try: 

517 shift, error = self._cache[key] 

518 except KeyError: 

519 # Test a series of increasing overlap window sizes to help avoid 

520 # missing alignments when the stage position error is large relative 

521 # to the tile overlap. Simply using a large overlap in all cases 

522 # limits the maximum achievable correlation thus increasing the 

523 # error metric, leading to worse overall results. The window size 

524 # starts at the nominal size and doubles until it's at least 10% of 

525 # the tile size. If the nominal overlap is already 10% or greater, 

526 # we only use that one size. 

527 try: 

528 smin = self.intersection(key[0], key[1]).shape 

529 smax = cp.round(self.tile_size * self.window_size_factor) 

530 sizes = [smin] 

531 while any(cp.array(sizes[-1]) < smax): 

532 sizes.append(cp.array(sizes[-1]) * 2) 

533 

534 # Try each window size and collect results 

535 results = [] 

536 for s in sizes: 

537 try: 

538 result = self._register(key[0], key[1], s) 

539 if result is not None: 

540 results.append(result) 

541 except Exception as e: 

542 if self.verbose: 

543 print(f" window size {s} failed: {e}") 

544 continue 

545 

546 if not results: 

547 # All window sizes failed, return large error 

548 shift = cp.array([0.0, 0.0]) 

549 error = cp.inf 

550 else: 

551 # Use the shift from the window size that gave the lowest error 

552 shift, _ = min(results, key=lambda r: r[1]) 

553 # Extract the images from the nominal overlap window but with the 

554 # shift applied to the second tile's position, and compute the error 

555 # metric on these images. This should be even lower than the error 

556 # computed above. 

557 try: 

558 _, o1, o2 = self.overlap(key[0], key[1], shift=shift) 

559 error = ashlar_nccw_no_preprocessing_gpu(o1, o2) 

560 except Exception as e: 

561 if self.verbose: 

562 print(f" final error computation failed: {e}") 

563 error = cp.inf 

564 

565 except Exception as e: 

566 if self.verbose: 

567 print(f" registration failed for tiles {key}: {e}") 

568 shift = cp.array([0.0, 0.0]) 

569 error = cp.inf 

570 

571 self._cache[key] = (shift, error) 

572 if t1 > t2: 

573 shift = -shift 

574 # Return copy of shift to prevent corruption of cached values 

575 return shift.copy(), error 

576 

577 def _register(self, t1, t2, min_size=0): 

578 """Register a single tile pair with given minimum size.""" 

579 try: 

580 its, img1, img2 = self.overlap(t1, t2, min_size) 

581 

582 # Validate that we got valid images 

583 if img1.size == 0 or img2.size == 0: 

584 if self.verbose: 

585 print(f" empty images for tiles {t1}, {t2} with min_size {min_size}") 

586 return None 

587 

588 # Account for padding, flipping the sign depending on the direction 

589 # between the tiles 

590 p1, p2 = self.positions[[t1, t2]] 

591 sx = 1 if p1[1] >= p2[1] else -1 

592 sy = 1 if p1[0] >= p2[0] else -1 

593 padding = cp.array(its.padding) * cp.array([sy, sx]) 

594 shift, error = ashlar_register_gpu(img1, img2, upsample=self.upsample_factor) 

595 shift = cp.array(shift) + padding 

596 return shift.get(), error 

597 except Exception as e: 

598 if self.verbose: 

599 print(f" _register failed for tiles {t1}, {t2}: {e}") 

600 return None 

601 

602 

603 def intersection(self, t1, t2, min_size=0, shift=None): 

604 """Calculate intersection region between two tiles.""" 

605 corners1 = self.positions[[t1, t2]].copy() 

606 if shift is not None: 

607 if not isinstance(shift, cp.ndarray): 

608 shift = cp.array(shift) 

609 corners1[1] += shift 

610 corners2 = corners1 + self.tile_size 

611 return IntersectionGPU(corners1, corners2, min_size) 

612 

613 def crop(self, tile_id, offset, shape): 

614 """Crop image from tile at given offset and shape.""" 

615 img = self.image_stack[tile_id] 

616 return ashlar_crop_gpu(img, offset, shape) 

617 

618 def overlap(self, t1, t2, min_size=0, shift=None): 

619 """Extract overlapping regions between two tiles.""" 

620 its = self.intersection(t1, t2, min_size, shift) 

621 

622 # Validate intersection shape before cropping 

623 if cp.any(its.shape <= 0): 

624 raise ValueError(f"Invalid intersection shape {its.shape} for tiles {t1}, {t2}") 

625 

626 img1 = self.crop(t1, its.offsets[0], its.shape) 

627 img2 = self.crop(t2, its.offsets[1], its.shape) 

628 return its, img1, img2 

629 

630 

631 

632 

633 

634 def build_spanning_tree(self): 

635 """Build minimum spanning tree using GPU Boruvka algorithm.""" 

636 # Import the Boruvka MST implementation 

637 from openhcs.processing.backends.pos_gen.mist.boruvka_mst import build_mst_gpu_boruvka 

638 

639 # Convert cache to Boruvka format 

640 valid_edges = [(t1, t2, shift, error) for (t1, t2), (shift, error) in self._cache.items() if cp.isfinite(error)] 

641 

642 if len(valid_edges) == 0: 

643 # No valid edges - create empty graph with all nodes 

644 self.spanning_tree = nx.Graph() 

645 self.spanning_tree.add_nodes_from(range(len(self.positions))) 

646 return 

647 

648 # Prepare arrays for Boruvka MST 

649 connection_from = cp.array([t1 for t1, t2, shift, error in valid_edges], dtype=cp.int32) 

650 connection_to = cp.array([t2 for t1, t2, shift, error in valid_edges], dtype=cp.int32) 

651 connection_dx = cp.array([shift[1] for t1, t2, shift, error in valid_edges], dtype=cp.float32) # x shift 

652 connection_dy = cp.array([shift[0] for t1, t2, shift, error in valid_edges], dtype=cp.float32) # y shift 

653 # Use negative error as quality (higher quality = lower error) 

654 connection_quality = cp.array([-error for t1, t2, shift, error in valid_edges], dtype=cp.float32) 

655 

656 num_nodes = len(self.positions) 

657 

658 try: 

659 # Run GPU Boruvka MST 

660 mst_result = build_mst_gpu_boruvka( 

661 connection_from, connection_to, connection_dx, connection_dy, 

662 connection_quality, num_nodes 

663 ) 

664 

665 # Convert back to NetworkX format for compatibility with rest of algorithm 

666 self.spanning_tree = nx.Graph() 

667 self.spanning_tree.add_nodes_from(range(num_nodes)) 

668 

669 for edge in mst_result['edges']: 

670 t1, t2 = edge['from'], edge['to'] 

671 # Reconstruct error from quality 

672 error = -edge['quality'] if 'quality' in edge else 0.0 

673 self.spanning_tree.add_edge(t1, t2, weight=error) 

674 

675 except Exception as e: 

676 # Fallback to NetworkX if Boruvka fails 

677 print(f"Boruvka MST failed, falling back to NetworkX: {e}") 

678 g = nx.Graph() 

679 g.add_nodes_from(self.neighbors_graph) 

680 g.add_weighted_edges_from( 

681 (t1, t2, error) 

682 for (t1, t2), (_, error) in self._cache.items() 

683 if cp.isfinite(error) 

684 ) 

685 spanning_tree = nx.Graph() 

686 spanning_tree.add_nodes_from(g) 

687 for c in nx.connected_components(g): 

688 cc = g.subgraph(c) 

689 center = nx.center(cc)[0] 

690 paths = nx.single_source_dijkstra_path(cc, center).values() 

691 for path in paths: 

692 nx.add_path(spanning_tree, path) 

693 self.spanning_tree = spanning_tree 

694 

695 def calculate_positions(self): 

696 """Calculate final positions from spanning tree.""" 

697 shifts = {} 

698 for c in nx.connected_components(self.spanning_tree): 

699 cc = self.spanning_tree.subgraph(c) 

700 center = nx.center(cc)[0] 

701 shifts[center] = cp.array([0, 0]) 

702 for edge in nx.traversal.bfs_edges(cc, center): 

703 source, dest = edge 

704 if source not in shifts: 

705 source, dest = dest, source 

706 shift = self.register_pair(source, dest)[0] 

707 shifts[dest] = shifts[source] + cp.array(shift) 

708 if shifts: 

709 self.shifts = cp.array([s for _, s in sorted(shifts.items())]) 

710 self.final_positions = self.positions + self.shifts 

711 else: 

712 # TODO: fill in shifts and positions with 0x2 arrays 

713 raise NotImplementedError("No images") 

714 

715 

716 def fit_model(self): 

717 """Fit linear model to handle disconnected components.""" 

718 components = sorted( 

719 nx.connected_components(self.spanning_tree), 

720 key=len, reverse=True 

721 ) 

722 # Fit LR model on positions of largest connected component 

723 cc0 = list(components[0]) 

724 self.lr = sklearn.linear_model.LinearRegression() 

725 

726 # Convert to CPU for sklearn operations 

727 positions_cpu = cp.asnumpy(self.positions[cc0]) 

728 final_positions_cpu = cp.asnumpy(self.final_positions[cc0]) 

729 self.lr.fit(positions_cpu, final_positions_cpu) 

730 

731 # Fix up degenerate transform matrix. This happens when the spanning 

732 # tree is completely edgeless or cc0's metadata positions fall in a 

733 # straight line. In this case we fall back to the identity transform. 

734 if np.linalg.det(self.lr.coef_) < 1e-3: 

735 warn_data( 

736 "Could not align enough edges, proceeding anyway with original" 

737 " stage positions." 

738 ) 

739 self.lr.coef_ = np.diag(np.ones(2)) 

740 self.lr.intercept_ = np.zeros(2) 

741 

742 # Adjust position of remaining components so their centroids match 

743 # the predictions of the model 

744 for cc in components[1:]: 

745 nodes = list(cc) 

746 centroid_m = cp.mean(self.positions[nodes], axis=0) 

747 centroid_f = cp.mean(self.final_positions[nodes], axis=0) 

748 

749 # Convert to CPU for prediction, then back to GPU 

750 centroid_m_cpu = cp.asnumpy(centroid_m).reshape(1, -1) 

751 shift_cpu = self.lr.predict(centroid_m_cpu)[0] - cp.asnumpy(centroid_f) 

752 shift = cp.array(shift_cpu) 

753 

754 self.final_positions[nodes] += shift 

755 

756 # Adjust positions and model intercept to put origin at 0,0 

757 self.origin = cp.min(self.final_positions, axis=0) 

758 self.final_positions -= self.origin 

759 self.lr.intercept_ -= cp.asnumpy(self.origin) 

760 

761 

762def _calculate_initial_positions_gpu(image_stack, grid_dims: tuple, overlap_ratio: float): 

763 """Calculate initial grid positions based on overlap ratio (GPU version).""" 

764 grid_rows, grid_cols = grid_dims 

765 

766 # Handle both numpy and cupy arrays 

767 if isinstance(image_stack, cp.ndarray): 

768 tile_height, tile_width = image_stack.shape[1:3] 

769 else: 

770 tile_height, tile_width = image_stack.shape[1:3] 

771 

772 spacing_factor = 1.0 - overlap_ratio 

773 

774 positions = [] 

775 for tile_idx in range(len(image_stack)): 

776 r = tile_idx // grid_cols 

777 c = tile_idx % grid_cols 

778 

779 y_pos = r * tile_height * spacing_factor 

780 x_pos = c * tile_width * spacing_factor 

781 positions.append([y_pos, x_pos]) 

782 

783 return cp.array(positions, dtype=cp.float64) 

784 

785 

786def _convert_ashlar_positions_to_openhcs_gpu(ashlar_positions) -> List[Tuple[float, float]]: 

787 """Convert Ashlar positions to OpenHCS format (GPU version).""" 

788 # Convert to CPU if needed 

789 if isinstance(ashlar_positions, cp.ndarray): 

790 ashlar_positions = cp.asnumpy(ashlar_positions) 

791 

792 positions = [] 

793 for tile_idx in range(len(ashlar_positions)): 

794 y, x = ashlar_positions[tile_idx] 

795 positions.append((float(x), float(y))) # OpenHCS uses (x, y) format 

796 return positions 

797 

798 

799@special_inputs("grid_dimensions") 

800@special_outputs("positions") 

801@cupy_func 

802def ashlar_compute_tile_positions_gpu( 

803 image_stack, 

804 grid_dimensions: Tuple[int, int], 

805 overlap_ratio: float = 0.1, 

806 max_shift: float = 15.0, 

807 stitch_alpha: float = 0.01, 

808 max_error: float = None, 

809 randomize: bool = False, 

810 verbose: bool = False, 

811 upsample_factor: int = 10, 

812 permutation_upsample: int = 1, 

813 permutation_samples: int = 1000, 

814 min_permutation_samples: int = 10, 

815 max_permutation_tries: int = 100, 

816 window_size_factor: float = 0.1, 

817 **kwargs 

818) -> Tuple[np.ndarray, List[Tuple[float, float]]]: 

819 """ 

820 Compute tile positions using the Ashlar algorithm on GPU - matches CPU version. 

821 

822 This function implements the Ashlar edge-based stitching algorithm using GPU acceleration. 

823 It performs position calculation with minimal preprocessing (windowing only, no whitening) 

824 to match the CPU version behavior. 

825 

826 Args: 

827 image_stack: 3D numpy/cupy array of shape (num_tiles, height, width) containing preprocessed 

828 grayscale images. Each slice [i] should be a single-channel 2D image ready 

829 for correlation analysis. No further preprocessing will be applied. 

830 

831 grid_dimensions: Tuple of (grid_rows, grid_cols) specifying the logical arrangement of 

832 tiles. For example, (2, 3) means 2 rows and 3 columns of tiles, for a 

833 total of 6 tiles. Must match the number of images in image_stack. 

834 

835 overlap_ratio: Expected fractional overlap between adjacent tiles (0.0-1.0). Default 0.1 

836 means 10% overlap. This is used to calculate initial grid positions and 

837 should match the actual overlap in your microscopy data. Typical values: 

838 - 0.05-0.15 for well-controlled microscopes 

839 - 0.15-0.25 for less precise stages 

840 

841 max_shift: Maximum allowed shift correction in micrometers. Default 15.0. This limits 

842 how far tiles can be moved from their initial grid positions during alignment. 

843 Should be set based on your microscope's stage accuracy: 

844 - 5-15 μm for high-precision stages 

845 - 15-50 μm for standard stages 

846 - 50+ μm for low-precision or manual stages 

847 

848 stitch_alpha: Alpha value for statistical error threshold computation (0.0-1.0). Default 

849 0.01 means 1% false positive rate. Lower values are stricter and reject more 

850 alignments, higher values are more permissive. This controls the trade-off 

851 between alignment quality and success rate: 

852 - 0.001-0.01: Very strict, high quality alignments only 

853 - 0.01-0.05: Balanced (recommended for most data) 

854 - 0.05-0.1: Permissive, accepts lower quality alignments 

855 

856 max_error: Explicit error threshold for rejecting alignments (None = auto-compute). 

857 When None (default), the threshold is computed automatically using permutation 

858 testing. Set to a specific value to override automatic computation. Higher 

859 values accept more alignments, lower values are stricter. 

860 

861 randomize: Whether to use random seed for permutation testing (bool). Default False uses 

862 a fixed seed for reproducible results. Set True for different random sampling 

863 in each run. Generally should be False for consistent results. 

864 

865 verbose: Enable detailed progress logging (bool). Default False. When True, prints 

866 progress information including permutation testing, edge alignment, and 

867 spanning tree construction. Useful for debugging and monitoring progress 

868 on large datasets. 

869 

870 upsample_factor: Sub-pixel accuracy factor for phase cross correlation (int). Default 10. 

871 Higher values provide better sub-pixel accuracy but increase computation time. 

872 Range: 1-100+. Values of 10-50 are typical for high-accuracy stitching. 

873 - 1: Pixel-level accuracy (fastest) 

874 - 10: 0.1 pixel accuracy (balanced) 

875 - 50: 0.02 pixel accuracy (high precision) 

876 

877 permutation_upsample: Upsample factor for permutation testing (int). Default 1. 

878 Lower than upsample_factor for speed during threshold computation. 

879 Usually kept at 1 since permutation testing doesn't need sub-pixel accuracy. 

880 

881 permutation_samples: Number of random samples for error threshold computation (int). Default 1000. 

882 Higher values give more accurate thresholds but slower computation. 

883 Automatically reduced for small datasets to avoid infinite loops. 

884 

885 min_permutation_samples: Minimum permutation samples for small datasets (int). Default 10. 

886 When there are few non-overlapping pairs, this sets the minimum 

887 number of samples to ensure statistical validity. 

888 

889 max_permutation_tries: Maximum attempts to find non-overlapping strips (int). Default 100. 

890 Prevents infinite loops in pathological cases where valid strips 

891 are hard to find. Rarely needs adjustment. 

892 

893 window_size_factor: Fraction of tile size for maximum window size (float). Default 0.1. 

894 Controls the largest overlap window tested during progressive sizing. 

895 Larger values allow detection of bigger stage errors but may reduce 

896 correlation quality. Range: 0.05-0.2 typical. 

897 

898 filter_sigma: Whitening filter sigma for preprocessing (float). Default 0. 

899 Controls the whitening filter applied before correlation: 

900 - 0: Pure Laplacian filter (high-pass, matches original Ashlar) 

901 - >0: Gaussian-Laplacian (LoG) filter with specified sigma 

902 - Typical values: 0-2.0 for most microscopy data 

903 

904 **kwargs: Additional parameters (ignored). Allows compatibility with other stitching 

905 algorithms that may have different parameter sets. 

906 

907 Returns: 

908 Tuple of (image_stack, positions) where: 

909 - image_stack: The original input image array (unchanged) 

910 - positions: List of (x, y) position tuples in OpenHCS format, one per tile. 

911 Positions are in pixel coordinates with (0, 0) at the top-left. 

912 The positions represent the optimal tile placement after Ashlar 

913 alignment, accounting for stage errors and image correlation. 

914 

915 Raises: 

916 Exception: If the Ashlar algorithm fails (e.g., insufficient overlap, correlation 

917 errors), the function automatically falls back to grid-based positioning 

918 using the specified overlap_ratio. 

919 

920 Notes: 

921 - This implementation contains the complete Ashlar algorithm including whitening 

922 filter preprocessing, permutation testing, progressive window sizing, minimum 

923 spanning tree construction, and linear model fitting for disconnected components. 

924 - The correlation functions are identical to original Ashlar including proper 

925 whitening/filtering preprocessing as specified by filter_sigma parameter. 

926 - For best results, ensure your image_stack contains single-channel grayscale 

927 images. The whitening filter will be applied automatically during correlation. 

928 """ 

929 grid_rows, grid_cols = grid_dimensions 

930 

931 if verbose: 

932 logger.info(f"Ashlar GPU: Processing {grid_rows}x{grid_cols} grid with {len(image_stack)} tiles") 

933 

934 try: 

935 # Convert to CuPy array if needed 

936 if not isinstance(image_stack, cp.ndarray): 

937 image_stack_gpu = cp.asarray(image_stack) 

938 else: 

939 image_stack_gpu = image_stack 

940 

941 # Calculate initial grid positions 

942 initial_positions = _calculate_initial_positions_gpu(image_stack_gpu, grid_dimensions, overlap_ratio) 

943 tile_size = cp.array(image_stack_gpu.shape[1:3]) # (height, width) 

944 

945 # Create and run ArrayEdgeAlignerGPU with complete Ashlar algorithm 

946 logger.info("Running complete Ashlar edge-based stitching algorithm on GPU") 

947 aligner = ArrayEdgeAlignerGPU( 

948 image_stack=image_stack_gpu, 

949 positions=initial_positions, 

950 tile_size=tile_size, 

951 pixel_size=1.0, # Assume 1 micrometer per pixel if not specified 

952 max_shift=max_shift, 

953 alpha=stitch_alpha, 

954 max_error=max_error, 

955 randomize=randomize, 

956 verbose=verbose, 

957 upsample_factor=upsample_factor, 

958 permutation_upsample=permutation_upsample, 

959 permutation_samples=permutation_samples, 

960 min_permutation_samples=min_permutation_samples, 

961 max_permutation_tries=max_permutation_tries, 

962 window_size_factor=window_size_factor 

963 ) 

964 

965 # Run the complete algorithm 

966 aligner.run() 

967 

968 # Convert to OpenHCS format 

969 positions = _convert_ashlar_positions_to_openhcs_gpu(aligner.final_positions) 

970 

971 # Convert result back to original format (CPU if input was CPU) 

972 if not isinstance(image_stack, cp.ndarray): 

973 result_image_stack = cp.asnumpy(image_stack_gpu) 

974 else: 

975 result_image_stack = image_stack_gpu 

976 

977 logger.info("Ashlar GPU algorithm completed successfully") 

978 

979 except Exception as e: 

980 logger.error(f"Ashlar GPU algorithm failed: {e}") 

981 # Fallback to grid positions if Ashlar fails 

982 logger.warning("Falling back to grid-based positioning") 

983 positions = [] 

984 

985 # Use original image_stack for fallback dimensions 

986 if isinstance(image_stack, cp.ndarray): 

987 tile_height, tile_width = image_stack.shape[1:3] 

988 else: 

989 tile_height, tile_width = image_stack.shape[1:3] 

990 

991 spacing_factor = 1.0 - overlap_ratio 

992 

993 for tile_idx in range(len(image_stack)): 

994 r = tile_idx // grid_cols 

995 c = tile_idx % grid_cols 

996 x_pos = c * tile_width * spacing_factor 

997 y_pos = r * tile_height * spacing_factor 

998 positions.append((float(x_pos), float(y_pos))) 

999 

1000 # Set result_image_stack for fallback case 

1001 if not isinstance(image_stack, cp.ndarray): 

1002 result_image_stack = image_stack 

1003 else: 

1004 result_image_stack = image_stack 

1005 

1006 logger.info(f"Ashlar GPU: Completed processing {len(positions)} tile positions") 

1007 

1008 return result_image_stack, positions 

1009 

1010 

1011def materialize_ashlar_gpu_positions(data: List[Tuple[float, float]], path: str, filemanager) -> str: 

1012 """Materialize Ashlar GPU tile positions as scientific CSV with grid metadata.""" 

1013 csv_path = path.replace('.pkl', '_ashlar_positions_gpu.csv') 

1014 

1015 df = pd.DataFrame(data, columns=['x_position_um', 'y_position_um']) 

1016 df['tile_id'] = range(len(df)) 

1017 

1018 # Estimate grid dimensions from position layout 

1019 unique_x = sorted(df['x_position_um'].unique()) 

1020 unique_y = sorted(df['y_position_um'].unique()) 

1021 

1022 grid_cols = len(unique_x) 

1023 grid_rows = len(unique_y) 

1024 

1025 # Add grid coordinates 

1026 df['grid_row'] = df.index // grid_cols 

1027 df['grid_col'] = df.index % grid_cols 

1028 

1029 # Add spacing information 

1030 if len(unique_x) > 1: 

1031 x_spacing = unique_x[1] - unique_x[0] 

1032 df['x_spacing_um'] = x_spacing 

1033 else: 

1034 df['x_spacing_um'] = 0 

1035 

1036 if len(unique_y) > 1: 

1037 y_spacing = unique_y[1] - unique_y[0] 

1038 df['y_spacing_um'] = y_spacing 

1039 else: 

1040 df['y_spacing_um'] = 0 

1041 

1042 # Add metadata 

1043 df['algorithm'] = 'ashlar_gpu' 

1044 df['grid_dimensions'] = f"{grid_rows}x{grid_cols}" 

1045 

1046 csv_content = df.to_csv(index=False) 

1047 filemanager.save(csv_content, csv_path, "disk") 

1048 return csv_path