Coverage for openhcs/processing/backends/pos_gen/ashlar_main_gpu.py: 7.8%
507 statements
« prev ^ index » next coverage.py v7.11.0, created at 2025-11-04 02:09 +0000
« prev ^ index » next coverage.py v7.11.0, created at 2025-11-04 02:09 +0000
1"""
2OpenHCS Interface for Ashlar GPU Stitching Algorithm
4Array-based EdgeAligner implementation that works directly with CuPy arrays
5instead of file-based readers. This is the complete Ashlar algorithm modified
6to accept arrays directly and run on GPU.
7"""
8from __future__ import annotations
9import logging
10import sys
11from typing import TYPE_CHECKING, Tuple, List
12import numpy as np
13import networkx as nx
14import scipy.spatial.distance
15import sklearn.linear_model
16import pandas as pd
18from openhcs.core.pipeline.function_contracts import special_inputs, special_outputs
19from openhcs.core.memory.decorators import cupy as cupy_func
20from openhcs.core.utils import optional_import
22# Import CuPy using the established optional import pattern
23cp = optional_import("cupy")
25import warnings
27if TYPE_CHECKING: 27 ↛ 28line 27 didn't jump to line 28 because the condition on line 27 was never true
28 pass
30logger = logging.getLogger(__name__)
33class DataWarning(Warning):
34 """Warnings about the content of user-provided image data."""
35 pass
38def warn_data(message):
39 """Issue a warning about image data."""
40 warnings.warn(message, DataWarning)
43class IntersectionGPU:
44 """Calculate intersection region between two tiles - EXACT Ashlar implementation for GPU."""
46 def __init__(self, corners1, corners2, min_size=0):
47 if not cp:
48 raise ImportError("CuPy is required for GPU intersection calculations")
49 if isinstance(min_size, (int, float)):
50 min_size = cp.full(2, min_size)
51 elif not isinstance(min_size, cp.ndarray):
52 min_size = cp.array(min_size)
53 self._calculate(corners1, corners2, min_size)
55 def _calculate(self, corners1, corners2, min_size):
56 """Calculate intersection parameters using EXACT Ashlar logic."""
57 # This is the EXACT logic from the original Ashlar Intersection class
58 max_shape = (corners2 - corners1).max(axis=0)
59 min_size = cp.clip(min_size, 1, max_shape)
60 position = corners1.max(axis=0)
61 initial_shape = cp.floor(corners2.min(axis=0) - position).astype(int)
62 clipped_shape = cp.maximum(initial_shape, min_size)
63 self.shape = cp.ceil(clipped_shape).astype(int)
64 self.padding = self.shape - initial_shape
65 self.offsets = cp.maximum(position - corners1 - self.padding, 0)
66 offset_diff = self.offsets[1] - self.offsets[0]
67 self.offset_diff_frac = offset_diff - cp.round(offset_diff)
70def _get_window(shape):
71 """Build a 2D Hann window (from Ashlar utils.get_window) on GPU."""
72 if cp is None:
73 raise ImportError("CuPy is required for GPU window functions")
74 # Build a 2D Hann window by taking the outer product of two 1-D windows.
75 wy = cp.hanning(shape[0]).astype(cp.float32)
76 wx = cp.hanning(shape[1]).astype(cp.float32)
77 window = cp.outer(wy, wx)
78 return window
81# Precompute Laplacian kernel for whitening (equivalent to skimage.restoration.uft.laplacian)
82_laplace_kernel_gpu = None
83if cp: 83 ↛ 84line 83 didn't jump to line 84 because the condition on line 83 was never true
84 _laplace_kernel_gpu = cp.array([[0, -1, 0], [-1, 4, -1], [0, -1, 0]], dtype=cp.float32)
87def whiten_gpu(img, sigma):
88 """
89 Vectorized GPU whitening filter - EXACT match to Ashlar reference implementation.
91 This implements the same whitening as ashlar.utils.whiten() but optimized for GPU:
92 - sigma=0: Uses Laplacian convolution (high-pass filter)
93 - sigma>0: Uses Gaussian-Laplacian (LoG filter)
95 Args:
96 img: CuPy array (2D image)
97 sigma: Standard deviation for Gaussian kernel (0 = pure Laplacian)
99 Returns:
100 CuPy array: Whitened image
101 """
102 # Convert to float32 (matches reference)
103 if not isinstance(img, cp.ndarray):
104 img = cp.asarray(img)
105 img = img.astype(cp.float32)
107 if sigma == 0:
108 # Pure Laplacian convolution (high-pass filter)
109 # Equivalent to scipy.ndimage.convolve(img, _laplace_kernel)
110 from cupyx.scipy import ndimage as cp_ndimage
111 output = cp_ndimage.convolve(img, _laplace_kernel_gpu, mode='reflect')
112 else:
113 # Gaussian-Laplacian (LoG filter)
114 # Equivalent to scipy.ndimage.gaussian_laplace(img, sigma)
115 from cupyx.scipy import ndimage as cp_ndimage
116 output = cp_ndimage.gaussian_laplace(img, sigma)
118 return output
121def whiten_gpu_vectorized(img_stack, sigma):
122 """
123 Vectorized GPU whitening for multiple images simultaneously.
125 This processes an entire stack of images in parallel on GPU for maximum efficiency.
127 Args:
128 img_stack: CuPy array of shape (N, H, W) - stack of N images
129 sigma: Standard deviation for Gaussian kernel (0 = pure Laplacian)
131 Returns:
132 CuPy array: Stack of whitened images with same shape as input
133 """
134 if not isinstance(img_stack, cp.ndarray):
135 img_stack = cp.asarray(img_stack)
136 img_stack = img_stack.astype(cp.float32)
138 if sigma == 0:
139 # Vectorized Laplacian convolution for entire stack
140 from cupyx.scipy import ndimage as cp_ndimage
141 # Process each image in the stack
142 output_stack = cp.empty_like(img_stack)
143 for i in range(img_stack.shape[0]):
144 output_stack[i] = cp_ndimage.convolve(img_stack[i], _laplace_kernel_gpu, mode='reflect')
145 else:
146 # Vectorized Gaussian-Laplacian for entire stack
147 from cupyx.scipy import ndimage as cp_ndimage
148 output_stack = cp.empty_like(img_stack)
149 for i in range(img_stack.shape[0]):
150 output_stack[i] = cp_ndimage.gaussian_laplace(img_stack[i], sigma)
152 return output_stack
155def ashlar_register_gpu(img1, img2, upsample=10):
156 """
157 GPU register function using cuCIM - matches CPU version with windowing only.
159 This uses cuCIM's phase_cross_correlation which is the GPU equivalent
160 of skimage.registration.phase_cross_correlation used in the CPU version.
161 No whitening filter - just windowing like the CPU version.
163 Args:
164 img1, img2: Input images
165 upsample: Upsampling factor for phase correlation
166 """
167 import cucim.skimage.registration
169 # Input validation (same as CPU version)
170 if img1 is None or img2 is None:
171 return cp.array([0.0, 0.0]), cp.inf
173 if img1.size == 0 or img2.size == 0:
174 return cp.array([0.0, 0.0]), cp.inf
176 if img1.shape != img2.shape:
177 return cp.array([0.0, 0.0]), cp.inf
179 if len(img1.shape) != 2:
180 return cp.array([0.0, 0.0]), cp.inf
182 if img1.shape[0] < 1 or img1.shape[1] < 1:
183 return cp.array([0.0, 0.0]), cp.inf
185 # Convert to CuPy arrays
186 if not isinstance(img1, cp.ndarray):
187 img1 = cp.asarray(img1)
188 if not isinstance(img2, cp.ndarray):
189 img2 = cp.asarray(img2)
191 # Convert to float32 and apply windowing (matches CPU version)
192 img1w = img1.astype(cp.float32) * _get_window(img1.shape)
193 img2w = img2.astype(cp.float32) * _get_window(img2.shape)
195 # Use cuCIM's phase cross correlation (GPU equivalent of skimage)
196 try:
197 shift, error, phase_diff = cucim.skimage.registration.phase_cross_correlation(
198 img1w, img2w, upsample_factor=upsample
199 )
201 # Convert to numpy for consistency with CPU version
202 shift = cp.asnumpy(shift)
203 error = float(error)
205 # Only log high errors to avoid spam
206 if error > 1.0: # High error threshold for Ashlar
207 logger.warning(f"Ashlar GPU: HIGH CORRELATION ERROR - Error={error:.4f}, Shift=({shift[0]:.2f}, {shift[1]:.2f})")
208 logger.warning(" This indicates poor overlap or image quality between tiles")
210 except Exception as e:
211 # Fallback if correlation fails
212 logger.error(f"Ashlar GPU: CORRELATION FAILED - Exception: {e}")
213 logger.error(" Returning infinite error")
214 shift = cp.array([0.0, 0.0])
215 error = cp.inf
217 return shift, error
223def ashlar_nccw_no_preprocessing_gpu(img1, img2):
224 """
225 GPU nccw function - faithful to Ashlar but with better numerical stability.
227 This matches the CPU version but with improved precision handling for GPU.
228 """
229 # Convert to CuPy arrays and float32 (equivalent to what whiten() does)
230 if not isinstance(img1, cp.ndarray):
231 img1 = cp.asarray(img1)
232 if not isinstance(img2, cp.ndarray):
233 img2 = cp.asarray(img2)
235 img1w = img1.astype(cp.float32)
236 img2w = img2.astype(cp.float32)
238 correlation = float(cp.abs(cp.sum(img1w * img2w)))
239 total_amplitude = float(cp.linalg.norm(img1w) * cp.linalg.norm(img2w))
241 if correlation > 0 and total_amplitude > 0:
242 diff = correlation - total_amplitude
243 if diff <= 0:
244 error = -cp.log(correlation / total_amplitude)
245 elif diff < 1e-3: # Increased tolerance for GPU precision
246 # This situation can occur due to numerical precision issues when
247 # img1 and img2 are very nearly or exactly identical. If the
248 # difference is small enough, let it slide.
249 error = 0
250 else:
251 # Instead of raising error, return a large but finite error
252 logger.warning(f"Ashlar GPU: NCCW numerical precision issue - diff={diff:.6f}, using error=100.0")
253 error = 100.0 # Large error but not infinite
254 else:
255 logger.warning(f"Ashlar GPU: NCCW invalid correlation - correlation={correlation:.6f}, total_amplitude={total_amplitude:.6f}")
256 error = cp.inf
258 # Log all NCCW results at INFO level for user visibility
259 error_float = float(error)
260 if error_float > 10.0: # High NCCW error threshold
261 logger.warning(f"Ashlar GPU: HIGH NCCW ERROR - Error={error_float:.4f}")
262 logger.warning(" This indicates poor image correlation between tiles")
263 else:
264 logger.info(f"Ashlar GPU: NCCW - Error={error_float:.4f}")
266 return error_float
272def ashlar_crop_gpu(img, offset, shape):
273 """
274 EXACT Ashlar crop function (from ashlar.utils.crop) for GPU arrays with boundary validation.
276 Note that this only crops to the nearest whole-pixel offset.
277 """
278 # Convert to CuPy if needed
279 if not isinstance(img, cp.ndarray):
280 img = cp.asarray(img)
281 if not isinstance(offset, cp.ndarray):
282 offset = cp.asarray(offset)
283 if not isinstance(shape, cp.ndarray):
284 shape = cp.asarray(shape)
286 # Validate inputs to prevent zero-sized arrays
287 if cp.any(shape <= 0):
288 raise ValueError(f"Invalid crop shape: {shape}. Shape must be positive.")
290 # Note that this only crops to the nearest whole-pixel offset.
291 start = cp.round(offset).astype(int)
292 end = start + shape
294 # Validate bounds to prevent invalid slicing
295 img_shape = cp.array(img.shape)
296 if cp.any(start < 0) or cp.any(end > img_shape):
297 # Clip to valid bounds
298 start = cp.maximum(start, 0)
299 end = cp.minimum(end, img_shape)
301 # Recalculate shape after clipping
302 new_shape = end - start
303 if cp.any(new_shape <= 0):
304 raise ValueError(f"Invalid crop region after bounds checking: start={start}, end={end}, img_shape={img_shape}")
306 img = img[start[0]:end[0], start[1]:end[1]]
307 return img
313class ArrayEdgeAlignerGPU:
314 """
315 Array-based EdgeAligner that implements the complete Ashlar algorithm
316 but works directly with CuPy arrays instead of file readers and runs on GPU.
317 """
319 def __init__(self, image_stack, positions, tile_size, pixel_size=1.0,
320 max_shift=15, alpha=0.01, max_error=None,
321 randomize=False, verbose=False, upsample_factor=10,
322 permutation_upsample=1, permutation_samples=1000,
323 min_permutation_samples=10, max_permutation_tries=100,
324 window_size_factor=0.1):
325 """
326 Initialize array-based EdgeAligner for position calculation on GPU.
328 Args:
329 image_stack: 3D numpy/cupy array (num_tiles, height, width) - preprocessed grayscale
330 positions: 2D array of tile positions (num_tiles, 2) in pixels
331 tile_size: Array [height, width] of tile dimensions
332 pixel_size: Pixel size in micrometers (for max_shift conversion)
333 max_shift: Maximum allowed shift in micrometers
334 alpha: Alpha value for error threshold (lower = stricter)
335 max_error: Explicit error threshold (None = auto-compute)
336 randomize: Use random seed for permutation testing
337 verbose: Enable verbose logging
338 """
339 # Convert to CuPy arrays for GPU processing
340 if not isinstance(image_stack, cp.ndarray):
341 self.image_stack = cp.asarray(image_stack)
342 else:
343 self.image_stack = image_stack
345 if not isinstance(positions, cp.ndarray):
346 self.positions = cp.asarray(positions, dtype=cp.float64)
347 else:
348 self.positions = positions.astype(cp.float64)
350 self.tile_size = cp.array(tile_size)
351 self.pixel_size = pixel_size
352 self.max_shift = max_shift
353 self.max_shift_pixels = self.max_shift / self.pixel_size
354 self.alpha = alpha
355 self.max_error = max_error
356 self.randomize = randomize
357 self.verbose = verbose
358 self.upsample_factor = upsample_factor
359 self.permutation_upsample = permutation_upsample
360 self.permutation_samples = permutation_samples
361 self.min_permutation_samples = min_permutation_samples
362 self.max_permutation_tries = max_permutation_tries
363 self.window_size_factor = window_size_factor
364 self._cache = {}
365 self.errors_negative_sampled = cp.empty(0)
367 # Build neighbors graph (this uses CPU operations with NetworkX)
368 self.neighbors_graph = self._build_neighbors_graph()
370 def _build_neighbors_graph(self):
371 """Build graph of neighboring (overlapping) tiles."""
372 # Convert to CPU for scipy operations
373 positions_cpu = cp.asnumpy(self.positions)
374 tile_size_cpu = cp.asnumpy(self.tile_size)
376 pdist = scipy.spatial.distance.pdist(positions_cpu, metric='cityblock')
377 sp = scipy.spatial.distance.squareform(pdist)
378 max_distance = tile_size_cpu.max() + 1
379 edges = zip(*np.nonzero((sp > 0) & (sp < max_distance)))
380 graph = nx.from_edgelist(edges)
381 graph.add_nodes_from(range(len(positions_cpu)))
382 return graph
385 def run(self):
386 """Run the complete Ashlar algorithm."""
387 self.check_overlaps()
388 self.compute_threshold()
389 self.register_all()
390 self.build_spanning_tree()
391 self.calculate_positions()
392 self.fit_model()
394 def check_overlaps(self):
395 """Check if tiles actually overlap based on positions."""
396 overlaps = []
397 for t1, t2 in self.neighbors_graph.edges:
398 overlap = self.tile_size - cp.abs(self.positions[t1] - self.positions[t2])
399 overlaps.append(overlap)
401 if overlaps:
402 overlaps = cp.stack(overlaps)
403 failures = cp.any(overlaps < 1, axis=1)
404 failures_cpu = cp.asnumpy(failures)
406 if len(failures_cpu) and all(failures_cpu):
407 warn_data("No tiles overlap, attempting alignment anyway.")
408 elif any(failures_cpu):
409 warn_data("Some neighboring tiles have zero overlap.")
411 def compute_threshold(self):
412 """Compute error threshold using permutation testing."""
413 if self.max_error is not None:
414 if self.verbose:
415 print(" using explicit error threshold")
416 return
418 edges = self.neighbors_graph.edges
419 num_tiles = len(self.image_stack)
421 # If not enough tiles overlap to matter, skip this whole thing
422 if len(edges) <= 1:
423 self.max_error = np.inf
424 return
426 widths = []
427 for t1, t2 in edges:
428 shape = self.intersection(t1, t2).shape
429 widths.append(cp.min(cp.array(shape)))
431 widths = cp.array(widths)
432 w = int(cp.max(widths))
433 max_offset = int(self.tile_size[0]) - w
435 # Number of possible pairs minus number of actual neighbor pairs
436 num_distant_pairs = num_tiles * (num_tiles - 1) // 2 - len(edges)
438 # Reduce permutation count for small datasets
439 n = self.permutation_samples if num_distant_pairs > 8 else (num_distant_pairs + 1) * self.min_permutation_samples
440 pairs = np.empty((n, 2), dtype=int) # Keep on CPU for random generation
441 offsets = np.empty((n, 2), dtype=int) # Keep on CPU for random generation
443 # Generate n random non-overlapping image strips
444 max_tries = self.max_permutation_tries
445 if self.randomize is False:
446 random_state = np.random.RandomState(0)
447 else:
448 random_state = np.random.RandomState()
450 for i in range(n):
451 # Limit tries to avoid infinite loop in pathological cases
452 for current_try in range(max_tries):
453 t1, t2 = random_state.randint(num_tiles, size=2)
454 o1, o2 = random_state.randint(max_offset, size=2)
456 # Check for non-overlapping strips and abort the retry loop
457 if t1 != t2 and (t1, t2) not in edges:
458 # Different, non-neighboring tiles -- always OK
459 break
460 elif t1 == t2 and abs(o1 - o2) > w:
461 # Same tile OK if strips don't overlap within the image
462 break
463 elif (t1, t2) in edges:
464 # Neighbors OK if either strip is entirely outside the
465 # expected overlap region (based on nominal positions)
466 its = self.intersection(t1, t2, cp.full(2, w))
467 ioff1, ioff2 = its.offsets[:, 0]
468 if (
469 its.shape[0] > its.shape[1]
470 or o1 < ioff1 - w or o1 > ioff1 + w
471 or o2 < ioff2 - w or o2 > ioff2 + w
472 ):
473 break
474 else:
475 # Retries exhausted. This should be very rare.
476 warn_data(f"Could not find non-overlapping strips in {max_tries} tries")
477 pairs[i] = t1, t2
478 offsets[i] = o1, o2
480 errors = cp.empty(n)
481 for i, ((t1, t2), (offset1, offset2)) in enumerate(zip(pairs, offsets)):
482 if self.verbose and (i % 10 == 9 or i == n - 1):
483 sys.stdout.write(f'\r quantifying alignment error {i + 1}/{n}')
484 sys.stdout.flush()
485 img1 = self.image_stack[t1][offset1:offset1+w, :]
486 img2 = self.image_stack[t2][offset2:offset2+w, :]
487 _, errors[i] = ashlar_register_gpu(img1, img2, upsample=self.permutation_upsample)
488 if self.verbose:
489 print()
490 self.errors_negative_sampled = errors
491 self.max_error = float(cp.percentile(errors, self.alpha * 100))
494 def register_all(self):
495 """Register all neighboring tile pairs."""
496 n = self.neighbors_graph.size()
497 for i, (t1, t2) in enumerate(self.neighbors_graph.edges, 1):
498 if self.verbose:
499 sys.stdout.write(f'\r aligning edge {i}/{n}')
500 sys.stdout.flush()
501 self.register_pair(t1, t2)
502 if self.verbose:
503 print()
504 self.all_errors = cp.array([x[1] for x in self._cache.values()])
506 # Set error values above the threshold to infinity
507 for k, v in self._cache.items():
508 shift_array = cp.array(v[0]) if not isinstance(v[0], cp.ndarray) else v[0]
509 if v[1] > self.max_error or cp.any(cp.abs(shift_array) > self.max_shift_pixels):
510 self._cache[k] = (v[0], cp.inf)
512 def register_pair(self, t1, t2):
513 """Return relative shift between images and the alignment error."""
514 key = tuple(sorted((t1, t2)))
515 try:
516 shift, error = self._cache[key]
517 except KeyError:
518 # Test a series of increasing overlap window sizes to help avoid
519 # missing alignments when the stage position error is large relative
520 # to the tile overlap. Simply using a large overlap in all cases
521 # limits the maximum achievable correlation thus increasing the
522 # error metric, leading to worse overall results. The window size
523 # starts at the nominal size and doubles until it's at least 10% of
524 # the tile size. If the nominal overlap is already 10% or greater,
525 # we only use that one size.
526 try:
527 smin = self.intersection(key[0], key[1]).shape
528 smax = cp.round(self.tile_size * self.window_size_factor)
529 sizes = [smin]
530 while any(cp.array(sizes[-1]) < smax):
531 sizes.append(cp.array(sizes[-1]) * 2)
533 # Try each window size and collect results
534 results = []
535 for s in sizes:
536 try:
537 result = self._register(key[0], key[1], s)
538 if result is not None:
539 results.append(result)
540 except Exception as e:
541 if self.verbose:
542 print(f" window size {s} failed: {e}")
543 continue
545 if not results:
546 # All window sizes failed, return large error
547 shift = cp.array([0.0, 0.0])
548 error = cp.inf
549 else:
550 # Use the shift from the window size that gave the lowest error
551 shift, _ = min(results, key=lambda r: r[1])
552 # Extract the images from the nominal overlap window but with the
553 # shift applied to the second tile's position, and compute the error
554 # metric on these images. This should be even lower than the error
555 # computed above.
556 try:
557 _, o1, o2 = self.overlap(key[0], key[1], shift=shift)
558 error = ashlar_nccw_no_preprocessing_gpu(o1, o2)
559 except Exception as e:
560 if self.verbose:
561 print(f" final error computation failed: {e}")
562 error = cp.inf
564 except Exception as e:
565 if self.verbose:
566 print(f" registration failed for tiles {key}: {e}")
567 shift = cp.array([0.0, 0.0])
568 error = cp.inf
570 self._cache[key] = (shift, error)
571 if t1 > t2:
572 shift = -shift
573 # Return copy of shift to prevent corruption of cached values
574 return shift.copy(), error
576 def _register(self, t1, t2, min_size=0):
577 """Register a single tile pair with given minimum size."""
578 try:
579 its, img1, img2 = self.overlap(t1, t2, min_size)
581 # Validate that we got valid images
582 if img1.size == 0 or img2.size == 0:
583 if self.verbose:
584 print(f" empty images for tiles {t1}, {t2} with min_size {min_size}")
585 return None
587 # Account for padding, flipping the sign depending on the direction
588 # between the tiles
589 p1, p2 = self.positions[[t1, t2]]
590 sx = 1 if p1[1] >= p2[1] else -1
591 sy = 1 if p1[0] >= p2[0] else -1
592 padding = cp.array(its.padding) * cp.array([sy, sx])
593 shift, error = ashlar_register_gpu(img1, img2, upsample=self.upsample_factor)
594 shift = cp.array(shift) + padding
595 return shift.get(), error
596 except Exception as e:
597 if self.verbose:
598 print(f" _register failed for tiles {t1}, {t2}: {e}")
599 return None
602 def intersection(self, t1, t2, min_size=0, shift=None):
603 """Calculate intersection region between two tiles."""
604 corners1 = self.positions[[t1, t2]].copy()
605 if shift is not None:
606 if not isinstance(shift, cp.ndarray):
607 shift = cp.array(shift)
608 corners1[1] += shift
609 corners2 = corners1 + self.tile_size
610 return IntersectionGPU(corners1, corners2, min_size)
612 def crop(self, tile_id, offset, shape):
613 """Crop image from tile at given offset and shape."""
614 img = self.image_stack[tile_id]
615 return ashlar_crop_gpu(img, offset, shape)
617 def overlap(self, t1, t2, min_size=0, shift=None):
618 """Extract overlapping regions between two tiles."""
619 its = self.intersection(t1, t2, min_size, shift)
621 # Validate intersection shape before cropping
622 if cp.any(its.shape <= 0):
623 raise ValueError(f"Invalid intersection shape {its.shape} for tiles {t1}, {t2}")
625 img1 = self.crop(t1, its.offsets[0], its.shape)
626 img2 = self.crop(t2, its.offsets[1], its.shape)
627 return its, img1, img2
633 def build_spanning_tree(self):
634 """Build minimum spanning tree using GPU Boruvka algorithm."""
635 # Import the Boruvka MST implementation
636 from openhcs.processing.backends.pos_gen.mist.boruvka_mst import build_mst_gpu_boruvka
638 # Convert cache to Boruvka format
639 valid_edges = [(t1, t2, shift, error) for (t1, t2), (shift, error) in self._cache.items() if cp.isfinite(error)]
641 if len(valid_edges) == 0:
642 # No valid edges - create empty graph with all nodes
643 self.spanning_tree = nx.Graph()
644 self.spanning_tree.add_nodes_from(range(len(self.positions)))
645 return
647 # Prepare arrays for Boruvka MST
648 connection_from = cp.array([t1 for t1, t2, shift, error in valid_edges], dtype=cp.int32)
649 connection_to = cp.array([t2 for t1, t2, shift, error in valid_edges], dtype=cp.int32)
650 connection_dx = cp.array([shift[1] for t1, t2, shift, error in valid_edges], dtype=cp.float32) # x shift
651 connection_dy = cp.array([shift[0] for t1, t2, shift, error in valid_edges], dtype=cp.float32) # y shift
652 # Use negative error as quality (higher quality = lower error)
653 connection_quality = cp.array([-error for t1, t2, shift, error in valid_edges], dtype=cp.float32)
655 num_nodes = len(self.positions)
657 try:
658 # Run GPU Boruvka MST
659 mst_result = build_mst_gpu_boruvka(
660 connection_from, connection_to, connection_dx, connection_dy,
661 connection_quality, num_nodes
662 )
664 # Convert back to NetworkX format for compatibility with rest of algorithm
665 self.spanning_tree = nx.Graph()
666 self.spanning_tree.add_nodes_from(range(num_nodes))
668 for edge in mst_result['edges']:
669 t1, t2 = edge['from'], edge['to']
670 # Reconstruct error from quality
671 error = -edge['quality'] if 'quality' in edge else 0.0
672 self.spanning_tree.add_edge(t1, t2, weight=error)
674 except Exception as e:
675 # Fallback to NetworkX if Boruvka fails
676 print(f"Boruvka MST failed, falling back to NetworkX: {e}")
677 g = nx.Graph()
678 g.add_nodes_from(self.neighbors_graph)
679 g.add_weighted_edges_from(
680 (t1, t2, error)
681 for (t1, t2), (_, error) in self._cache.items()
682 if cp.isfinite(error)
683 )
684 spanning_tree = nx.Graph()
685 spanning_tree.add_nodes_from(g)
686 for c in nx.connected_components(g):
687 cc = g.subgraph(c)
688 center = nx.center(cc)[0]
689 paths = nx.single_source_dijkstra_path(cc, center).values()
690 for path in paths:
691 nx.add_path(spanning_tree, path)
692 self.spanning_tree = spanning_tree
694 def calculate_positions(self):
695 """Calculate final positions from spanning tree."""
696 shifts = {}
697 for c in nx.connected_components(self.spanning_tree):
698 cc = self.spanning_tree.subgraph(c)
699 center = nx.center(cc)[0]
700 shifts[center] = cp.array([0, 0])
701 for edge in nx.traversal.bfs_edges(cc, center):
702 source, dest = edge
703 if source not in shifts:
704 source, dest = dest, source
705 shift = self.register_pair(source, dest)[0]
706 shifts[dest] = shifts[source] + cp.array(shift)
707 if shifts:
708 self.shifts = cp.array([s for _, s in sorted(shifts.items())])
709 self.final_positions = self.positions + self.shifts
710 else:
711 # TODO: fill in shifts and positions with 0x2 arrays
712 raise NotImplementedError("No images")
715 def fit_model(self):
716 """Fit linear model to handle disconnected components."""
717 components = sorted(
718 nx.connected_components(self.spanning_tree),
719 key=len, reverse=True
720 )
721 # Fit LR model on positions of largest connected component
722 cc0 = list(components[0])
723 self.lr = sklearn.linear_model.LinearRegression()
725 # Convert to CPU for sklearn operations
726 positions_cpu = cp.asnumpy(self.positions[cc0])
727 final_positions_cpu = cp.asnumpy(self.final_positions[cc0])
728 self.lr.fit(positions_cpu, final_positions_cpu)
730 # Fix up degenerate transform matrix. This happens when the spanning
731 # tree is completely edgeless or cc0's metadata positions fall in a
732 # straight line. In this case we fall back to the identity transform.
733 if np.linalg.det(self.lr.coef_) < 1e-3:
734 warn_data(
735 "Could not align enough edges, proceeding anyway with original"
736 " stage positions."
737 )
738 self.lr.coef_ = np.diag(np.ones(2))
739 self.lr.intercept_ = np.zeros(2)
741 # Adjust position of remaining components so their centroids match
742 # the predictions of the model
743 for cc in components[1:]:
744 nodes = list(cc)
745 centroid_m = cp.mean(self.positions[nodes], axis=0)
746 centroid_f = cp.mean(self.final_positions[nodes], axis=0)
748 # Convert to CPU for prediction, then back to GPU
749 centroid_m_cpu = cp.asnumpy(centroid_m).reshape(1, -1)
750 shift_cpu = self.lr.predict(centroid_m_cpu)[0] - cp.asnumpy(centroid_f)
751 shift = cp.array(shift_cpu)
753 self.final_positions[nodes] += shift
755 # Adjust positions and model intercept to put origin at 0,0
756 self.origin = cp.min(self.final_positions, axis=0)
757 self.final_positions -= self.origin
758 self.lr.intercept_ -= cp.asnumpy(self.origin)
761def _calculate_initial_positions_gpu(image_stack, grid_dims: tuple, overlap_ratio: float):
762 """Calculate initial grid positions based on overlap ratio (GPU version)."""
763 grid_rows, grid_cols = grid_dims
765 # Handle both numpy and cupy arrays
766 if isinstance(image_stack, cp.ndarray):
767 tile_height, tile_width = image_stack.shape[1:3]
768 else:
769 tile_height, tile_width = image_stack.shape[1:3]
771 spacing_factor = 1.0 - overlap_ratio
773 positions = []
774 for tile_idx in range(len(image_stack)):
775 r = tile_idx // grid_cols
776 c = tile_idx % grid_cols
778 y_pos = r * tile_height * spacing_factor
779 x_pos = c * tile_width * spacing_factor
780 positions.append([y_pos, x_pos])
782 return cp.array(positions, dtype=cp.float64)
785def _convert_ashlar_positions_to_openhcs_gpu(ashlar_positions) -> List[Tuple[float, float]]:
786 """Convert Ashlar positions to OpenHCS format (GPU version)."""
787 # Convert to CPU if needed
788 if isinstance(ashlar_positions, cp.ndarray):
789 ashlar_positions = cp.asnumpy(ashlar_positions)
791 positions = []
792 for tile_idx in range(len(ashlar_positions)):
793 y, x = ashlar_positions[tile_idx]
794 positions.append((float(x), float(y))) # OpenHCS uses (x, y) format
795 return positions
798@special_inputs("grid_dimensions")
799@special_outputs("positions")
800@cupy_func
801def ashlar_compute_tile_positions_gpu(
802 image_stack,
803 grid_dimensions: Tuple[int, int],
804 overlap_ratio: float = 0.1,
805 max_shift: float = 15.0,
806 stitch_alpha: float = 0.01,
807 max_error: float = None,
808 randomize: bool = False,
809 verbose: bool = False,
810 upsample_factor: int = 10,
811 permutation_upsample: int = 1,
812 permutation_samples: int = 1000,
813 min_permutation_samples: int = 10,
814 max_permutation_tries: int = 100,
815 window_size_factor: float = 0.1,
816 **kwargs
817) -> Tuple[np.ndarray, List[Tuple[float, float]]]:
818 """
819 Compute tile positions using the Ashlar algorithm on GPU - matches CPU version.
821 This function implements the Ashlar edge-based stitching algorithm using GPU acceleration.
822 It performs position calculation with minimal preprocessing (windowing only, no whitening)
823 to match the CPU version behavior.
825 Args:
826 image_stack: 3D numpy/cupy array of shape (num_tiles, height, width) containing preprocessed
827 grayscale images. Each slice [i] should be a single-channel 2D image ready
828 for correlation analysis. No further preprocessing will be applied.
830 grid_dimensions: Tuple of (grid_rows, grid_cols) specifying the logical arrangement of
831 tiles. For example, (2, 3) means 2 rows and 3 columns of tiles, for a
832 total of 6 tiles. Must match the number of images in image_stack.
834 overlap_ratio: Expected fractional overlap between adjacent tiles (0.0-1.0). Default 0.1
835 means 10% overlap. This is used to calculate initial grid positions and
836 should match the actual overlap in your microscopy data. Typical values:
837 - 0.05-0.15 for well-controlled microscopes
838 - 0.15-0.25 for less precise stages
840 max_shift: Maximum allowed shift correction in micrometers. Default 15.0. This limits
841 how far tiles can be moved from their initial grid positions during alignment.
842 Should be set based on your microscope's stage accuracy:
843 - 5-15 μm for high-precision stages
844 - 15-50 μm for standard stages
845 - 50+ μm for low-precision or manual stages
847 stitch_alpha: Alpha value for statistical error threshold computation (0.0-1.0). Default
848 0.01 means 1% false positive rate. Lower values are stricter and reject more
849 alignments, higher values are more permissive. This controls the trade-off
850 between alignment quality and success rate:
851 - 0.001-0.01: Very strict, high quality alignments only
852 - 0.01-0.05: Balanced (recommended for most data)
853 - 0.05-0.1: Permissive, accepts lower quality alignments
855 max_error: Explicit error threshold for rejecting alignments (None = auto-compute).
856 When None (default), the threshold is computed automatically using permutation
857 testing. Set to a specific value to override automatic computation. Higher
858 values accept more alignments, lower values are stricter.
860 randomize: Whether to use random seed for permutation testing (bool). Default False uses
861 a fixed seed for reproducible results. Set True for different random sampling
862 in each run. Generally should be False for consistent results.
864 verbose: Enable detailed progress logging (bool). Default False. When True, prints
865 progress information including permutation testing, edge alignment, and
866 spanning tree construction. Useful for debugging and monitoring progress
867 on large datasets.
869 upsample_factor: Sub-pixel accuracy factor for phase cross correlation (int). Default 10.
870 Higher values provide better sub-pixel accuracy but increase computation time.
871 Range: 1-100+. Values of 10-50 are typical for high-accuracy stitching.
872 - 1: Pixel-level accuracy (fastest)
873 - 10: 0.1 pixel accuracy (balanced)
874 - 50: 0.02 pixel accuracy (high precision)
876 permutation_upsample: Upsample factor for permutation testing (int). Default 1.
877 Lower than upsample_factor for speed during threshold computation.
878 Usually kept at 1 since permutation testing doesn't need sub-pixel accuracy.
880 permutation_samples: Number of random samples for error threshold computation (int). Default 1000.
881 Higher values give more accurate thresholds but slower computation.
882 Automatically reduced for small datasets to avoid infinite loops.
884 min_permutation_samples: Minimum permutation samples for small datasets (int). Default 10.
885 When there are few non-overlapping pairs, this sets the minimum
886 number of samples to ensure statistical validity.
888 max_permutation_tries: Maximum attempts to find non-overlapping strips (int). Default 100.
889 Prevents infinite loops in pathological cases where valid strips
890 are hard to find. Rarely needs adjustment.
892 window_size_factor: Fraction of tile size for maximum window size (float). Default 0.1.
893 Controls the largest overlap window tested during progressive sizing.
894 Larger values allow detection of bigger stage errors but may reduce
895 correlation quality. Range: 0.05-0.2 typical.
897 filter_sigma: Whitening filter sigma for preprocessing (float). Default 0.
898 Controls the whitening filter applied before correlation:
899 - 0: Pure Laplacian filter (high-pass, matches original Ashlar)
900 - >0: Gaussian-Laplacian (LoG) filter with specified sigma
901 - Typical values: 0-2.0 for most microscopy data
903 **kwargs: Additional parameters (ignored). Allows compatibility with other stitching
904 algorithms that may have different parameter sets.
906 Returns:
907 Tuple of (image_stack, positions) where:
908 - image_stack: The original input image array (unchanged)
909 - positions: List of (x, y) position tuples in OpenHCS format, one per tile.
910 Positions are in pixel coordinates with (0, 0) at the top-left.
911 The positions represent the optimal tile placement after Ashlar
912 alignment, accounting for stage errors and image correlation.
914 Raises:
915 Exception: If the Ashlar algorithm fails (e.g., insufficient overlap, correlation
916 errors), the function automatically falls back to grid-based positioning
917 using the specified overlap_ratio.
919 Notes:
920 - This implementation contains the complete Ashlar algorithm including whitening
921 filter preprocessing, permutation testing, progressive window sizing, minimum
922 spanning tree construction, and linear model fitting for disconnected components.
923 - The correlation functions are identical to original Ashlar including proper
924 whitening/filtering preprocessing as specified by filter_sigma parameter.
925 - For best results, ensure your image_stack contains single-channel grayscale
926 images. The whitening filter will be applied automatically during correlation.
927 """
928 grid_rows, grid_cols = grid_dimensions
930 if verbose:
931 logger.info(f"Ashlar GPU: Processing {grid_rows}x{grid_cols} grid with {len(image_stack)} tiles")
933 try:
934 # Convert to CuPy array if needed
935 if not isinstance(image_stack, cp.ndarray):
936 image_stack_gpu = cp.asarray(image_stack)
937 else:
938 image_stack_gpu = image_stack
940 # Calculate initial grid positions
941 initial_positions = _calculate_initial_positions_gpu(image_stack_gpu, grid_dimensions, overlap_ratio)
942 tile_size = cp.array(image_stack_gpu.shape[1:3]) # (height, width)
944 # Create and run ArrayEdgeAlignerGPU with complete Ashlar algorithm
945 logger.info("Running complete Ashlar edge-based stitching algorithm on GPU")
946 aligner = ArrayEdgeAlignerGPU(
947 image_stack=image_stack_gpu,
948 positions=initial_positions,
949 tile_size=tile_size,
950 pixel_size=1.0, # Assume 1 micrometer per pixel if not specified
951 max_shift=max_shift,
952 alpha=stitch_alpha,
953 max_error=max_error,
954 randomize=randomize,
955 verbose=verbose,
956 upsample_factor=upsample_factor,
957 permutation_upsample=permutation_upsample,
958 permutation_samples=permutation_samples,
959 min_permutation_samples=min_permutation_samples,
960 max_permutation_tries=max_permutation_tries,
961 window_size_factor=window_size_factor
962 )
964 # Run the complete algorithm
965 aligner.run()
967 # Convert to OpenHCS format
968 positions = _convert_ashlar_positions_to_openhcs_gpu(aligner.final_positions)
970 # Convert result back to original format (CPU if input was CPU)
971 if not isinstance(image_stack, cp.ndarray):
972 result_image_stack = cp.asnumpy(image_stack_gpu)
973 else:
974 result_image_stack = image_stack_gpu
976 logger.info("Ashlar GPU algorithm completed successfully")
978 except Exception as e:
979 logger.error(f"Ashlar GPU algorithm failed: {e}")
980 # Fallback to grid positions if Ashlar fails
981 logger.warning("Falling back to grid-based positioning")
982 positions = []
984 # Use original image_stack for fallback dimensions
985 if isinstance(image_stack, cp.ndarray):
986 tile_height, tile_width = image_stack.shape[1:3]
987 else:
988 tile_height, tile_width = image_stack.shape[1:3]
990 spacing_factor = 1.0 - overlap_ratio
992 for tile_idx in range(len(image_stack)):
993 r = tile_idx // grid_cols
994 c = tile_idx % grid_cols
995 x_pos = c * tile_width * spacing_factor
996 y_pos = r * tile_height * spacing_factor
997 positions.append((float(x_pos), float(y_pos)))
999 # Set result_image_stack for fallback case
1000 if not isinstance(image_stack, cp.ndarray):
1001 result_image_stack = image_stack
1002 else:
1003 result_image_stack = image_stack
1005 logger.info(f"Ashlar GPU: Completed processing {len(positions)} tile positions")
1007 return result_image_stack, positions
1010def materialize_ashlar_gpu_positions(data: List[Tuple[float, float]], path: str, filemanager) -> str:
1011 """Materialize Ashlar GPU tile positions as scientific CSV with grid metadata."""
1012 csv_path = path.replace('.pkl', '_ashlar_positions_gpu.csv')
1014 df = pd.DataFrame(data, columns=['x_position_um', 'y_position_um'])
1015 df['tile_id'] = range(len(df))
1017 # Estimate grid dimensions from position layout
1018 unique_x = sorted(df['x_position_um'].unique())
1019 unique_y = sorted(df['y_position_um'].unique())
1021 grid_cols = len(unique_x)
1022 grid_rows = len(unique_y)
1024 # Add grid coordinates
1025 df['grid_row'] = df.index // grid_cols
1026 df['grid_col'] = df.index % grid_cols
1028 # Add spacing information
1029 if len(unique_x) > 1:
1030 x_spacing = unique_x[1] - unique_x[0]
1031 df['x_spacing_um'] = x_spacing
1032 else:
1033 df['x_spacing_um'] = 0
1035 if len(unique_y) > 1:
1036 y_spacing = unique_y[1] - unique_y[0]
1037 df['y_spacing_um'] = y_spacing
1038 else:
1039 df['y_spacing_um'] = 0
1041 # Add metadata
1042 df['algorithm'] = 'ashlar_gpu'
1043 df['grid_dimensions'] = f"{grid_rows}x{grid_cols}"
1045 csv_content = df.to_csv(index=False)
1046 filemanager.save(csv_content, csv_path, "disk")
1047 return csv_path