Coverage for openhcs/processing/backends/pos_gen/ashlar_main_gpu.py: 7.8%
508 statements
« prev ^ index » next coverage.py v7.10.3, created at 2025-08-14 05:57 +0000
« prev ^ index » next coverage.py v7.10.3, created at 2025-08-14 05:57 +0000
1"""
2OpenHCS Interface for Ashlar GPU Stitching Algorithm
4Array-based EdgeAligner implementation that works directly with CuPy arrays
5instead of file-based readers. This is the complete Ashlar algorithm modified
6to accept arrays directly and run on GPU.
7"""
8from __future__ import annotations
9import logging
10import sys
11from typing import TYPE_CHECKING, Tuple, List
12import numpy as np
13import networkx as nx
14import scipy.spatial.distance
15import sklearn.linear_model
16import pandas as pd
18from openhcs.core.pipeline.function_contracts import special_inputs, special_outputs
19from openhcs.core.memory.decorators import cupy as cupy_func
20from openhcs.core.utils import optional_import
22# Import CuPy using the established optional import pattern
23cp = optional_import("cupy")
25import warnings
27if TYPE_CHECKING: 27 ↛ 28line 27 didn't jump to line 28 because the condition on line 27 was never true
28 pass
30logger = logging.getLogger(__name__)
33class DataWarning(Warning):
34 """Warnings about the content of user-provided image data."""
35 pass
38def warn_data(message):
39 """Issue a warning about image data."""
40 warnings.warn(message, DataWarning)
43class IntersectionGPU:
44 """Calculate intersection region between two tiles - EXACT Ashlar implementation for GPU."""
46 def __init__(self, corners1, corners2, min_size=0):
47 if not cp:
48 raise ImportError("CuPy is required for GPU intersection calculations")
49 if isinstance(min_size, (int, float)):
50 min_size = cp.full(2, min_size)
51 elif not isinstance(min_size, cp.ndarray):
52 min_size = cp.array(min_size)
53 self._calculate(corners1, corners2, min_size)
55 def _calculate(self, corners1, corners2, min_size):
56 """Calculate intersection parameters using EXACT Ashlar logic."""
57 # This is the EXACT logic from the original Ashlar Intersection class
58 max_shape = (corners2 - corners1).max(axis=0)
59 min_size = cp.clip(min_size, 1, max_shape)
60 position = corners1.max(axis=0)
61 initial_shape = cp.floor(corners2.min(axis=0) - position).astype(int)
62 clipped_shape = cp.maximum(initial_shape, min_size)
63 self.shape = cp.ceil(clipped_shape).astype(int)
64 self.padding = self.shape - initial_shape
65 self.offsets = cp.maximum(position - corners1 - self.padding, 0)
66 offset_diff = self.offsets[1] - self.offsets[0]
67 self.offset_diff_frac = offset_diff - cp.round(offset_diff)
70def _get_window(shape):
71 """Build a 2D Hann window (from Ashlar utils.get_window) on GPU."""
72 if cp is None:
73 raise ImportError("CuPy is required for GPU window functions")
74 # Build a 2D Hann window by taking the outer product of two 1-D windows.
75 wy = cp.hanning(shape[0]).astype(cp.float32)
76 wx = cp.hanning(shape[1]).astype(cp.float32)
77 window = cp.outer(wy, wx)
78 return window
81# Precompute Laplacian kernel for whitening (equivalent to skimage.restoration.uft.laplacian)
82_laplace_kernel_gpu = None
83if cp: 83 ↛ 84line 83 didn't jump to line 84 because the condition on line 83 was never true
84 _laplace_kernel_gpu = cp.array([[0, -1, 0], [-1, 4, -1], [0, -1, 0]], dtype=cp.float32)
87def whiten_gpu(img, sigma):
88 """
89 Vectorized GPU whitening filter - EXACT match to Ashlar reference implementation.
91 This implements the same whitening as ashlar.utils.whiten() but optimized for GPU:
92 - sigma=0: Uses Laplacian convolution (high-pass filter)
93 - sigma>0: Uses Gaussian-Laplacian (LoG filter)
95 Args:
96 img: CuPy array (2D image)
97 sigma: Standard deviation for Gaussian kernel (0 = pure Laplacian)
99 Returns:
100 CuPy array: Whitened image
101 """
102 # Convert to float32 (matches reference)
103 if not isinstance(img, cp.ndarray):
104 img = cp.asarray(img)
105 img = img.astype(cp.float32)
107 if sigma == 0:
108 # Pure Laplacian convolution (high-pass filter)
109 # Equivalent to scipy.ndimage.convolve(img, _laplace_kernel)
110 from cupyx.scipy import ndimage as cp_ndimage
111 output = cp_ndimage.convolve(img, _laplace_kernel_gpu, mode='reflect')
112 else:
113 # Gaussian-Laplacian (LoG filter)
114 # Equivalent to scipy.ndimage.gaussian_laplace(img, sigma)
115 from cupyx.scipy import ndimage as cp_ndimage
116 output = cp_ndimage.gaussian_laplace(img, sigma)
118 return output
121def whiten_gpu_vectorized(img_stack, sigma):
122 """
123 Vectorized GPU whitening for multiple images simultaneously.
125 This processes an entire stack of images in parallel on GPU for maximum efficiency.
127 Args:
128 img_stack: CuPy array of shape (N, H, W) - stack of N images
129 sigma: Standard deviation for Gaussian kernel (0 = pure Laplacian)
131 Returns:
132 CuPy array: Stack of whitened images with same shape as input
133 """
134 if not isinstance(img_stack, cp.ndarray):
135 img_stack = cp.asarray(img_stack)
136 img_stack = img_stack.astype(cp.float32)
138 if sigma == 0:
139 # Vectorized Laplacian convolution for entire stack
140 from cupyx.scipy import ndimage as cp_ndimage
141 # Process each image in the stack
142 output_stack = cp.empty_like(img_stack)
143 for i in range(img_stack.shape[0]):
144 output_stack[i] = cp_ndimage.convolve(img_stack[i], _laplace_kernel_gpu, mode='reflect')
145 else:
146 # Vectorized Gaussian-Laplacian for entire stack
147 from cupyx.scipy import ndimage as cp_ndimage
148 output_stack = cp.empty_like(img_stack)
149 for i in range(img_stack.shape[0]):
150 output_stack[i] = cp_ndimage.gaussian_laplace(img_stack[i], sigma)
152 return output_stack
155def ashlar_register_gpu(img1, img2, upsample=10):
156 """
157 GPU register function using cuCIM - matches CPU version with windowing only.
159 This uses cuCIM's phase_cross_correlation which is the GPU equivalent
160 of skimage.registration.phase_cross_correlation used in the CPU version.
161 No whitening filter - just windowing like the CPU version.
163 Args:
164 img1, img2: Input images
165 upsample: Upsampling factor for phase correlation
166 """
167 import itertools
168 import cucim.skimage.registration
170 # Input validation (same as CPU version)
171 if img1 is None or img2 is None:
172 return cp.array([0.0, 0.0]), cp.inf
174 if img1.size == 0 or img2.size == 0:
175 return cp.array([0.0, 0.0]), cp.inf
177 if img1.shape != img2.shape:
178 return cp.array([0.0, 0.0]), cp.inf
180 if len(img1.shape) != 2:
181 return cp.array([0.0, 0.0]), cp.inf
183 if img1.shape[0] < 1 or img1.shape[1] < 1:
184 return cp.array([0.0, 0.0]), cp.inf
186 # Convert to CuPy arrays
187 if not isinstance(img1, cp.ndarray):
188 img1 = cp.asarray(img1)
189 if not isinstance(img2, cp.ndarray):
190 img2 = cp.asarray(img2)
192 # Convert to float32 and apply windowing (matches CPU version)
193 img1w = img1.astype(cp.float32) * _get_window(img1.shape)
194 img2w = img2.astype(cp.float32) * _get_window(img2.shape)
196 # Use cuCIM's phase cross correlation (GPU equivalent of skimage)
197 try:
198 shift, error, phase_diff = cucim.skimage.registration.phase_cross_correlation(
199 img1w, img2w, upsample_factor=upsample
200 )
202 # Convert to numpy for consistency with CPU version
203 shift = cp.asnumpy(shift)
204 error = float(error)
206 # Only log high errors to avoid spam
207 if error > 1.0: # High error threshold for Ashlar
208 logger.warning(f"Ashlar GPU: HIGH CORRELATION ERROR - Error={error:.4f}, Shift=({shift[0]:.2f}, {shift[1]:.2f})")
209 logger.warning(f" This indicates poor overlap or image quality between tiles")
211 except Exception as e:
212 # Fallback if correlation fails
213 logger.error(f"Ashlar GPU: CORRELATION FAILED - Exception: {e}")
214 logger.error(f" Returning infinite error")
215 shift = cp.array([0.0, 0.0])
216 error = cp.inf
218 return shift, error
224def ashlar_nccw_no_preprocessing_gpu(img1, img2):
225 """
226 GPU nccw function - faithful to Ashlar but with better numerical stability.
228 This matches the CPU version but with improved precision handling for GPU.
229 """
230 # Convert to CuPy arrays and float32 (equivalent to what whiten() does)
231 if not isinstance(img1, cp.ndarray):
232 img1 = cp.asarray(img1)
233 if not isinstance(img2, cp.ndarray):
234 img2 = cp.asarray(img2)
236 img1w = img1.astype(cp.float32)
237 img2w = img2.astype(cp.float32)
239 correlation = float(cp.abs(cp.sum(img1w * img2w)))
240 total_amplitude = float(cp.linalg.norm(img1w) * cp.linalg.norm(img2w))
242 if correlation > 0 and total_amplitude > 0:
243 diff = correlation - total_amplitude
244 if diff <= 0:
245 error = -cp.log(correlation / total_amplitude)
246 elif diff < 1e-3: # Increased tolerance for GPU precision
247 # This situation can occur due to numerical precision issues when
248 # img1 and img2 are very nearly or exactly identical. If the
249 # difference is small enough, let it slide.
250 error = 0
251 else:
252 # Instead of raising error, return a large but finite error
253 logger.warning(f"Ashlar GPU: NCCW numerical precision issue - diff={diff:.6f}, using error=100.0")
254 error = 100.0 # Large error but not infinite
255 else:
256 logger.warning(f"Ashlar GPU: NCCW invalid correlation - correlation={correlation:.6f}, total_amplitude={total_amplitude:.6f}")
257 error = cp.inf
259 # Log all NCCW results at INFO level for user visibility
260 error_float = float(error)
261 if error_float > 10.0: # High NCCW error threshold
262 logger.warning(f"Ashlar GPU: HIGH NCCW ERROR - Error={error_float:.4f}")
263 logger.warning(f" This indicates poor image correlation between tiles")
264 else:
265 logger.info(f"Ashlar GPU: NCCW - Error={error_float:.4f}")
267 return error_float
273def ashlar_crop_gpu(img, offset, shape):
274 """
275 EXACT Ashlar crop function (from ashlar.utils.crop) for GPU arrays with boundary validation.
277 Note that this only crops to the nearest whole-pixel offset.
278 """
279 # Convert to CuPy if needed
280 if not isinstance(img, cp.ndarray):
281 img = cp.asarray(img)
282 if not isinstance(offset, cp.ndarray):
283 offset = cp.asarray(offset)
284 if not isinstance(shape, cp.ndarray):
285 shape = cp.asarray(shape)
287 # Validate inputs to prevent zero-sized arrays
288 if cp.any(shape <= 0):
289 raise ValueError(f"Invalid crop shape: {shape}. Shape must be positive.")
291 # Note that this only crops to the nearest whole-pixel offset.
292 start = cp.round(offset).astype(int)
293 end = start + shape
295 # Validate bounds to prevent invalid slicing
296 img_shape = cp.array(img.shape)
297 if cp.any(start < 0) or cp.any(end > img_shape):
298 # Clip to valid bounds
299 start = cp.maximum(start, 0)
300 end = cp.minimum(end, img_shape)
302 # Recalculate shape after clipping
303 new_shape = end - start
304 if cp.any(new_shape <= 0):
305 raise ValueError(f"Invalid crop region after bounds checking: start={start}, end={end}, img_shape={img_shape}")
307 img = img[start[0]:end[0], start[1]:end[1]]
308 return img
314class ArrayEdgeAlignerGPU:
315 """
316 Array-based EdgeAligner that implements the complete Ashlar algorithm
317 but works directly with CuPy arrays instead of file readers and runs on GPU.
318 """
320 def __init__(self, image_stack, positions, tile_size, pixel_size=1.0,
321 max_shift=15, alpha=0.01, max_error=None,
322 randomize=False, verbose=False, upsample_factor=10,
323 permutation_upsample=1, permutation_samples=1000,
324 min_permutation_samples=10, max_permutation_tries=100,
325 window_size_factor=0.1):
326 """
327 Initialize array-based EdgeAligner for position calculation on GPU.
329 Args:
330 image_stack: 3D numpy/cupy array (num_tiles, height, width) - preprocessed grayscale
331 positions: 2D array of tile positions (num_tiles, 2) in pixels
332 tile_size: Array [height, width] of tile dimensions
333 pixel_size: Pixel size in micrometers (for max_shift conversion)
334 max_shift: Maximum allowed shift in micrometers
335 alpha: Alpha value for error threshold (lower = stricter)
336 max_error: Explicit error threshold (None = auto-compute)
337 randomize: Use random seed for permutation testing
338 verbose: Enable verbose logging
339 """
340 # Convert to CuPy arrays for GPU processing
341 if not isinstance(image_stack, cp.ndarray):
342 self.image_stack = cp.asarray(image_stack)
343 else:
344 self.image_stack = image_stack
346 if not isinstance(positions, cp.ndarray):
347 self.positions = cp.asarray(positions, dtype=cp.float64)
348 else:
349 self.positions = positions.astype(cp.float64)
351 self.tile_size = cp.array(tile_size)
352 self.pixel_size = pixel_size
353 self.max_shift = max_shift
354 self.max_shift_pixels = self.max_shift / self.pixel_size
355 self.alpha = alpha
356 self.max_error = max_error
357 self.randomize = randomize
358 self.verbose = verbose
359 self.upsample_factor = upsample_factor
360 self.permutation_upsample = permutation_upsample
361 self.permutation_samples = permutation_samples
362 self.min_permutation_samples = min_permutation_samples
363 self.max_permutation_tries = max_permutation_tries
364 self.window_size_factor = window_size_factor
365 self._cache = {}
366 self.errors_negative_sampled = cp.empty(0)
368 # Build neighbors graph (this uses CPU operations with NetworkX)
369 self.neighbors_graph = self._build_neighbors_graph()
371 def _build_neighbors_graph(self):
372 """Build graph of neighboring (overlapping) tiles."""
373 # Convert to CPU for scipy operations
374 positions_cpu = cp.asnumpy(self.positions)
375 tile_size_cpu = cp.asnumpy(self.tile_size)
377 pdist = scipy.spatial.distance.pdist(positions_cpu, metric='cityblock')
378 sp = scipy.spatial.distance.squareform(pdist)
379 max_distance = tile_size_cpu.max() + 1
380 edges = zip(*np.nonzero((sp > 0) & (sp < max_distance)))
381 graph = nx.from_edgelist(edges)
382 graph.add_nodes_from(range(len(positions_cpu)))
383 return graph
386 def run(self):
387 """Run the complete Ashlar algorithm."""
388 self.check_overlaps()
389 self.compute_threshold()
390 self.register_all()
391 self.build_spanning_tree()
392 self.calculate_positions()
393 self.fit_model()
395 def check_overlaps(self):
396 """Check if tiles actually overlap based on positions."""
397 overlaps = []
398 for t1, t2 in self.neighbors_graph.edges:
399 overlap = self.tile_size - cp.abs(self.positions[t1] - self.positions[t2])
400 overlaps.append(overlap)
402 if overlaps:
403 overlaps = cp.stack(overlaps)
404 failures = cp.any(overlaps < 1, axis=1)
405 failures_cpu = cp.asnumpy(failures)
407 if len(failures_cpu) and all(failures_cpu):
408 warn_data("No tiles overlap, attempting alignment anyway.")
409 elif any(failures_cpu):
410 warn_data("Some neighboring tiles have zero overlap.")
412 def compute_threshold(self):
413 """Compute error threshold using permutation testing."""
414 if self.max_error is not None:
415 if self.verbose:
416 print(" using explicit error threshold")
417 return
419 edges = self.neighbors_graph.edges
420 num_tiles = len(self.image_stack)
422 # If not enough tiles overlap to matter, skip this whole thing
423 if len(edges) <= 1:
424 self.max_error = np.inf
425 return
427 widths = []
428 for t1, t2 in edges:
429 shape = self.intersection(t1, t2).shape
430 widths.append(cp.min(cp.array(shape)))
432 widths = cp.array(widths)
433 w = int(cp.max(widths))
434 max_offset = int(self.tile_size[0]) - w
436 # Number of possible pairs minus number of actual neighbor pairs
437 num_distant_pairs = num_tiles * (num_tiles - 1) // 2 - len(edges)
439 # Reduce permutation count for small datasets
440 n = self.permutation_samples if num_distant_pairs > 8 else (num_distant_pairs + 1) * self.min_permutation_samples
441 pairs = np.empty((n, 2), dtype=int) # Keep on CPU for random generation
442 offsets = np.empty((n, 2), dtype=int) # Keep on CPU for random generation
444 # Generate n random non-overlapping image strips
445 max_tries = self.max_permutation_tries
446 if self.randomize is False:
447 random_state = np.random.RandomState(0)
448 else:
449 random_state = np.random.RandomState()
451 for i in range(n):
452 # Limit tries to avoid infinite loop in pathological cases
453 for current_try in range(max_tries):
454 t1, t2 = random_state.randint(num_tiles, size=2)
455 o1, o2 = random_state.randint(max_offset, size=2)
457 # Check for non-overlapping strips and abort the retry loop
458 if t1 != t2 and (t1, t2) not in edges:
459 # Different, non-neighboring tiles -- always OK
460 break
461 elif t1 == t2 and abs(o1 - o2) > w:
462 # Same tile OK if strips don't overlap within the image
463 break
464 elif (t1, t2) in edges:
465 # Neighbors OK if either strip is entirely outside the
466 # expected overlap region (based on nominal positions)
467 its = self.intersection(t1, t2, cp.full(2, w))
468 ioff1, ioff2 = its.offsets[:, 0]
469 if (
470 its.shape[0] > its.shape[1]
471 or o1 < ioff1 - w or o1 > ioff1 + w
472 or o2 < ioff2 - w or o2 > ioff2 + w
473 ):
474 break
475 else:
476 # Retries exhausted. This should be very rare.
477 warn_data(f"Could not find non-overlapping strips in {max_tries} tries")
478 pairs[i] = t1, t2
479 offsets[i] = o1, o2
481 errors = cp.empty(n)
482 for i, ((t1, t2), (offset1, offset2)) in enumerate(zip(pairs, offsets)):
483 if self.verbose and (i % 10 == 9 or i == n - 1):
484 sys.stdout.write(f'\r quantifying alignment error {i + 1}/{n}')
485 sys.stdout.flush()
486 img1 = self.image_stack[t1][offset1:offset1+w, :]
487 img2 = self.image_stack[t2][offset2:offset2+w, :]
488 _, errors[i] = ashlar_register_gpu(img1, img2, upsample=self.permutation_upsample)
489 if self.verbose:
490 print()
491 self.errors_negative_sampled = errors
492 self.max_error = float(cp.percentile(errors, self.alpha * 100))
495 def register_all(self):
496 """Register all neighboring tile pairs."""
497 n = self.neighbors_graph.size()
498 for i, (t1, t2) in enumerate(self.neighbors_graph.edges, 1):
499 if self.verbose:
500 sys.stdout.write(f'\r aligning edge {i}/{n}')
501 sys.stdout.flush()
502 self.register_pair(t1, t2)
503 if self.verbose:
504 print()
505 self.all_errors = cp.array([x[1] for x in self._cache.values()])
507 # Set error values above the threshold to infinity
508 for k, v in self._cache.items():
509 shift_array = cp.array(v[0]) if not isinstance(v[0], cp.ndarray) else v[0]
510 if v[1] > self.max_error or cp.any(cp.abs(shift_array) > self.max_shift_pixels):
511 self._cache[k] = (v[0], cp.inf)
513 def register_pair(self, t1, t2):
514 """Return relative shift between images and the alignment error."""
515 key = tuple(sorted((t1, t2)))
516 try:
517 shift, error = self._cache[key]
518 except KeyError:
519 # Test a series of increasing overlap window sizes to help avoid
520 # missing alignments when the stage position error is large relative
521 # to the tile overlap. Simply using a large overlap in all cases
522 # limits the maximum achievable correlation thus increasing the
523 # error metric, leading to worse overall results. The window size
524 # starts at the nominal size and doubles until it's at least 10% of
525 # the tile size. If the nominal overlap is already 10% or greater,
526 # we only use that one size.
527 try:
528 smin = self.intersection(key[0], key[1]).shape
529 smax = cp.round(self.tile_size * self.window_size_factor)
530 sizes = [smin]
531 while any(cp.array(sizes[-1]) < smax):
532 sizes.append(cp.array(sizes[-1]) * 2)
534 # Try each window size and collect results
535 results = []
536 for s in sizes:
537 try:
538 result = self._register(key[0], key[1], s)
539 if result is not None:
540 results.append(result)
541 except Exception as e:
542 if self.verbose:
543 print(f" window size {s} failed: {e}")
544 continue
546 if not results:
547 # All window sizes failed, return large error
548 shift = cp.array([0.0, 0.0])
549 error = cp.inf
550 else:
551 # Use the shift from the window size that gave the lowest error
552 shift, _ = min(results, key=lambda r: r[1])
553 # Extract the images from the nominal overlap window but with the
554 # shift applied to the second tile's position, and compute the error
555 # metric on these images. This should be even lower than the error
556 # computed above.
557 try:
558 _, o1, o2 = self.overlap(key[0], key[1], shift=shift)
559 error = ashlar_nccw_no_preprocessing_gpu(o1, o2)
560 except Exception as e:
561 if self.verbose:
562 print(f" final error computation failed: {e}")
563 error = cp.inf
565 except Exception as e:
566 if self.verbose:
567 print(f" registration failed for tiles {key}: {e}")
568 shift = cp.array([0.0, 0.0])
569 error = cp.inf
571 self._cache[key] = (shift, error)
572 if t1 > t2:
573 shift = -shift
574 # Return copy of shift to prevent corruption of cached values
575 return shift.copy(), error
577 def _register(self, t1, t2, min_size=0):
578 """Register a single tile pair with given minimum size."""
579 try:
580 its, img1, img2 = self.overlap(t1, t2, min_size)
582 # Validate that we got valid images
583 if img1.size == 0 or img2.size == 0:
584 if self.verbose:
585 print(f" empty images for tiles {t1}, {t2} with min_size {min_size}")
586 return None
588 # Account for padding, flipping the sign depending on the direction
589 # between the tiles
590 p1, p2 = self.positions[[t1, t2]]
591 sx = 1 if p1[1] >= p2[1] else -1
592 sy = 1 if p1[0] >= p2[0] else -1
593 padding = cp.array(its.padding) * cp.array([sy, sx])
594 shift, error = ashlar_register_gpu(img1, img2, upsample=self.upsample_factor)
595 shift = cp.array(shift) + padding
596 return shift.get(), error
597 except Exception as e:
598 if self.verbose:
599 print(f" _register failed for tiles {t1}, {t2}: {e}")
600 return None
603 def intersection(self, t1, t2, min_size=0, shift=None):
604 """Calculate intersection region between two tiles."""
605 corners1 = self.positions[[t1, t2]].copy()
606 if shift is not None:
607 if not isinstance(shift, cp.ndarray):
608 shift = cp.array(shift)
609 corners1[1] += shift
610 corners2 = corners1 + self.tile_size
611 return IntersectionGPU(corners1, corners2, min_size)
613 def crop(self, tile_id, offset, shape):
614 """Crop image from tile at given offset and shape."""
615 img = self.image_stack[tile_id]
616 return ashlar_crop_gpu(img, offset, shape)
618 def overlap(self, t1, t2, min_size=0, shift=None):
619 """Extract overlapping regions between two tiles."""
620 its = self.intersection(t1, t2, min_size, shift)
622 # Validate intersection shape before cropping
623 if cp.any(its.shape <= 0):
624 raise ValueError(f"Invalid intersection shape {its.shape} for tiles {t1}, {t2}")
626 img1 = self.crop(t1, its.offsets[0], its.shape)
627 img2 = self.crop(t2, its.offsets[1], its.shape)
628 return its, img1, img2
634 def build_spanning_tree(self):
635 """Build minimum spanning tree using GPU Boruvka algorithm."""
636 # Import the Boruvka MST implementation
637 from openhcs.processing.backends.pos_gen.mist.boruvka_mst import build_mst_gpu_boruvka
639 # Convert cache to Boruvka format
640 valid_edges = [(t1, t2, shift, error) for (t1, t2), (shift, error) in self._cache.items() if cp.isfinite(error)]
642 if len(valid_edges) == 0:
643 # No valid edges - create empty graph with all nodes
644 self.spanning_tree = nx.Graph()
645 self.spanning_tree.add_nodes_from(range(len(self.positions)))
646 return
648 # Prepare arrays for Boruvka MST
649 connection_from = cp.array([t1 for t1, t2, shift, error in valid_edges], dtype=cp.int32)
650 connection_to = cp.array([t2 for t1, t2, shift, error in valid_edges], dtype=cp.int32)
651 connection_dx = cp.array([shift[1] for t1, t2, shift, error in valid_edges], dtype=cp.float32) # x shift
652 connection_dy = cp.array([shift[0] for t1, t2, shift, error in valid_edges], dtype=cp.float32) # y shift
653 # Use negative error as quality (higher quality = lower error)
654 connection_quality = cp.array([-error for t1, t2, shift, error in valid_edges], dtype=cp.float32)
656 num_nodes = len(self.positions)
658 try:
659 # Run GPU Boruvka MST
660 mst_result = build_mst_gpu_boruvka(
661 connection_from, connection_to, connection_dx, connection_dy,
662 connection_quality, num_nodes
663 )
665 # Convert back to NetworkX format for compatibility with rest of algorithm
666 self.spanning_tree = nx.Graph()
667 self.spanning_tree.add_nodes_from(range(num_nodes))
669 for edge in mst_result['edges']:
670 t1, t2 = edge['from'], edge['to']
671 # Reconstruct error from quality
672 error = -edge['quality'] if 'quality' in edge else 0.0
673 self.spanning_tree.add_edge(t1, t2, weight=error)
675 except Exception as e:
676 # Fallback to NetworkX if Boruvka fails
677 print(f"Boruvka MST failed, falling back to NetworkX: {e}")
678 g = nx.Graph()
679 g.add_nodes_from(self.neighbors_graph)
680 g.add_weighted_edges_from(
681 (t1, t2, error)
682 for (t1, t2), (_, error) in self._cache.items()
683 if cp.isfinite(error)
684 )
685 spanning_tree = nx.Graph()
686 spanning_tree.add_nodes_from(g)
687 for c in nx.connected_components(g):
688 cc = g.subgraph(c)
689 center = nx.center(cc)[0]
690 paths = nx.single_source_dijkstra_path(cc, center).values()
691 for path in paths:
692 nx.add_path(spanning_tree, path)
693 self.spanning_tree = spanning_tree
695 def calculate_positions(self):
696 """Calculate final positions from spanning tree."""
697 shifts = {}
698 for c in nx.connected_components(self.spanning_tree):
699 cc = self.spanning_tree.subgraph(c)
700 center = nx.center(cc)[0]
701 shifts[center] = cp.array([0, 0])
702 for edge in nx.traversal.bfs_edges(cc, center):
703 source, dest = edge
704 if source not in shifts:
705 source, dest = dest, source
706 shift = self.register_pair(source, dest)[0]
707 shifts[dest] = shifts[source] + cp.array(shift)
708 if shifts:
709 self.shifts = cp.array([s for _, s in sorted(shifts.items())])
710 self.final_positions = self.positions + self.shifts
711 else:
712 # TODO: fill in shifts and positions with 0x2 arrays
713 raise NotImplementedError("No images")
716 def fit_model(self):
717 """Fit linear model to handle disconnected components."""
718 components = sorted(
719 nx.connected_components(self.spanning_tree),
720 key=len, reverse=True
721 )
722 # Fit LR model on positions of largest connected component
723 cc0 = list(components[0])
724 self.lr = sklearn.linear_model.LinearRegression()
726 # Convert to CPU for sklearn operations
727 positions_cpu = cp.asnumpy(self.positions[cc0])
728 final_positions_cpu = cp.asnumpy(self.final_positions[cc0])
729 self.lr.fit(positions_cpu, final_positions_cpu)
731 # Fix up degenerate transform matrix. This happens when the spanning
732 # tree is completely edgeless or cc0's metadata positions fall in a
733 # straight line. In this case we fall back to the identity transform.
734 if np.linalg.det(self.lr.coef_) < 1e-3:
735 warn_data(
736 "Could not align enough edges, proceeding anyway with original"
737 " stage positions."
738 )
739 self.lr.coef_ = np.diag(np.ones(2))
740 self.lr.intercept_ = np.zeros(2)
742 # Adjust position of remaining components so their centroids match
743 # the predictions of the model
744 for cc in components[1:]:
745 nodes = list(cc)
746 centroid_m = cp.mean(self.positions[nodes], axis=0)
747 centroid_f = cp.mean(self.final_positions[nodes], axis=0)
749 # Convert to CPU for prediction, then back to GPU
750 centroid_m_cpu = cp.asnumpy(centroid_m).reshape(1, -1)
751 shift_cpu = self.lr.predict(centroid_m_cpu)[0] - cp.asnumpy(centroid_f)
752 shift = cp.array(shift_cpu)
754 self.final_positions[nodes] += shift
756 # Adjust positions and model intercept to put origin at 0,0
757 self.origin = cp.min(self.final_positions, axis=0)
758 self.final_positions -= self.origin
759 self.lr.intercept_ -= cp.asnumpy(self.origin)
762def _calculate_initial_positions_gpu(image_stack, grid_dims: tuple, overlap_ratio: float):
763 """Calculate initial grid positions based on overlap ratio (GPU version)."""
764 grid_rows, grid_cols = grid_dims
766 # Handle both numpy and cupy arrays
767 if isinstance(image_stack, cp.ndarray):
768 tile_height, tile_width = image_stack.shape[1:3]
769 else:
770 tile_height, tile_width = image_stack.shape[1:3]
772 spacing_factor = 1.0 - overlap_ratio
774 positions = []
775 for tile_idx in range(len(image_stack)):
776 r = tile_idx // grid_cols
777 c = tile_idx % grid_cols
779 y_pos = r * tile_height * spacing_factor
780 x_pos = c * tile_width * spacing_factor
781 positions.append([y_pos, x_pos])
783 return cp.array(positions, dtype=cp.float64)
786def _convert_ashlar_positions_to_openhcs_gpu(ashlar_positions) -> List[Tuple[float, float]]:
787 """Convert Ashlar positions to OpenHCS format (GPU version)."""
788 # Convert to CPU if needed
789 if isinstance(ashlar_positions, cp.ndarray):
790 ashlar_positions = cp.asnumpy(ashlar_positions)
792 positions = []
793 for tile_idx in range(len(ashlar_positions)):
794 y, x = ashlar_positions[tile_idx]
795 positions.append((float(x), float(y))) # OpenHCS uses (x, y) format
796 return positions
799@special_inputs("grid_dimensions")
800@special_outputs("positions")
801@cupy_func
802def ashlar_compute_tile_positions_gpu(
803 image_stack,
804 grid_dimensions: Tuple[int, int],
805 overlap_ratio: float = 0.1,
806 max_shift: float = 15.0,
807 stitch_alpha: float = 0.01,
808 max_error: float = None,
809 randomize: bool = False,
810 verbose: bool = False,
811 upsample_factor: int = 10,
812 permutation_upsample: int = 1,
813 permutation_samples: int = 1000,
814 min_permutation_samples: int = 10,
815 max_permutation_tries: int = 100,
816 window_size_factor: float = 0.1,
817 **kwargs
818) -> Tuple[np.ndarray, List[Tuple[float, float]]]:
819 """
820 Compute tile positions using the Ashlar algorithm on GPU - matches CPU version.
822 This function implements the Ashlar edge-based stitching algorithm using GPU acceleration.
823 It performs position calculation with minimal preprocessing (windowing only, no whitening)
824 to match the CPU version behavior.
826 Args:
827 image_stack: 3D numpy/cupy array of shape (num_tiles, height, width) containing preprocessed
828 grayscale images. Each slice [i] should be a single-channel 2D image ready
829 for correlation analysis. No further preprocessing will be applied.
831 grid_dimensions: Tuple of (grid_rows, grid_cols) specifying the logical arrangement of
832 tiles. For example, (2, 3) means 2 rows and 3 columns of tiles, for a
833 total of 6 tiles. Must match the number of images in image_stack.
835 overlap_ratio: Expected fractional overlap between adjacent tiles (0.0-1.0). Default 0.1
836 means 10% overlap. This is used to calculate initial grid positions and
837 should match the actual overlap in your microscopy data. Typical values:
838 - 0.05-0.15 for well-controlled microscopes
839 - 0.15-0.25 for less precise stages
841 max_shift: Maximum allowed shift correction in micrometers. Default 15.0. This limits
842 how far tiles can be moved from their initial grid positions during alignment.
843 Should be set based on your microscope's stage accuracy:
844 - 5-15 μm for high-precision stages
845 - 15-50 μm for standard stages
846 - 50+ μm for low-precision or manual stages
848 stitch_alpha: Alpha value for statistical error threshold computation (0.0-1.0). Default
849 0.01 means 1% false positive rate. Lower values are stricter and reject more
850 alignments, higher values are more permissive. This controls the trade-off
851 between alignment quality and success rate:
852 - 0.001-0.01: Very strict, high quality alignments only
853 - 0.01-0.05: Balanced (recommended for most data)
854 - 0.05-0.1: Permissive, accepts lower quality alignments
856 max_error: Explicit error threshold for rejecting alignments (None = auto-compute).
857 When None (default), the threshold is computed automatically using permutation
858 testing. Set to a specific value to override automatic computation. Higher
859 values accept more alignments, lower values are stricter.
861 randomize: Whether to use random seed for permutation testing (bool). Default False uses
862 a fixed seed for reproducible results. Set True for different random sampling
863 in each run. Generally should be False for consistent results.
865 verbose: Enable detailed progress logging (bool). Default False. When True, prints
866 progress information including permutation testing, edge alignment, and
867 spanning tree construction. Useful for debugging and monitoring progress
868 on large datasets.
870 upsample_factor: Sub-pixel accuracy factor for phase cross correlation (int). Default 10.
871 Higher values provide better sub-pixel accuracy but increase computation time.
872 Range: 1-100+. Values of 10-50 are typical for high-accuracy stitching.
873 - 1: Pixel-level accuracy (fastest)
874 - 10: 0.1 pixel accuracy (balanced)
875 - 50: 0.02 pixel accuracy (high precision)
877 permutation_upsample: Upsample factor for permutation testing (int). Default 1.
878 Lower than upsample_factor for speed during threshold computation.
879 Usually kept at 1 since permutation testing doesn't need sub-pixel accuracy.
881 permutation_samples: Number of random samples for error threshold computation (int). Default 1000.
882 Higher values give more accurate thresholds but slower computation.
883 Automatically reduced for small datasets to avoid infinite loops.
885 min_permutation_samples: Minimum permutation samples for small datasets (int). Default 10.
886 When there are few non-overlapping pairs, this sets the minimum
887 number of samples to ensure statistical validity.
889 max_permutation_tries: Maximum attempts to find non-overlapping strips (int). Default 100.
890 Prevents infinite loops in pathological cases where valid strips
891 are hard to find. Rarely needs adjustment.
893 window_size_factor: Fraction of tile size for maximum window size (float). Default 0.1.
894 Controls the largest overlap window tested during progressive sizing.
895 Larger values allow detection of bigger stage errors but may reduce
896 correlation quality. Range: 0.05-0.2 typical.
898 filter_sigma: Whitening filter sigma for preprocessing (float). Default 0.
899 Controls the whitening filter applied before correlation:
900 - 0: Pure Laplacian filter (high-pass, matches original Ashlar)
901 - >0: Gaussian-Laplacian (LoG) filter with specified sigma
902 - Typical values: 0-2.0 for most microscopy data
904 **kwargs: Additional parameters (ignored). Allows compatibility with other stitching
905 algorithms that may have different parameter sets.
907 Returns:
908 Tuple of (image_stack, positions) where:
909 - image_stack: The original input image array (unchanged)
910 - positions: List of (x, y) position tuples in OpenHCS format, one per tile.
911 Positions are in pixel coordinates with (0, 0) at the top-left.
912 The positions represent the optimal tile placement after Ashlar
913 alignment, accounting for stage errors and image correlation.
915 Raises:
916 Exception: If the Ashlar algorithm fails (e.g., insufficient overlap, correlation
917 errors), the function automatically falls back to grid-based positioning
918 using the specified overlap_ratio.
920 Notes:
921 - This implementation contains the complete Ashlar algorithm including whitening
922 filter preprocessing, permutation testing, progressive window sizing, minimum
923 spanning tree construction, and linear model fitting for disconnected components.
924 - The correlation functions are identical to original Ashlar including proper
925 whitening/filtering preprocessing as specified by filter_sigma parameter.
926 - For best results, ensure your image_stack contains single-channel grayscale
927 images. The whitening filter will be applied automatically during correlation.
928 """
929 grid_rows, grid_cols = grid_dimensions
931 if verbose:
932 logger.info(f"Ashlar GPU: Processing {grid_rows}x{grid_cols} grid with {len(image_stack)} tiles")
934 try:
935 # Convert to CuPy array if needed
936 if not isinstance(image_stack, cp.ndarray):
937 image_stack_gpu = cp.asarray(image_stack)
938 else:
939 image_stack_gpu = image_stack
941 # Calculate initial grid positions
942 initial_positions = _calculate_initial_positions_gpu(image_stack_gpu, grid_dimensions, overlap_ratio)
943 tile_size = cp.array(image_stack_gpu.shape[1:3]) # (height, width)
945 # Create and run ArrayEdgeAlignerGPU with complete Ashlar algorithm
946 logger.info("Running complete Ashlar edge-based stitching algorithm on GPU")
947 aligner = ArrayEdgeAlignerGPU(
948 image_stack=image_stack_gpu,
949 positions=initial_positions,
950 tile_size=tile_size,
951 pixel_size=1.0, # Assume 1 micrometer per pixel if not specified
952 max_shift=max_shift,
953 alpha=stitch_alpha,
954 max_error=max_error,
955 randomize=randomize,
956 verbose=verbose,
957 upsample_factor=upsample_factor,
958 permutation_upsample=permutation_upsample,
959 permutation_samples=permutation_samples,
960 min_permutation_samples=min_permutation_samples,
961 max_permutation_tries=max_permutation_tries,
962 window_size_factor=window_size_factor
963 )
965 # Run the complete algorithm
966 aligner.run()
968 # Convert to OpenHCS format
969 positions = _convert_ashlar_positions_to_openhcs_gpu(aligner.final_positions)
971 # Convert result back to original format (CPU if input was CPU)
972 if not isinstance(image_stack, cp.ndarray):
973 result_image_stack = cp.asnumpy(image_stack_gpu)
974 else:
975 result_image_stack = image_stack_gpu
977 logger.info("Ashlar GPU algorithm completed successfully")
979 except Exception as e:
980 logger.error(f"Ashlar GPU algorithm failed: {e}")
981 # Fallback to grid positions if Ashlar fails
982 logger.warning("Falling back to grid-based positioning")
983 positions = []
985 # Use original image_stack for fallback dimensions
986 if isinstance(image_stack, cp.ndarray):
987 tile_height, tile_width = image_stack.shape[1:3]
988 else:
989 tile_height, tile_width = image_stack.shape[1:3]
991 spacing_factor = 1.0 - overlap_ratio
993 for tile_idx in range(len(image_stack)):
994 r = tile_idx // grid_cols
995 c = tile_idx % grid_cols
996 x_pos = c * tile_width * spacing_factor
997 y_pos = r * tile_height * spacing_factor
998 positions.append((float(x_pos), float(y_pos)))
1000 # Set result_image_stack for fallback case
1001 if not isinstance(image_stack, cp.ndarray):
1002 result_image_stack = image_stack
1003 else:
1004 result_image_stack = image_stack
1006 logger.info(f"Ashlar GPU: Completed processing {len(positions)} tile positions")
1008 return result_image_stack, positions
1011def materialize_ashlar_gpu_positions(data: List[Tuple[float, float]], path: str, filemanager) -> str:
1012 """Materialize Ashlar GPU tile positions as scientific CSV with grid metadata."""
1013 csv_path = path.replace('.pkl', '_ashlar_positions_gpu.csv')
1015 df = pd.DataFrame(data, columns=['x_position_um', 'y_position_um'])
1016 df['tile_id'] = range(len(df))
1018 # Estimate grid dimensions from position layout
1019 unique_x = sorted(df['x_position_um'].unique())
1020 unique_y = sorted(df['y_position_um'].unique())
1022 grid_cols = len(unique_x)
1023 grid_rows = len(unique_y)
1025 # Add grid coordinates
1026 df['grid_row'] = df.index // grid_cols
1027 df['grid_col'] = df.index % grid_cols
1029 # Add spacing information
1030 if len(unique_x) > 1:
1031 x_spacing = unique_x[1] - unique_x[0]
1032 df['x_spacing_um'] = x_spacing
1033 else:
1034 df['x_spacing_um'] = 0
1036 if len(unique_y) > 1:
1037 y_spacing = unique_y[1] - unique_y[0]
1038 df['y_spacing_um'] = y_spacing
1039 else:
1040 df['y_spacing_um'] = 0
1042 # Add metadata
1043 df['algorithm'] = 'ashlar_gpu'
1044 df['grid_dimensions'] = f"{grid_rows}x{grid_cols}"
1046 csv_content = df.to_csv(index=False)
1047 filemanager.save(csv_content, csv_path, "disk")
1048 return csv_path