Coverage for openhcs/processing/backends/pos_gen/ashlar_main_cpu.py: 77.5%
401 statements
« prev ^ index » next coverage.py v7.11.0, created at 2025-11-04 02:09 +0000
« prev ^ index » next coverage.py v7.11.0, created at 2025-11-04 02:09 +0000
1"""
2OpenHCS Interface for Ashlar CPU Stitching Algorithm
4Array-based EdgeAligner implementation that works directly with numpy arrays
5instead of file-based readers. This is the complete Ashlar algorithm modified
6to accept arrays directly.
7"""
8from __future__ import annotations
9import logging
10import sys
11from typing import TYPE_CHECKING, Tuple, List
12import numpy as np
13import networkx as nx
14import scipy.spatial.distance
15import sklearn.linear_model
16import pandas as pd
18from openhcs.core.pipeline.function_contracts import special_inputs, special_outputs
19from openhcs.core.memory.decorators import numpy as numpy_func
21import warnings
23if TYPE_CHECKING: 23 ↛ 24line 23 didn't jump to line 24 because the condition on line 23 was never true
24 pass
26logger = logging.getLogger(__name__)
29class DataWarning(Warning):
30 """Warnings about the content of user-provided image data."""
31 pass
34def warn_data(message):
35 """Issue a warning about image data."""
36 warnings.warn(message, DataWarning)
39class Intersection:
40 """Calculate intersection region between two tiles (extracted from Ashlar)."""
42 def __init__(self, corners1, corners2, min_size=0):
43 if np.isscalar(min_size):
44 min_size = np.repeat(min_size, 2)
45 self._calculate(corners1, corners2, min_size)
47 def _calculate(self, corners1, corners2, min_size):
48 """Calculate intersection parameters with robust boundary validation."""
49 # corners1 and corners2 are arrays of shape (2, 2) containing
50 # the upper-left and lower-right corners of the two tiles
51 max_shape = (corners2 - corners1).max(axis=0)
52 min_size = min_size.clip(1, max_shape)
53 position = corners1.max(axis=0)
54 initial_shape = np.floor(corners2.min(axis=0) - position).astype(int)
55 clipped_shape = np.maximum(initial_shape, min_size)
56 self.shape = np.ceil(clipped_shape).astype(int)
57 self.padding = self.shape - initial_shape
59 # Calculate offsets with boundary validation
60 raw_offsets = np.maximum(position - corners1 - self.padding, 0)
62 # Validate that offsets + shape don't exceed tile boundaries
63 tile_sizes = corners2 - corners1
64 for i in range(2):
65 # Ensure offset + shape <= tile_size for each tile
66 max_offset = tile_sizes[i] - self.shape
67 raw_offsets[i] = np.minimum(raw_offsets[i], np.maximum(max_offset, 0))
69 # Ensure shape doesn't exceed available space
70 available_space = tile_sizes[i] - raw_offsets[i]
71 self.shape = np.minimum(self.shape, available_space.astype(int))
73 # Final validation - ensure shape is positive
74 self.shape = np.maximum(self.shape, 1)
76 self.offsets = raw_offsets.astype(int)
78 # Calculate fractional offset difference for subpixel accuracy
79 offset_diff = self.offsets[1] - self.offsets[0]
80 self.offset_diff_frac = offset_diff - offset_diff.round()
83def _get_window(shape):
84 """Build a 2D Hann window (from Ashlar utils.get_window)."""
85 # Build a 2D Hann window by taking the outer product of two 1-D windows.
86 wy = np.hanning(shape[0]).astype(np.float32)
87 wx = np.hanning(shape[1]).astype(np.float32)
88 window = np.outer(wy, wx)
89 return window
92def ashlar_register_no_preprocessing(img1, img2, upsample=10):
93 """
94 Robust Ashlar register function with comprehensive input validation.
96 This is based on ashlar.utils.register() but adds validation to handle
97 edge cases that can occur with real microscopy data.
98 """
99 import itertools
100 import scipy.ndimage
101 import skimage.registration
103 # Input validation
104 if img1 is None or img2 is None: 104 ↛ 105line 104 didn't jump to line 105 because the condition on line 104 was never true
105 return np.array([0.0, 0.0]), np.inf
107 if img1.size == 0 or img2.size == 0: 107 ↛ 108line 107 didn't jump to line 108 because the condition on line 107 was never true
108 return np.array([0.0, 0.0]), np.inf
110 if img1.shape != img2.shape: 110 ↛ 111line 110 didn't jump to line 111 because the condition on line 110 was never true
111 return np.array([0.0, 0.0]), np.inf
113 if len(img1.shape) != 2: 113 ↛ 114line 113 didn't jump to line 114 because the condition on line 113 was never true
114 return np.array([0.0, 0.0]), np.inf
116 if img1.shape[0] < 1 or img1.shape[1] < 1: 116 ↛ 117line 116 didn't jump to line 117 because the condition on line 116 was never true
117 return np.array([0.0, 0.0]), np.inf
119 # Convert to float32 (equivalent to what whiten() does) - match GPU version
120 img1w = img1.astype(np.float32)
121 img2w = img2.astype(np.float32)
123 # Apply windowing function (from original Ashlar)
124 img1w = img1w * _get_window(img1w.shape)
125 img2w = img2w * _get_window(img2w.shape)
127 # Use skimage's phase cross correlation with error handling
128 try:
129 shift = skimage.registration.phase_cross_correlation(
130 img1w,
131 img2w,
132 upsample_factor=upsample,
133 normalization=None
134 )[0]
135 except Exception as e:
136 # If phase correlation fails, return large error
137 logger.error(f"Ashlar CPU: PHASE CORRELATION FAILED - Exception: {e}")
138 logger.error(" Returning infinite error")
139 return np.array([0.0, 0.0]), np.inf
141 # At this point we may have a shift in the wrong quadrant since the FFT
142 # assumes the signal is periodic. We test all four possibilities and return
143 # the shift that gives the highest direct correlation (sum of products).
144 shape = np.array(img1.shape)
145 shift_pos = (shift + shape) % shape
146 shift_neg = shift_pos - shape
147 shifts = list(itertools.product(*zip(shift_pos, shift_neg)))
148 correlations = []
149 for s in shifts:
150 try:
151 shifted_img = scipy.ndimage.shift(img2w, s, order=0)
152 corr = np.abs(np.sum(img1w * shifted_img))
153 correlations.append(corr)
154 except Exception:
155 correlations.append(0.0)
157 if not correlations or max(correlations) == 0: 157 ↛ 158line 157 didn't jump to line 158 because the condition on line 157 was never true
158 logger.warning("Ashlar CPU: NO VALID CORRELATIONS - All correlations failed or zero")
159 return np.array([0.0, 0.0]), np.inf
161 idx = np.argmax(correlations)
162 shift = shifts[idx]
163 correlation = correlations[idx]
164 total_amplitude = np.linalg.norm(img1w) * np.linalg.norm(img2w)
165 if correlation > 0 and total_amplitude > 0: 165 ↛ 168line 165 didn't jump to line 168 because the condition on line 165 was always true
166 error = -np.log(correlation / total_amplitude)
167 else:
168 error = np.inf
170 # Log all correlation results at INFO level for user visibility
171 if error > 1.0: # High error threshold for Ashlar 171 ↛ 172line 171 didn't jump to line 172 because the condition on line 171 was never true
172 logger.warning(f"Ashlar CPU: HIGH CORRELATION ERROR - Error={error:.4f}, Shift=({shift[0]:.2f}, {shift[1]:.2f})")
173 logger.warning(" This indicates poor overlap or image quality between tiles")
174 else:
175 logger.info(f"Ashlar CPU: Correlation - Error={error:.4f}, Shift=({shift[0]:.2f}, {shift[1]:.2f})")
177 return shift, error
180def ashlar_nccw_no_preprocessing(img1, img2):
181 """
182 Robust Ashlar nccw function with comprehensive input validation.
184 This is based on ashlar.utils.nccw() but adds validation to handle
185 edge cases that can occur with real microscopy data.
186 """
187 # Input validation
188 if img1 is None or img2 is None: 188 ↛ 189line 188 didn't jump to line 189 because the condition on line 188 was never true
189 return np.inf
191 if img1.size == 0 or img2.size == 0: 191 ↛ 192line 191 didn't jump to line 192 because the condition on line 191 was never true
192 return np.inf
194 if img1.shape != img2.shape: 194 ↛ 195line 194 didn't jump to line 195 because the condition on line 194 was never true
195 return np.inf
197 if len(img1.shape) != 2: 197 ↛ 198line 197 didn't jump to line 198 because the condition on line 197 was never true
198 return np.inf
200 if img1.shape[0] < 1 or img1.shape[1] < 1: 200 ↛ 201line 200 didn't jump to line 201 because the condition on line 200 was never true
201 return np.inf
203 # Convert to float32 (equivalent to what whiten() does) - match GPU version
204 img1w = img1.astype(np.float32)
205 img2w = img2.astype(np.float32)
207 correlation = np.abs(np.sum(img1w * img2w))
208 total_amplitude = np.linalg.norm(img1w) * np.linalg.norm(img2w)
209 if correlation > 0 and total_amplitude > 0: 209 ↛ 223line 209 didn't jump to line 223 because the condition on line 209 was always true
210 diff = correlation - total_amplitude
211 if diff <= 0:
212 error = -np.log(correlation / total_amplitude)
213 elif diff < 1e-3: # Increased tolerance for robustness 213 ↛ 217line 213 didn't jump to line 217 because the condition on line 213 was never true
214 # This situation can occur due to numerical precision issues when
215 # img1 and img2 are very nearly or exactly identical. If the
216 # difference is small enough, let it slide.
217 error = 0
218 else:
219 # Instead of raising error, return large but finite error
220 logger.warning(f"Ashlar CPU: NCCW numerical precision issue - diff={diff:.6f}, using error=100.0")
221 error = 100.0 # Large error but not infinite
222 else:
223 logger.warning(f"Ashlar CPU: NCCW invalid correlation - correlation={correlation:.6f}, total_amplitude={total_amplitude:.6f}")
224 error = np.inf
226 # Log all NCCW results at INFO level for user visibility
227 if error > 10.0: # High NCCW error threshold
228 logger.warning(f"Ashlar CPU: HIGH NCCW ERROR - Error={error:.4f}")
229 logger.warning(" This indicates poor image correlation between tiles")
230 else:
231 logger.info(f"Ashlar CPU: NCCW - Error={error:.4f}")
233 return error
236def ashlar_crop(img, offset, shape):
237 """
238 Robust Ashlar crop function with comprehensive boundary validation.
240 This is based on ashlar.utils.crop() but adds validation to handle
241 edge cases that can occur with real microscopy data.
242 """
243 # Input validation
244 if img is None or img.size == 0: 244 ↛ 245line 244 didn't jump to line 245 because the condition on line 244 was never true
245 raise ValueError("Cannot crop from empty or None image")
247 # Convert to integers and validate
248 start = offset.round().astype(int)
249 shape = np.round(shape).astype(int)
251 # Ensure start is non-negative
252 start = np.maximum(start, 0)
254 # Ensure shape is positive
255 shape = np.maximum(shape, 1)
257 # Validate bounds
258 img_height, img_width = img.shape[:2]
259 end = start + shape
261 # Clamp to image boundaries
262 start[0] = min(start[0], img_height - 1)
263 start[1] = min(start[1], img_width - 1)
264 end[0] = min(end[0], img_height)
265 end[1] = min(end[1], img_width)
267 # Ensure we have a valid region
268 if end[0] <= start[0] or end[1] <= start[1]: 268 ↛ 270line 268 didn't jump to line 270 because the condition on line 268 was never true
269 # Return minimum valid region if bounds are invalid
270 return img[start[0]:start[0]+1, start[1]:start[1]+1]
272 return img[start[0]:end[0], start[1]:end[1]]
275class ArrayEdgeAligner:
276 """
277 Array-based EdgeAligner that implements the complete Ashlar algorithm
278 but works directly with numpy arrays instead of file readers.
279 """
281 def __init__(self, image_stack, positions, tile_size, pixel_size=1.0,
282 max_shift=15, alpha=0.01, max_error=None,
283 randomize=False, verbose=False):
284 """
285 Initialize array-based EdgeAligner for pure position calculation.
287 Args:
288 image_stack: 3D numpy array (num_tiles, height, width) - preprocessed grayscale
289 positions: 2D array of tile positions (num_tiles, 2) in pixels
290 tile_size: Array [height, width] of tile dimensions
291 pixel_size: Pixel size in micrometers (for max_shift conversion)
292 max_shift: Maximum allowed shift in micrometers
293 alpha: Alpha value for error threshold (lower = stricter)
294 max_error: Explicit error threshold (None = auto-compute)
295 randomize: Use random seed for permutation testing
296 verbose: Enable verbose logging
297 """
298 self.image_stack = image_stack
299 self.positions = positions.astype(float)
300 self.tile_size = np.array(tile_size)
301 self.pixel_size = pixel_size
302 self.max_shift = max_shift
303 self.max_shift_pixels = self.max_shift / self.pixel_size
304 self.alpha = alpha
305 self.max_error = max_error
306 self.randomize = randomize
307 self.verbose = verbose
308 self._cache = {}
309 self.errors_negative_sampled = np.empty(0)
311 # Build neighbors graph
312 self.neighbors_graph = self._build_neighbors_graph()
314 def _build_neighbors_graph(self):
315 """Build graph of neighboring (overlapping) tiles."""
316 pdist = scipy.spatial.distance.pdist(self.positions, metric='cityblock')
317 sp = scipy.spatial.distance.squareform(pdist)
318 max_distance = self.tile_size.max() + 1
319 edges = zip(*np.nonzero((sp > 0) & (sp < max_distance)))
320 graph = nx.from_edgelist(edges)
321 graph.add_nodes_from(range(len(self.positions)))
322 return graph
325 def run(self):
326 """Run the complete Ashlar algorithm."""
327 self.check_overlaps()
328 self.compute_threshold()
329 self.register_all()
330 self.build_spanning_tree()
331 self.calculate_positions()
332 self.fit_model()
334 def check_overlaps(self):
335 """Check if tiles actually overlap based on positions."""
336 overlaps = np.array([
337 self.tile_size - abs(self.positions[t1] - self.positions[t2])
338 for t1, t2 in self.neighbors_graph.edges
339 ])
340 failures = np.any(overlaps < 1, axis=1) if len(overlaps) else []
341 if len(failures) and all(failures): 341 ↛ 342line 341 didn't jump to line 342 because the condition on line 341 was never true
342 warn_data("No tiles overlap, attempting alignment anyway.")
343 elif any(failures): 343 ↛ 344line 343 didn't jump to line 344 because the condition on line 343 was never true
344 warn_data("Some neighboring tiles have zero overlap.")
346 def compute_threshold(self):
347 """Compute error threshold using permutation testing."""
348 if self.max_error is not None: 348 ↛ 349line 348 didn't jump to line 349 because the condition on line 348 was never true
349 if self.verbose:
350 print(" using explicit error threshold")
351 return
353 edges = self.neighbors_graph.edges
354 num_tiles = len(self.image_stack)
356 # If not enough tiles overlap to matter, skip this whole thing
357 if len(edges) <= 1: 357 ↛ 358line 357 didn't jump to line 358 because the condition on line 357 was never true
358 self.max_error = np.inf
359 return
361 widths = np.array([
362 self.intersection(t1, t2).shape.min()
363 for t1, t2 in edges
364 ])
365 w = widths.max()
366 max_offset = self.tile_size[0] - w
368 # Number of possible pairs minus number of actual neighbor pairs
369 num_distant_pairs = num_tiles * (num_tiles - 1) // 2 - len(edges)
371 # Reduce permutation count for small datasets
372 n = 1000 if num_distant_pairs > 8 else (num_distant_pairs + 1) * 10
373 pairs = np.empty((n, 2), dtype=int)
374 offsets = np.empty((n, 2), dtype=int)
376 # Generate n random non-overlapping image strips
377 max_tries = 100
378 if self.randomize is False: 378 ↛ 381line 378 didn't jump to line 381 because the condition on line 378 was always true
379 random_state = np.random.RandomState(0)
380 else:
381 random_state = np.random.RandomState()
383 for i in range(n):
384 # Limit tries to avoid infinite loop in pathological cases
385 for current_try in range(max_tries): 385 ↛ 409line 385 didn't jump to line 409 because the loop on line 385 didn't complete
386 t1, t2 = random_state.randint(num_tiles, size=2)
387 o1, o2 = random_state.randint(max_offset, size=2)
389 # Check for non-overlapping strips and abort the retry loop
390 if t1 != t2 and (t1, t2) not in edges:
391 # Different, non-neighboring tiles -- always OK
392 break
393 elif t1 == t2 and abs(o1 - o2) > w:
394 # Same tile OK if strips don't overlap within the image
395 break
396 elif (t1, t2) in edges:
397 # Neighbors OK if either strip is entirely outside the
398 # expected overlap region (based on nominal positions)
399 its = self.intersection(t1, t2, np.repeat(w, 2))
400 ioff1, ioff2 = its.offsets[:, 0]
401 if (
402 its.shape[0] > its.shape[1]
403 or o1 < ioff1 - w or o1 > ioff1 + w
404 or o2 < ioff2 - w or o2 > ioff2 + w
405 ):
406 break
407 else:
408 # Retries exhausted. This should be very rare.
409 warn_data(f"Could not find non-overlapping strips in {max_tries} tries")
410 pairs[i] = t1, t2
411 offsets[i] = o1, o2
413 errors = np.empty(n)
414 for i, ((t1, t2), (offset1, offset2)) in enumerate(zip(pairs, offsets)):
415 # if self.verbose and (i % 10 == 9 or i == n - 1):
416 # sys.stdout.write(f'\r quantifying alignment error {i + 1}/{n}')
417 # sys.stdout.flush()
418 img1 = self.image_stack[t1][offset1:offset1+w, :]
419 img2 = self.image_stack[t2][offset2:offset2+w, :]
420 _, errors[i] = ashlar_register_no_preprocessing(img1, img2, upsample=1)
421 # if self.verbose:
422 # print()
423 self.errors_negative_sampled = errors
424 self.max_error = np.percentile(errors, self.alpha * 100)
427 def register_all(self):
428 """Register all neighboring tile pairs."""
429 n = self.neighbors_graph.size()
430 for i, (t1, t2) in enumerate(self.neighbors_graph.edges, 1):
431 if self.verbose: 431 ↛ 432line 431 didn't jump to line 432 because the condition on line 431 was never true
432 sys.stdout.write(f'\r aligning edge {i}/{n}')
433 sys.stdout.flush()
434 self.register_pair(t1, t2)
435 if self.verbose: 435 ↛ 436line 435 didn't jump to line 436 because the condition on line 435 was never true
436 print()
437 self.all_errors = np.array([x[1] for x in self._cache.values()])
439 # Set error values above the threshold to infinity
440 for k, v in self._cache.items():
441 if v[1] > self.max_error or any(np.abs(v[0]) > self.max_shift_pixels):
442 self._cache[k] = (v[0], np.inf)
444 def register_pair(self, t1, t2):
445 """Return relative shift between images and the alignment error."""
446 key = tuple(sorted((t1, t2)))
447 try:
448 shift, error = self._cache[key]
449 except KeyError:
450 # Test a series of increasing overlap window sizes to help avoid
451 # missing alignments when the stage position error is large relative
452 # to the tile overlap. Simply using a large overlap in all cases
453 # limits the maximum achievable correlation thus increasing the
454 # error metric, leading to worse overall results. The window size
455 # starts at the nominal size and doubles until it's at least 10% of
456 # the tile size. If the nominal overlap is already 10% or greater,
457 # we only use that one size.
458 smin = self.intersection(key[0], key[1]).shape
459 smax = np.round(self.tile_size * 0.1)
460 sizes = [smin]
461 while any(sizes[-1] < smax):
462 sizes.append(sizes[-1] * 2)
463 # Test each window size with validation
464 results = []
465 for s in sizes:
466 try:
467 result = self._register(key[0], key[1], s)
468 results.append(result)
469 except Exception:
470 # If this window size fails, use infinite error
471 results.append((np.array([0.0, 0.0]), np.inf))
472 # Use the shift from the window size that gave the lowest error
473 shift, _ = min(results, key=lambda r: r[1])
474 # Extract the images from the nominal overlap window but with the
475 # shift applied to the second tile's position, and compute the error
476 # metric on these images. This should be even lower than the error
477 # computed above.
478 _, o1, o2 = self.overlap(key[0], key[1], shift=shift)
479 error = ashlar_nccw_no_preprocessing(o1, o2)
480 self._cache[key] = (shift, error)
481 if t1 > t2:
482 shift = -shift
483 # Return copy of shift to prevent corruption of cached values
484 return shift.copy(), error
486 def _register(self, t1, t2, min_size=0):
487 """Register a single tile pair with given minimum size."""
488 its, img1, img2 = self.overlap(t1, t2, min_size)
489 # Account for padding, flipping the sign depending on the direction
490 # between the tiles
491 p1, p2 = self.positions[[t1, t2]]
492 sx = 1 if p1[1] >= p2[1] else -1
493 sy = 1 if p1[0] >= p2[0] else -1
494 padding = its.padding * [sy, sx]
495 shift, error = ashlar_register_no_preprocessing(img1, img2)
496 shift += padding
497 return shift, error
500 def intersection(self, t1, t2, min_size=0, shift=None):
501 """Calculate intersection region between two tiles."""
502 corners1 = self.positions[[t1, t2]]
503 if shift is not None:
504 corners1[1] += shift
505 corners2 = corners1 + self.tile_size
506 return Intersection(corners1, corners2, min_size)
508 def crop(self, tile_id, offset, shape):
509 """Crop image from tile at given offset and shape."""
510 img = self.image_stack[tile_id]
511 return ashlar_crop(img, offset, shape)
513 def overlap(self, t1, t2, min_size=0, shift=None):
514 """Extract overlapping regions between two tiles."""
515 its = self.intersection(t1, t2, min_size, shift)
516 img1 = self.crop(t1, its.offsets[0], its.shape)
517 img2 = self.crop(t2, its.offsets[1], its.shape)
518 return its, img1, img2
524 def build_spanning_tree(self):
525 """Build minimum spanning tree from registered edges."""
526 g = nx.Graph()
527 g.add_nodes_from(self.neighbors_graph)
528 g.add_weighted_edges_from(
529 (t1, t2, error)
530 for (t1, t2), (_, error) in self._cache.items()
531 if np.isfinite(error)
532 )
533 spanning_tree = nx.Graph()
534 spanning_tree.add_nodes_from(g)
535 for c in nx.connected_components(g):
536 cc = g.subgraph(c)
537 center = nx.center(cc)[0]
538 paths = nx.single_source_dijkstra_path(cc, center).values()
539 for path in paths:
540 nx.add_path(spanning_tree, path)
541 self.spanning_tree = spanning_tree
543 def calculate_positions(self):
544 """Calculate final positions from spanning tree."""
545 shifts = {}
546 for c in nx.connected_components(self.spanning_tree):
547 cc = self.spanning_tree.subgraph(c)
548 center = nx.center(cc)[0]
549 shifts[center] = np.array([0, 0])
550 for edge in nx.traversal.bfs_edges(cc, center):
551 source, dest = edge
552 if source not in shifts: 552 ↛ 553line 552 didn't jump to line 553 because the condition on line 552 was never true
553 source, dest = dest, source
554 shift = self.register_pair(source, dest)[0]
555 shifts[dest] = shifts[source] + shift
556 if shifts:
557 self.shifts = np.array([s for _, s in sorted(shifts.items())])
558 self.final_positions = self.positions + self.shifts
559 else:
560 # TODO: fill in shifts and positions with 0x2 arrays
561 raise NotImplementedError("No images")
564 def fit_model(self):
565 """Fit linear model to handle disconnected components."""
566 components = sorted(
567 nx.connected_components(self.spanning_tree),
568 key=len, reverse=True
569 )
570 # Fit LR model on positions of largest connected component
571 cc0 = list(components[0])
572 self.lr = sklearn.linear_model.LinearRegression()
573 self.lr.fit(self.positions[cc0], self.final_positions[cc0])
575 # Fix up degenerate transform matrix. This happens when the spanning
576 # tree is completely edgeless or cc0's metadata positions fall in a
577 # straight line. In this case we fall back to the identity transform.
578 if np.linalg.det(self.lr.coef_) < 1e-3: 578 ↛ 579line 578 didn't jump to line 579 because the condition on line 578 was never true
579 warn_data(
580 "Could not align enough edges, proceeding anyway with original"
581 " stage positions."
582 )
583 self.lr.coef_ = np.diag(np.ones(2))
584 self.lr.intercept_ = np.zeros(2)
586 # Adjust position of remaining components so their centroids match
587 # the predictions of the model
588 for cc in components[1:]:
589 nodes = list(cc)
590 centroid_m = np.mean(self.positions[nodes], axis=0)
591 centroid_f = np.mean(self.final_positions[nodes], axis=0)
592 shift = self.lr.predict([centroid_m])[0] - centroid_f
593 self.final_positions[nodes] += shift
595 # Adjust positions and model intercept to put origin at 0,0
596 self.origin = self.final_positions.min(axis=0)
597 self.final_positions -= self.origin
598 self.lr.intercept_ -= self.origin
601def _calculate_initial_positions(image_stack: np.ndarray, grid_dims: tuple, overlap_ratio: float) -> np.ndarray:
602 """Calculate initial grid positions based on overlap ratio."""
603 grid_rows, grid_cols = grid_dims
604 tile_height, tile_width = image_stack.shape[1:3]
605 spacing_factor = 1.0 - overlap_ratio
607 positions = []
608 for tile_idx in range(len(image_stack)):
609 r = tile_idx // grid_cols
610 c = tile_idx % grid_cols
612 y_pos = r * tile_height * spacing_factor
613 x_pos = c * tile_width * spacing_factor
614 positions.append([y_pos, x_pos])
616 return np.array(positions, dtype=float)
619def _convert_ashlar_positions_to_openhcs(ashlar_positions: np.ndarray) -> List[Tuple[float, float]]:
620 """Convert Ashlar positions to OpenHCS format."""
621 positions = []
622 for tile_idx in range(len(ashlar_positions)):
623 y, x = ashlar_positions[tile_idx]
624 positions.append((float(x), float(y))) # OpenHCS uses (x, y) format
625 return positions
628@special_inputs("grid_dimensions")
629@special_outputs("positions")
630@numpy_func
631def ashlar_compute_tile_positions_cpu(
632 image_stack: np.ndarray,
633 grid_dimensions: Tuple[int, int],
634 overlap_ratio: float = 0.1,
635 max_shift: float = 15.0,
636 stitch_alpha: float = 0.01,
637 max_error: float = None,
638 randomize: bool = False,
639 verbose: bool = False,
640 upsample_factor: int = 10,
641 permutation_upsample: int = 1,
642 permutation_samples: int = 1000,
643 min_permutation_samples: int = 10,
644 max_permutation_tries: int = 100,
645 window_size_factor: float = 0.1,
646 **kwargs
647) -> Tuple[np.ndarray, List[Tuple[float, float]]]:
648 """
649 Compute tile positions using the complete Ashlar algorithm - pure position calculation only.
651 This function implements the full Ashlar edge-based stitching algorithm but works directly
652 on preprocessed grayscale image arrays. It performs ONLY position calculation without any
653 file I/O, channel selection, or image preprocessing. All the mathematical sophistication
654 and robustness of the original Ashlar algorithm is preserved.
656 Args:
657 image_stack: 3D numpy array of shape (num_tiles, height, width) containing preprocessed
658 grayscale images. Each slice [i] should be a single-channel 2D image ready
659 for correlation analysis. No further preprocessing will be applied.
661 grid_dimensions: Tuple of (grid_rows, grid_cols) specifying the logical arrangement of
662 tiles. For example, (2, 3) means 2 rows and 3 columns of tiles, for a
663 total of 6 tiles. Must match the number of images in image_stack.
665 overlap_ratio: Expected fractional overlap between adjacent tiles (0.0-1.0). Default 0.1
666 means 10% overlap. This is used to calculate initial grid positions and
667 should match the actual overlap in your microscopy data. Typical values:
668 - 0.05-0.15 for well-controlled microscopes
669 - 0.15-0.25 for less precise stages
671 max_shift: Maximum allowed shift correction in micrometers. Default 15.0. This limits
672 how far tiles can be moved from their initial grid positions during alignment.
673 Should be set based on your microscope's stage accuracy:
674 - 5-15 μm for high-precision stages
675 - 15-50 μm for standard stages
676 - 50+ μm for low-precision or manual stages
678 stitch_alpha: Alpha value for statistical error threshold computation (0.0-1.0). Default
679 0.01 means 1% false positive rate. Lower values are stricter and reject more
680 alignments, higher values are more permissive. This controls the trade-off
681 between alignment quality and success rate:
682 - 0.001-0.01: Very strict, high quality alignments only
683 - 0.01-0.05: Balanced (recommended for most data)
684 - 0.05-0.1: Permissive, accepts lower quality alignments
686 max_error: Explicit error threshold for rejecting alignments (None = auto-compute).
687 When None (default), the threshold is computed automatically using permutation
688 testing. Set to a specific value to override automatic computation. Higher
689 values accept more alignments, lower values are stricter.
691 randomize: Whether to use random seed for permutation testing (bool). Default False uses
692 a fixed seed for reproducible results. Set True for different random sampling
693 in each run. Generally should be False for consistent results.
695 verbose: Enable detailed progress logging (bool). Default False. When True, prints
696 progress information including permutation testing, edge alignment, and
697 spanning tree construction. Useful for debugging and monitoring progress
698 on large datasets.
700 upsample_factor: Sub-pixel accuracy factor for phase cross correlation (int). Default 10.
701 Higher values provide better sub-pixel accuracy but increase computation time.
702 Range: 1-100+. Values of 10-50 are typical for high-accuracy stitching.
703 - 1: Pixel-level accuracy (fastest)
704 - 10: 0.1 pixel accuracy (balanced)
705 - 50: 0.02 pixel accuracy (high precision)
707 permutation_upsample: Upsample factor for permutation testing (int). Default 1.
708 Lower than upsample_factor for speed during threshold computation.
709 Usually kept at 1 since permutation testing doesn't need sub-pixel accuracy.
711 permutation_samples: Number of random samples for error threshold computation (int). Default 1000.
712 Higher values give more accurate thresholds but slower computation.
713 Automatically reduced for small datasets to avoid infinite loops.
715 min_permutation_samples: Minimum permutation samples for small datasets (int). Default 10.
716 When there are few non-overlapping pairs, this sets the minimum
717 number of samples to ensure statistical validity.
719 max_permutation_tries: Maximum attempts to find non-overlapping strips (int). Default 100.
720 Prevents infinite loops in pathological cases where valid strips
721 are hard to find. Rarely needs adjustment.
723 window_size_factor: Fraction of tile size for maximum window size (float). Default 0.1.
724 Controls the largest overlap window tested during progressive sizing.
725 Larger values allow detection of bigger stage errors but may reduce
726 correlation quality. Range: 0.05-0.2 typical.
728 **kwargs: Additional parameters (ignored). Allows compatibility with other stitching
729 algorithms that may have different parameter sets.
731 Returns:
732 Tuple of (image_stack, positions) where:
733 - image_stack: The original input image array (unchanged)
734 - positions: List of (x, y) position tuples in OpenHCS format, one per tile.
735 Positions are in pixel coordinates with (0, 0) at the top-left.
736 The positions represent the optimal tile placement after Ashlar
737 alignment, accounting for stage errors and image correlation.
739 Raises:
740 Exception: If the Ashlar algorithm fails (e.g., insufficient overlap, correlation
741 errors), the function automatically falls back to grid-based positioning
742 using the specified overlap_ratio.
744 Notes:
745 - This implementation contains the complete Ashlar algorithm including permutation
746 testing, progressive window sizing, minimum spanning tree construction, and
747 linear model fitting for disconnected components.
748 - The correlation functions are identical to original Ashlar but without image
749 preprocessing (whitening/filtering), allowing OpenHCS to handle preprocessing
750 in separate pipeline steps.
751 - For best results, ensure your image_stack contains properly preprocessed,
752 single-channel grayscale images with good contrast and minimal noise.
753 """
754 grid_rows, grid_cols = grid_dimensions
756 logger.info(f"Ashlar CPU: Processing {grid_rows}x{grid_cols} grid with {len(image_stack)} tiles")
758 try:
759 # Calculate initial grid positions
760 initial_positions = _calculate_initial_positions(image_stack, grid_dimensions, overlap_ratio)
761 tile_size = np.array(image_stack.shape[1:3]) # (height, width)
763 # Create and run ArrayEdgeAligner with complete Ashlar algorithm
764 logger.info("Running complete Ashlar edge-based stitching algorithm")
765 aligner = ArrayEdgeAligner(
766 image_stack=image_stack,
767 positions=initial_positions,
768 tile_size=tile_size,
769 pixel_size=1.0, # Assume 1 micrometer per pixel if not specified
770 max_shift=max_shift,
771 alpha=stitch_alpha,
772 max_error=max_error,
773 randomize=randomize,
774 verbose=verbose
775 )
777 # Run the complete algorithm
778 aligner.run()
780 # Convert to OpenHCS format
781 positions = _convert_ashlar_positions_to_openhcs(aligner.final_positions)
783 logger.info("Ashlar algorithm completed successfully")
785 except Exception as e:
786 logger.error(f"Ashlar algorithm failed: {e}")
787 # Fallback to grid positions if Ashlar fails
788 logger.warning("Falling back to grid-based positioning")
789 positions = []
790 tile_height, tile_width = image_stack.shape[1:3]
791 spacing_factor = 1.0 - overlap_ratio
793 for tile_idx in range(len(image_stack)):
794 r = tile_idx // grid_cols
795 c = tile_idx % grid_cols
796 x_pos = c * tile_width * spacing_factor
797 y_pos = r * tile_height * spacing_factor
798 positions.append((float(x_pos), float(y_pos)))
800 logger.info(f"Ashlar CPU: Completed processing {len(positions)} tile positions")
802 return image_stack, positions
805def materialize_ashlar_cpu_positions(data: List[Tuple[float, float]], path: str, filemanager) -> str:
806 """Materialize Ashlar CPU tile positions as scientific CSV with grid metadata."""
807 csv_path = path.replace('.pkl', '_ashlar_positions.csv')
809 df = pd.DataFrame(data, columns=['x_position_um', 'y_position_um'])
810 df['tile_id'] = range(len(df))
812 # Estimate grid dimensions from position layout
813 unique_x = sorted(df['x_position_um'].unique())
814 unique_y = sorted(df['y_position_um'].unique())
816 grid_cols = len(unique_x)
817 grid_rows = len(unique_y)
819 # Add grid coordinates
820 df['grid_row'] = df.index // grid_cols
821 df['grid_col'] = df.index % grid_cols
823 # Add spacing information
824 if len(unique_x) > 1:
825 x_spacing = unique_x[1] - unique_x[0]
826 df['x_spacing_um'] = x_spacing
827 else:
828 df['x_spacing_um'] = 0
830 if len(unique_y) > 1:
831 y_spacing = unique_y[1] - unique_y[0]
832 df['y_spacing_um'] = y_spacing
833 else:
834 df['y_spacing_um'] = 0
836 # Add metadata
837 df['algorithm'] = 'ashlar_cpu'
838 df['grid_dimensions'] = f"{grid_rows}x{grid_cols}"
840 csv_content = df.to_csv(index=False)
841 filemanager.save(csv_content, csv_path, "disk")
842 return csv_path