Coverage for openhcs/processing/backends/pos_gen/ashlar_main_cpu.py: 79.1%
402 statements
« prev ^ index » next coverage.py v7.10.3, created at 2025-08-14 05:57 +0000
« prev ^ index » next coverage.py v7.10.3, created at 2025-08-14 05:57 +0000
1"""
2OpenHCS Interface for Ashlar CPU Stitching Algorithm
4Array-based EdgeAligner implementation that works directly with numpy arrays
5instead of file-based readers. This is the complete Ashlar algorithm modified
6to accept arrays directly.
7"""
8from __future__ import annotations
9import logging
10import sys
11from typing import TYPE_CHECKING, Tuple, List
12import numpy as np
13import networkx as nx
14import scipy.spatial.distance
15import sklearn.linear_model
16import pandas as pd
18from openhcs.core.pipeline.function_contracts import special_inputs, special_outputs
19from openhcs.core.memory.decorators import numpy as numpy_func
20from openhcs.core.utils import optional_import
22import warnings
24if TYPE_CHECKING: 24 ↛ 25line 24 didn't jump to line 25 because the condition on line 24 was never true
25 pass
27logger = logging.getLogger(__name__)
30class DataWarning(Warning):
31 """Warnings about the content of user-provided image data."""
32 pass
35def warn_data(message):
36 """Issue a warning about image data."""
37 warnings.warn(message, DataWarning)
40class Intersection:
41 """Calculate intersection region between two tiles (extracted from Ashlar)."""
43 def __init__(self, corners1, corners2, min_size=0):
44 if np.isscalar(min_size):
45 min_size = np.repeat(min_size, 2)
46 self._calculate(corners1, corners2, min_size)
48 def _calculate(self, corners1, corners2, min_size):
49 """Calculate intersection parameters with robust boundary validation."""
50 # corners1 and corners2 are arrays of shape (2, 2) containing
51 # the upper-left and lower-right corners of the two tiles
52 max_shape = (corners2 - corners1).max(axis=0)
53 min_size = min_size.clip(1, max_shape)
54 position = corners1.max(axis=0)
55 initial_shape = np.floor(corners2.min(axis=0) - position).astype(int)
56 clipped_shape = np.maximum(initial_shape, min_size)
57 self.shape = np.ceil(clipped_shape).astype(int)
58 self.padding = self.shape - initial_shape
60 # Calculate offsets with boundary validation
61 raw_offsets = np.maximum(position - corners1 - self.padding, 0)
63 # Validate that offsets + shape don't exceed tile boundaries
64 tile_sizes = corners2 - corners1
65 for i in range(2):
66 # Ensure offset + shape <= tile_size for each tile
67 max_offset = tile_sizes[i] - self.shape
68 raw_offsets[i] = np.minimum(raw_offsets[i], np.maximum(max_offset, 0))
70 # Ensure shape doesn't exceed available space
71 available_space = tile_sizes[i] - raw_offsets[i]
72 self.shape = np.minimum(self.shape, available_space.astype(int))
74 # Final validation - ensure shape is positive
75 self.shape = np.maximum(self.shape, 1)
77 self.offsets = raw_offsets.astype(int)
79 # Calculate fractional offset difference for subpixel accuracy
80 offset_diff = self.offsets[1] - self.offsets[0]
81 self.offset_diff_frac = offset_diff - offset_diff.round()
84def _get_window(shape):
85 """Build a 2D Hann window (from Ashlar utils.get_window)."""
86 # Build a 2D Hann window by taking the outer product of two 1-D windows.
87 wy = np.hanning(shape[0]).astype(np.float32)
88 wx = np.hanning(shape[1]).astype(np.float32)
89 window = np.outer(wy, wx)
90 return window
93def ashlar_register_no_preprocessing(img1, img2, upsample=10):
94 """
95 Robust Ashlar register function with comprehensive input validation.
97 This is based on ashlar.utils.register() but adds validation to handle
98 edge cases that can occur with real microscopy data.
99 """
100 import itertools
101 import scipy.ndimage
102 import skimage.registration
104 # Input validation
105 if img1 is None or img2 is None: 105 ↛ 106line 105 didn't jump to line 106 because the condition on line 105 was never true
106 return np.array([0.0, 0.0]), np.inf
108 if img1.size == 0 or img2.size == 0: 108 ↛ 109line 108 didn't jump to line 109 because the condition on line 108 was never true
109 return np.array([0.0, 0.0]), np.inf
111 if img1.shape != img2.shape: 111 ↛ 112line 111 didn't jump to line 112 because the condition on line 111 was never true
112 return np.array([0.0, 0.0]), np.inf
114 if len(img1.shape) != 2: 114 ↛ 115line 114 didn't jump to line 115 because the condition on line 114 was never true
115 return np.array([0.0, 0.0]), np.inf
117 if img1.shape[0] < 1 or img1.shape[1] < 1: 117 ↛ 118line 117 didn't jump to line 118 because the condition on line 117 was never true
118 return np.array([0.0, 0.0]), np.inf
120 # Convert to float32 (equivalent to what whiten() does) - match GPU version
121 img1w = img1.astype(np.float32)
122 img2w = img2.astype(np.float32)
124 # Apply windowing function (from original Ashlar)
125 img1w = img1w * _get_window(img1w.shape)
126 img2w = img2w * _get_window(img2w.shape)
128 # Use skimage's phase cross correlation with error handling
129 try:
130 shift = skimage.registration.phase_cross_correlation(
131 img1w,
132 img2w,
133 upsample_factor=upsample,
134 normalization=None
135 )[0]
136 except Exception as e:
137 # If phase correlation fails, return large error
138 logger.error(f"Ashlar CPU: PHASE CORRELATION FAILED - Exception: {e}")
139 logger.error(f" Returning infinite error")
140 return np.array([0.0, 0.0]), np.inf
142 # At this point we may have a shift in the wrong quadrant since the FFT
143 # assumes the signal is periodic. We test all four possibilities and return
144 # the shift that gives the highest direct correlation (sum of products).
145 shape = np.array(img1.shape)
146 shift_pos = (shift + shape) % shape
147 shift_neg = shift_pos - shape
148 shifts = list(itertools.product(*zip(shift_pos, shift_neg)))
149 correlations = []
150 for s in shifts:
151 try:
152 shifted_img = scipy.ndimage.shift(img2w, s, order=0)
153 corr = np.abs(np.sum(img1w * shifted_img))
154 correlations.append(corr)
155 except Exception:
156 correlations.append(0.0)
158 if not correlations or max(correlations) == 0: 158 ↛ 159line 158 didn't jump to line 159 because the condition on line 158 was never true
159 logger.warning(f"Ashlar CPU: NO VALID CORRELATIONS - All correlations failed or zero")
160 return np.array([0.0, 0.0]), np.inf
162 idx = np.argmax(correlations)
163 shift = shifts[idx]
164 correlation = correlations[idx]
165 total_amplitude = np.linalg.norm(img1w) * np.linalg.norm(img2w)
166 if correlation > 0 and total_amplitude > 0: 166 ↛ 169line 166 didn't jump to line 169 because the condition on line 166 was always true
167 error = -np.log(correlation / total_amplitude)
168 else:
169 error = np.inf
171 # Log all correlation results at INFO level for user visibility
172 if error > 1.0: # High error threshold for Ashlar
173 logger.warning(f"Ashlar CPU: HIGH CORRELATION ERROR - Error={error:.4f}, Shift=({shift[0]:.2f}, {shift[1]:.2f})")
174 logger.warning(f" This indicates poor overlap or image quality between tiles")
175 else:
176 logger.info(f"Ashlar CPU: Correlation - Error={error:.4f}, Shift=({shift[0]:.2f}, {shift[1]:.2f})")
178 return shift, error
181def ashlar_nccw_no_preprocessing(img1, img2):
182 """
183 Robust Ashlar nccw function with comprehensive input validation.
185 This is based on ashlar.utils.nccw() but adds validation to handle
186 edge cases that can occur with real microscopy data.
187 """
188 # Input validation
189 if img1 is None or img2 is None: 189 ↛ 190line 189 didn't jump to line 190 because the condition on line 189 was never true
190 return np.inf
192 if img1.size == 0 or img2.size == 0: 192 ↛ 193line 192 didn't jump to line 193 because the condition on line 192 was never true
193 return np.inf
195 if img1.shape != img2.shape: 195 ↛ 196line 195 didn't jump to line 196 because the condition on line 195 was never true
196 return np.inf
198 if len(img1.shape) != 2: 198 ↛ 199line 198 didn't jump to line 199 because the condition on line 198 was never true
199 return np.inf
201 if img1.shape[0] < 1 or img1.shape[1] < 1: 201 ↛ 202line 201 didn't jump to line 202 because the condition on line 201 was never true
202 return np.inf
204 # Convert to float32 (equivalent to what whiten() does) - match GPU version
205 img1w = img1.astype(np.float32)
206 img2w = img2.astype(np.float32)
208 correlation = np.abs(np.sum(img1w * img2w))
209 total_amplitude = np.linalg.norm(img1w) * np.linalg.norm(img2w)
210 if correlation > 0 and total_amplitude > 0: 210 ↛ 224line 210 didn't jump to line 224 because the condition on line 210 was always true
211 diff = correlation - total_amplitude
212 if diff <= 0:
213 error = -np.log(correlation / total_amplitude)
214 elif diff < 1e-3: # Increased tolerance for robustness 214 ↛ 218line 214 didn't jump to line 218 because the condition on line 214 was never true
215 # This situation can occur due to numerical precision issues when
216 # img1 and img2 are very nearly or exactly identical. If the
217 # difference is small enough, let it slide.
218 error = 0
219 else:
220 # Instead of raising error, return large but finite error
221 logger.warning(f"Ashlar CPU: NCCW numerical precision issue - diff={diff:.6f}, using error=100.0")
222 error = 100.0 # Large error but not infinite
223 else:
224 logger.warning(f"Ashlar CPU: NCCW invalid correlation - correlation={correlation:.6f}, total_amplitude={total_amplitude:.6f}")
225 error = np.inf
227 # Log all NCCW results at INFO level for user visibility
228 if error > 10.0: # High NCCW error threshold
229 logger.warning(f"Ashlar CPU: HIGH NCCW ERROR - Error={error:.4f}")
230 logger.warning(f" This indicates poor image correlation between tiles")
231 else:
232 logger.info(f"Ashlar CPU: NCCW - Error={error:.4f}")
234 return error
237def ashlar_crop(img, offset, shape):
238 """
239 Robust Ashlar crop function with comprehensive boundary validation.
241 This is based on ashlar.utils.crop() but adds validation to handle
242 edge cases that can occur with real microscopy data.
243 """
244 # Input validation
245 if img is None or img.size == 0: 245 ↛ 246line 245 didn't jump to line 246 because the condition on line 245 was never true
246 raise ValueError("Cannot crop from empty or None image")
248 # Convert to integers and validate
249 start = offset.round().astype(int)
250 shape = np.round(shape).astype(int)
252 # Ensure start is non-negative
253 start = np.maximum(start, 0)
255 # Ensure shape is positive
256 shape = np.maximum(shape, 1)
258 # Validate bounds
259 img_height, img_width = img.shape[:2]
260 end = start + shape
262 # Clamp to image boundaries
263 start[0] = min(start[0], img_height - 1)
264 start[1] = min(start[1], img_width - 1)
265 end[0] = min(end[0], img_height)
266 end[1] = min(end[1], img_width)
268 # Ensure we have a valid region
269 if end[0] <= start[0] or end[1] <= start[1]: 269 ↛ 271line 269 didn't jump to line 271 because the condition on line 269 was never true
270 # Return minimum valid region if bounds are invalid
271 return img[start[0]:start[0]+1, start[1]:start[1]+1]
273 return img[start[0]:end[0], start[1]:end[1]]
276class ArrayEdgeAligner:
277 """
278 Array-based EdgeAligner that implements the complete Ashlar algorithm
279 but works directly with numpy arrays instead of file readers.
280 """
282 def __init__(self, image_stack, positions, tile_size, pixel_size=1.0,
283 max_shift=15, alpha=0.01, max_error=None,
284 randomize=False, verbose=False):
285 """
286 Initialize array-based EdgeAligner for pure position calculation.
288 Args:
289 image_stack: 3D numpy array (num_tiles, height, width) - preprocessed grayscale
290 positions: 2D array of tile positions (num_tiles, 2) in pixels
291 tile_size: Array [height, width] of tile dimensions
292 pixel_size: Pixel size in micrometers (for max_shift conversion)
293 max_shift: Maximum allowed shift in micrometers
294 alpha: Alpha value for error threshold (lower = stricter)
295 max_error: Explicit error threshold (None = auto-compute)
296 randomize: Use random seed for permutation testing
297 verbose: Enable verbose logging
298 """
299 self.image_stack = image_stack
300 self.positions = positions.astype(float)
301 self.tile_size = np.array(tile_size)
302 self.pixel_size = pixel_size
303 self.max_shift = max_shift
304 self.max_shift_pixels = self.max_shift / self.pixel_size
305 self.alpha = alpha
306 self.max_error = max_error
307 self.randomize = randomize
308 self.verbose = verbose
309 self._cache = {}
310 self.errors_negative_sampled = np.empty(0)
312 # Build neighbors graph
313 self.neighbors_graph = self._build_neighbors_graph()
315 def _build_neighbors_graph(self):
316 """Build graph of neighboring (overlapping) tiles."""
317 pdist = scipy.spatial.distance.pdist(self.positions, metric='cityblock')
318 sp = scipy.spatial.distance.squareform(pdist)
319 max_distance = self.tile_size.max() + 1
320 edges = zip(*np.nonzero((sp > 0) & (sp < max_distance)))
321 graph = nx.from_edgelist(edges)
322 graph.add_nodes_from(range(len(self.positions)))
323 return graph
326 def run(self):
327 """Run the complete Ashlar algorithm."""
328 self.check_overlaps()
329 self.compute_threshold()
330 self.register_all()
331 self.build_spanning_tree()
332 self.calculate_positions()
333 self.fit_model()
335 def check_overlaps(self):
336 """Check if tiles actually overlap based on positions."""
337 overlaps = np.array([
338 self.tile_size - abs(self.positions[t1] - self.positions[t2])
339 for t1, t2 in self.neighbors_graph.edges
340 ])
341 failures = np.any(overlaps < 1, axis=1) if len(overlaps) else []
342 if len(failures) and all(failures): 342 ↛ 343line 342 didn't jump to line 343 because the condition on line 342 was never true
343 warn_data("No tiles overlap, attempting alignment anyway.")
344 elif any(failures): 344 ↛ 345line 344 didn't jump to line 345 because the condition on line 344 was never true
345 warn_data("Some neighboring tiles have zero overlap.")
347 def compute_threshold(self):
348 """Compute error threshold using permutation testing."""
349 if self.max_error is not None: 349 ↛ 350line 349 didn't jump to line 350 because the condition on line 349 was never true
350 if self.verbose:
351 print(" using explicit error threshold")
352 return
354 edges = self.neighbors_graph.edges
355 num_tiles = len(self.image_stack)
357 # If not enough tiles overlap to matter, skip this whole thing
358 if len(edges) <= 1: 358 ↛ 359line 358 didn't jump to line 359 because the condition on line 358 was never true
359 self.max_error = np.inf
360 return
362 widths = np.array([
363 self.intersection(t1, t2).shape.min()
364 for t1, t2 in edges
365 ])
366 w = widths.max()
367 max_offset = self.tile_size[0] - w
369 # Number of possible pairs minus number of actual neighbor pairs
370 num_distant_pairs = num_tiles * (num_tiles - 1) // 2 - len(edges)
372 # Reduce permutation count for small datasets
373 n = 1000 if num_distant_pairs > 8 else (num_distant_pairs + 1) * 10
374 pairs = np.empty((n, 2), dtype=int)
375 offsets = np.empty((n, 2), dtype=int)
377 # Generate n random non-overlapping image strips
378 max_tries = 100
379 if self.randomize is False: 379 ↛ 382line 379 didn't jump to line 382 because the condition on line 379 was always true
380 random_state = np.random.RandomState(0)
381 else:
382 random_state = np.random.RandomState()
384 for i in range(n):
385 # Limit tries to avoid infinite loop in pathological cases
386 for current_try in range(max_tries): 386 ↛ 410line 386 didn't jump to line 410 because the loop on line 386 didn't complete
387 t1, t2 = random_state.randint(num_tiles, size=2)
388 o1, o2 = random_state.randint(max_offset, size=2)
390 # Check for non-overlapping strips and abort the retry loop
391 if t1 != t2 and (t1, t2) not in edges:
392 # Different, non-neighboring tiles -- always OK
393 break
394 elif t1 == t2 and abs(o1 - o2) > w:
395 # Same tile OK if strips don't overlap within the image
396 break
397 elif (t1, t2) in edges:
398 # Neighbors OK if either strip is entirely outside the
399 # expected overlap region (based on nominal positions)
400 its = self.intersection(t1, t2, np.repeat(w, 2))
401 ioff1, ioff2 = its.offsets[:, 0]
402 if (
403 its.shape[0] > its.shape[1]
404 or o1 < ioff1 - w or o1 > ioff1 + w
405 or o2 < ioff2 - w or o2 > ioff2 + w
406 ):
407 break
408 else:
409 # Retries exhausted. This should be very rare.
410 warn_data(f"Could not find non-overlapping strips in {max_tries} tries")
411 pairs[i] = t1, t2
412 offsets[i] = o1, o2
414 errors = np.empty(n)
415 for i, ((t1, t2), (offset1, offset2)) in enumerate(zip(pairs, offsets)):
416 # if self.verbose and (i % 10 == 9 or i == n - 1):
417 # sys.stdout.write(f'\r quantifying alignment error {i + 1}/{n}')
418 # sys.stdout.flush()
419 img1 = self.image_stack[t1][offset1:offset1+w, :]
420 img2 = self.image_stack[t2][offset2:offset2+w, :]
421 _, errors[i] = ashlar_register_no_preprocessing(img1, img2, upsample=1)
422 # if self.verbose:
423 # print()
424 self.errors_negative_sampled = errors
425 self.max_error = np.percentile(errors, self.alpha * 100)
428 def register_all(self):
429 """Register all neighboring tile pairs."""
430 n = self.neighbors_graph.size()
431 for i, (t1, t2) in enumerate(self.neighbors_graph.edges, 1):
432 if self.verbose: 432 ↛ 433line 432 didn't jump to line 433 because the condition on line 432 was never true
433 sys.stdout.write(f'\r aligning edge {i}/{n}')
434 sys.stdout.flush()
435 self.register_pair(t1, t2)
436 if self.verbose: 436 ↛ 437line 436 didn't jump to line 437 because the condition on line 436 was never true
437 print()
438 self.all_errors = np.array([x[1] for x in self._cache.values()])
440 # Set error values above the threshold to infinity
441 for k, v in self._cache.items():
442 if v[1] > self.max_error or any(np.abs(v[0]) > self.max_shift_pixels):
443 self._cache[k] = (v[0], np.inf)
445 def register_pair(self, t1, t2):
446 """Return relative shift between images and the alignment error."""
447 key = tuple(sorted((t1, t2)))
448 try:
449 shift, error = self._cache[key]
450 except KeyError:
451 # Test a series of increasing overlap window sizes to help avoid
452 # missing alignments when the stage position error is large relative
453 # to the tile overlap. Simply using a large overlap in all cases
454 # limits the maximum achievable correlation thus increasing the
455 # error metric, leading to worse overall results. The window size
456 # starts at the nominal size and doubles until it's at least 10% of
457 # the tile size. If the nominal overlap is already 10% or greater,
458 # we only use that one size.
459 smin = self.intersection(key[0], key[1]).shape
460 smax = np.round(self.tile_size * 0.1)
461 sizes = [smin]
462 while any(sizes[-1] < smax):
463 sizes.append(sizes[-1] * 2)
464 # Test each window size with validation
465 results = []
466 for s in sizes:
467 try:
468 result = self._register(key[0], key[1], s)
469 results.append(result)
470 except Exception:
471 # If this window size fails, use infinite error
472 results.append((np.array([0.0, 0.0]), np.inf))
473 # Use the shift from the window size that gave the lowest error
474 shift, _ = min(results, key=lambda r: r[1])
475 # Extract the images from the nominal overlap window but with the
476 # shift applied to the second tile's position, and compute the error
477 # metric on these images. This should be even lower than the error
478 # computed above.
479 _, o1, o2 = self.overlap(key[0], key[1], shift=shift)
480 error = ashlar_nccw_no_preprocessing(o1, o2)
481 self._cache[key] = (shift, error)
482 if t1 > t2:
483 shift = -shift
484 # Return copy of shift to prevent corruption of cached values
485 return shift.copy(), error
487 def _register(self, t1, t2, min_size=0):
488 """Register a single tile pair with given minimum size."""
489 its, img1, img2 = self.overlap(t1, t2, min_size)
490 # Account for padding, flipping the sign depending on the direction
491 # between the tiles
492 p1, p2 = self.positions[[t1, t2]]
493 sx = 1 if p1[1] >= p2[1] else -1
494 sy = 1 if p1[0] >= p2[0] else -1
495 padding = its.padding * [sy, sx]
496 shift, error = ashlar_register_no_preprocessing(img1, img2)
497 shift += padding
498 return shift, error
501 def intersection(self, t1, t2, min_size=0, shift=None):
502 """Calculate intersection region between two tiles."""
503 corners1 = self.positions[[t1, t2]]
504 if shift is not None:
505 corners1[1] += shift
506 corners2 = corners1 + self.tile_size
507 return Intersection(corners1, corners2, min_size)
509 def crop(self, tile_id, offset, shape):
510 """Crop image from tile at given offset and shape."""
511 img = self.image_stack[tile_id]
512 return ashlar_crop(img, offset, shape)
514 def overlap(self, t1, t2, min_size=0, shift=None):
515 """Extract overlapping regions between two tiles."""
516 its = self.intersection(t1, t2, min_size, shift)
517 img1 = self.crop(t1, its.offsets[0], its.shape)
518 img2 = self.crop(t2, its.offsets[1], its.shape)
519 return its, img1, img2
525 def build_spanning_tree(self):
526 """Build minimum spanning tree from registered edges."""
527 g = nx.Graph()
528 g.add_nodes_from(self.neighbors_graph)
529 g.add_weighted_edges_from(
530 (t1, t2, error)
531 for (t1, t2), (_, error) in self._cache.items()
532 if np.isfinite(error)
533 )
534 spanning_tree = nx.Graph()
535 spanning_tree.add_nodes_from(g)
536 for c in nx.connected_components(g):
537 cc = g.subgraph(c)
538 center = nx.center(cc)[0]
539 paths = nx.single_source_dijkstra_path(cc, center).values()
540 for path in paths:
541 nx.add_path(spanning_tree, path)
542 self.spanning_tree = spanning_tree
544 def calculate_positions(self):
545 """Calculate final positions from spanning tree."""
546 shifts = {}
547 for c in nx.connected_components(self.spanning_tree):
548 cc = self.spanning_tree.subgraph(c)
549 center = nx.center(cc)[0]
550 shifts[center] = np.array([0, 0])
551 for edge in nx.traversal.bfs_edges(cc, center):
552 source, dest = edge
553 if source not in shifts: 553 ↛ 554line 553 didn't jump to line 554 because the condition on line 553 was never true
554 source, dest = dest, source
555 shift = self.register_pair(source, dest)[0]
556 shifts[dest] = shifts[source] + shift
557 if shifts:
558 self.shifts = np.array([s for _, s in sorted(shifts.items())])
559 self.final_positions = self.positions + self.shifts
560 else:
561 # TODO: fill in shifts and positions with 0x2 arrays
562 raise NotImplementedError("No images")
565 def fit_model(self):
566 """Fit linear model to handle disconnected components."""
567 components = sorted(
568 nx.connected_components(self.spanning_tree),
569 key=len, reverse=True
570 )
571 # Fit LR model on positions of largest connected component
572 cc0 = list(components[0])
573 self.lr = sklearn.linear_model.LinearRegression()
574 self.lr.fit(self.positions[cc0], self.final_positions[cc0])
576 # Fix up degenerate transform matrix. This happens when the spanning
577 # tree is completely edgeless or cc0's metadata positions fall in a
578 # straight line. In this case we fall back to the identity transform.
579 if np.linalg.det(self.lr.coef_) < 1e-3:
580 warn_data(
581 "Could not align enough edges, proceeding anyway with original"
582 " stage positions."
583 )
584 self.lr.coef_ = np.diag(np.ones(2))
585 self.lr.intercept_ = np.zeros(2)
587 # Adjust position of remaining components so their centroids match
588 # the predictions of the model
589 for cc in components[1:]:
590 nodes = list(cc)
591 centroid_m = np.mean(self.positions[nodes], axis=0)
592 centroid_f = np.mean(self.final_positions[nodes], axis=0)
593 shift = self.lr.predict([centroid_m])[0] - centroid_f
594 self.final_positions[nodes] += shift
596 # Adjust positions and model intercept to put origin at 0,0
597 self.origin = self.final_positions.min(axis=0)
598 self.final_positions -= self.origin
599 self.lr.intercept_ -= self.origin
602def _calculate_initial_positions(image_stack: np.ndarray, grid_dims: tuple, overlap_ratio: float) -> np.ndarray:
603 """Calculate initial grid positions based on overlap ratio."""
604 grid_rows, grid_cols = grid_dims
605 tile_height, tile_width = image_stack.shape[1:3]
606 spacing_factor = 1.0 - overlap_ratio
608 positions = []
609 for tile_idx in range(len(image_stack)):
610 r = tile_idx // grid_cols
611 c = tile_idx % grid_cols
613 y_pos = r * tile_height * spacing_factor
614 x_pos = c * tile_width * spacing_factor
615 positions.append([y_pos, x_pos])
617 return np.array(positions, dtype=float)
620def _convert_ashlar_positions_to_openhcs(ashlar_positions: np.ndarray) -> List[Tuple[float, float]]:
621 """Convert Ashlar positions to OpenHCS format."""
622 positions = []
623 for tile_idx in range(len(ashlar_positions)):
624 y, x = ashlar_positions[tile_idx]
625 positions.append((float(x), float(y))) # OpenHCS uses (x, y) format
626 return positions
629@special_inputs("grid_dimensions")
630@special_outputs("positions")
631@numpy_func
632def ashlar_compute_tile_positions_cpu(
633 image_stack: np.ndarray,
634 grid_dimensions: Tuple[int, int],
635 overlap_ratio: float = 0.1,
636 max_shift: float = 15.0,
637 stitch_alpha: float = 0.01,
638 max_error: float = None,
639 randomize: bool = False,
640 verbose: bool = False,
641 upsample_factor: int = 10,
642 permutation_upsample: int = 1,
643 permutation_samples: int = 1000,
644 min_permutation_samples: int = 10,
645 max_permutation_tries: int = 100,
646 window_size_factor: float = 0.1,
647 **kwargs
648) -> Tuple[np.ndarray, List[Tuple[float, float]]]:
649 """
650 Compute tile positions using the complete Ashlar algorithm - pure position calculation only.
652 This function implements the full Ashlar edge-based stitching algorithm but works directly
653 on preprocessed grayscale image arrays. It performs ONLY position calculation without any
654 file I/O, channel selection, or image preprocessing. All the mathematical sophistication
655 and robustness of the original Ashlar algorithm is preserved.
657 Args:
658 image_stack: 3D numpy array of shape (num_tiles, height, width) containing preprocessed
659 grayscale images. Each slice [i] should be a single-channel 2D image ready
660 for correlation analysis. No further preprocessing will be applied.
662 grid_dimensions: Tuple of (grid_rows, grid_cols) specifying the logical arrangement of
663 tiles. For example, (2, 3) means 2 rows and 3 columns of tiles, for a
664 total of 6 tiles. Must match the number of images in image_stack.
666 overlap_ratio: Expected fractional overlap between adjacent tiles (0.0-1.0). Default 0.1
667 means 10% overlap. This is used to calculate initial grid positions and
668 should match the actual overlap in your microscopy data. Typical values:
669 - 0.05-0.15 for well-controlled microscopes
670 - 0.15-0.25 for less precise stages
672 max_shift: Maximum allowed shift correction in micrometers. Default 15.0. This limits
673 how far tiles can be moved from their initial grid positions during alignment.
674 Should be set based on your microscope's stage accuracy:
675 - 5-15 μm for high-precision stages
676 - 15-50 μm for standard stages
677 - 50+ μm for low-precision or manual stages
679 stitch_alpha: Alpha value for statistical error threshold computation (0.0-1.0). Default
680 0.01 means 1% false positive rate. Lower values are stricter and reject more
681 alignments, higher values are more permissive. This controls the trade-off
682 between alignment quality and success rate:
683 - 0.001-0.01: Very strict, high quality alignments only
684 - 0.01-0.05: Balanced (recommended for most data)
685 - 0.05-0.1: Permissive, accepts lower quality alignments
687 max_error: Explicit error threshold for rejecting alignments (None = auto-compute).
688 When None (default), the threshold is computed automatically using permutation
689 testing. Set to a specific value to override automatic computation. Higher
690 values accept more alignments, lower values are stricter.
692 randomize: Whether to use random seed for permutation testing (bool). Default False uses
693 a fixed seed for reproducible results. Set True for different random sampling
694 in each run. Generally should be False for consistent results.
696 verbose: Enable detailed progress logging (bool). Default False. When True, prints
697 progress information including permutation testing, edge alignment, and
698 spanning tree construction. Useful for debugging and monitoring progress
699 on large datasets.
701 upsample_factor: Sub-pixel accuracy factor for phase cross correlation (int). Default 10.
702 Higher values provide better sub-pixel accuracy but increase computation time.
703 Range: 1-100+. Values of 10-50 are typical for high-accuracy stitching.
704 - 1: Pixel-level accuracy (fastest)
705 - 10: 0.1 pixel accuracy (balanced)
706 - 50: 0.02 pixel accuracy (high precision)
708 permutation_upsample: Upsample factor for permutation testing (int). Default 1.
709 Lower than upsample_factor for speed during threshold computation.
710 Usually kept at 1 since permutation testing doesn't need sub-pixel accuracy.
712 permutation_samples: Number of random samples for error threshold computation (int). Default 1000.
713 Higher values give more accurate thresholds but slower computation.
714 Automatically reduced for small datasets to avoid infinite loops.
716 min_permutation_samples: Minimum permutation samples for small datasets (int). Default 10.
717 When there are few non-overlapping pairs, this sets the minimum
718 number of samples to ensure statistical validity.
720 max_permutation_tries: Maximum attempts to find non-overlapping strips (int). Default 100.
721 Prevents infinite loops in pathological cases where valid strips
722 are hard to find. Rarely needs adjustment.
724 window_size_factor: Fraction of tile size for maximum window size (float). Default 0.1.
725 Controls the largest overlap window tested during progressive sizing.
726 Larger values allow detection of bigger stage errors but may reduce
727 correlation quality. Range: 0.05-0.2 typical.
729 **kwargs: Additional parameters (ignored). Allows compatibility with other stitching
730 algorithms that may have different parameter sets.
732 Returns:
733 Tuple of (image_stack, positions) where:
734 - image_stack: The original input image array (unchanged)
735 - positions: List of (x, y) position tuples in OpenHCS format, one per tile.
736 Positions are in pixel coordinates with (0, 0) at the top-left.
737 The positions represent the optimal tile placement after Ashlar
738 alignment, accounting for stage errors and image correlation.
740 Raises:
741 Exception: If the Ashlar algorithm fails (e.g., insufficient overlap, correlation
742 errors), the function automatically falls back to grid-based positioning
743 using the specified overlap_ratio.
745 Notes:
746 - This implementation contains the complete Ashlar algorithm including permutation
747 testing, progressive window sizing, minimum spanning tree construction, and
748 linear model fitting for disconnected components.
749 - The correlation functions are identical to original Ashlar but without image
750 preprocessing (whitening/filtering), allowing OpenHCS to handle preprocessing
751 in separate pipeline steps.
752 - For best results, ensure your image_stack contains properly preprocessed,
753 single-channel grayscale images with good contrast and minimal noise.
754 """
755 grid_rows, grid_cols = grid_dimensions
757 logger.info(f"Ashlar CPU: Processing {grid_rows}x{grid_cols} grid with {len(image_stack)} tiles")
759 try:
760 # Calculate initial grid positions
761 initial_positions = _calculate_initial_positions(image_stack, grid_dimensions, overlap_ratio)
762 tile_size = np.array(image_stack.shape[1:3]) # (height, width)
764 # Create and run ArrayEdgeAligner with complete Ashlar algorithm
765 logger.info("Running complete Ashlar edge-based stitching algorithm")
766 aligner = ArrayEdgeAligner(
767 image_stack=image_stack,
768 positions=initial_positions,
769 tile_size=tile_size,
770 pixel_size=1.0, # Assume 1 micrometer per pixel if not specified
771 max_shift=max_shift,
772 alpha=stitch_alpha,
773 max_error=max_error,
774 randomize=randomize,
775 verbose=verbose
776 )
778 # Run the complete algorithm
779 aligner.run()
781 # Convert to OpenHCS format
782 positions = _convert_ashlar_positions_to_openhcs(aligner.final_positions)
784 logger.info("Ashlar algorithm completed successfully")
786 except Exception as e:
787 logger.error(f"Ashlar algorithm failed: {e}")
788 # Fallback to grid positions if Ashlar fails
789 logger.warning("Falling back to grid-based positioning")
790 positions = []
791 tile_height, tile_width = image_stack.shape[1:3]
792 spacing_factor = 1.0 - overlap_ratio
794 for tile_idx in range(len(image_stack)):
795 r = tile_idx // grid_cols
796 c = tile_idx % grid_cols
797 x_pos = c * tile_width * spacing_factor
798 y_pos = r * tile_height * spacing_factor
799 positions.append((float(x_pos), float(y_pos)))
801 logger.info(f"Ashlar CPU: Completed processing {len(positions)} tile positions")
803 return image_stack, positions
806def materialize_ashlar_cpu_positions(data: List[Tuple[float, float]], path: str, filemanager) -> str:
807 """Materialize Ashlar CPU tile positions as scientific CSV with grid metadata."""
808 csv_path = path.replace('.pkl', '_ashlar_positions.csv')
810 df = pd.DataFrame(data, columns=['x_position_um', 'y_position_um'])
811 df['tile_id'] = range(len(df))
813 # Estimate grid dimensions from position layout
814 unique_x = sorted(df['x_position_um'].unique())
815 unique_y = sorted(df['y_position_um'].unique())
817 grid_cols = len(unique_x)
818 grid_rows = len(unique_y)
820 # Add grid coordinates
821 df['grid_row'] = df.index // grid_cols
822 df['grid_col'] = df.index % grid_cols
824 # Add spacing information
825 if len(unique_x) > 1:
826 x_spacing = unique_x[1] - unique_x[0]
827 df['x_spacing_um'] = x_spacing
828 else:
829 df['x_spacing_um'] = 0
831 if len(unique_y) > 1:
832 y_spacing = unique_y[1] - unique_y[0]
833 df['y_spacing_um'] = y_spacing
834 else:
835 df['y_spacing_um'] = 0
837 # Add metadata
838 df['algorithm'] = 'ashlar_cpu'
839 df['grid_dimensions'] = f"{grid_rows}x{grid_cols}"
841 csv_content = df.to_csv(index=False)
842 filemanager.save(csv_content, csv_path, "disk")
843 return csv_path