Coverage for openhcs/processing/backends/assemblers/assemble_stack_cpu.py: 61.9%
183 statements
« prev ^ index » next coverage.py v7.10.3, created at 2025-08-14 05:57 +0000
« prev ^ index » next coverage.py v7.10.3, created at 2025-08-14 05:57 +0000
1"""
2CPU implementation of image assembly functions with fixed blending.
3"""
4from __future__ import annotations
6import logging
7from typing import TYPE_CHECKING, List, Tuple, Union
9from openhcs.core.memory.decorators import numpy as numpy_func
10from openhcs.core.pipeline.function_contracts import special_inputs
12# For type checking only
13if TYPE_CHECKING: 13 ↛ 14line 13 didn't jump to line 14 because the condition on line 13 was never true
14 import numpy as np
15 from scipy.ndimage import shift as subpixel_shift
17# Import NumPy
18import numpy as np # type: ignore
19from scipy.ndimage import shift as subpixel_shift # type: ignore
21logger = logging.getLogger(__name__)
24def _get_all_overlapping_pairs(positions: "np.ndarray", tile_shape: tuple) -> list:
25 """
26 Find ALL overlapping tile pairs with edge directions.
27 [Keep this exactly as it was - it works fine]
28 """
29 height, width = tile_shape
30 N = positions.shape[0]
32 if N <= 1: 32 ↛ 33line 32 didn't jump to line 33 because the condition on line 32 was never true
33 return []
35 # Vectorized computation of ALL pairwise overlaps
36 pos_i = positions[:, np.newaxis, :]
37 pos_j = positions[np.newaxis, :, :]
39 xi, yi = pos_i[:, :, 0], pos_i[:, :, 1]
40 xj, yj = pos_j[:, :, 0], pos_j[:, :, 1]
42 left_i, right_i = xi, xi + width
43 top_i, bottom_i = yi, yi + height
44 left_j, right_j = xj, xj + width
45 top_j, bottom_j = yj, yj + height
47 x_overlap = np.maximum(0, np.minimum(right_i, right_j) - np.maximum(left_i, left_j))
48 y_overlap = np.maximum(0, np.minimum(bottom_i, bottom_j) - np.maximum(top_i, top_j))
50 valid_overlap = (x_overlap > 0) & (y_overlap > 0) & (np.arange(N)[:, None] != np.arange(N)[None, :])
52 edge_pairs = []
53 overlapping_pairs = np.where(valid_overlap)
55 for idx in range(len(overlapping_pairs[0])):
56 i, j = overlapping_pairs[0][idx], overlapping_pairs[1][idx]
58 x_overlap_val = float(x_overlap[i, j])
59 y_overlap_val = float(y_overlap[i, j])
61 xi_val, yi_val = positions[i, 0], positions[i, 1]
62 xj_val, yj_val = positions[j, 0], positions[j, 1]
64 if x_overlap_val > 0: 64 ↛ 70line 64 didn't jump to line 70 because the condition on line 64 was always true
65 if xj_val < xi_val:
66 edge_pairs.append((i, j, 'left', x_overlap_val))
67 elif xj_val > xi_val:
68 edge_pairs.append((i, j, 'right', x_overlap_val))
70 if y_overlap_val > 0: 70 ↛ 55line 70 didn't jump to line 55 because the condition on line 70 was always true
71 if yj_val < yi_val:
72 edge_pairs.append((i, j, 'top', y_overlap_val))
73 elif yj_val > yi_val:
74 edge_pairs.append((i, j, 'bottom', y_overlap_val))
76 return edge_pairs
79def _create_fixed_blend_mask(
80 tile_shape: tuple,
81 edge_overlaps: dict,
82 margin_ratio: float = 0.1
83) -> "np.ndarray":
84 """
85 Create blend mask with FIXED margin ratio using WORKING logic from old version.
86 CRITICAL: Uses endpoint=False like the old working version.
87 """
88 height, width = tile_shape
90 # Create 1D weights
91 y_weight = np.ones(height, dtype=np.float32)
92 x_weight = np.ones(width, dtype=np.float32)
94 # Fixed margins (same as old working version)
95 margin_pixels_y = int(height * margin_ratio)
96 margin_pixels_x = int(width * margin_ratio)
98 # Apply gradients ONLY where there are overlaps (same as old working version)
99 # CRITICAL: endpoint=False (this is what made the old version work!)
100 if 'top' in edge_overlaps and margin_pixels_y > 0:
101 y_weight[:margin_pixels_y] = np.linspace(0, 1, margin_pixels_y, endpoint=False)
103 if 'bottom' in edge_overlaps and margin_pixels_y > 0:
104 y_weight[-margin_pixels_y:] = np.linspace(1, 0, margin_pixels_y, endpoint=False)
106 if 'left' in edge_overlaps and margin_pixels_x > 0:
107 x_weight[:margin_pixels_x] = np.linspace(0, 1, margin_pixels_x, endpoint=False)
109 if 'right' in edge_overlaps and margin_pixels_x > 0:
110 x_weight[-margin_pixels_x:] = np.linspace(1, 0, margin_pixels_x, endpoint=False)
112 # Use outer product (same as old working version)
113 mask = np.outer(y_weight, x_weight)
114 return mask.astype(np.float32)
117def _create_dynamic_blend_mask(
118 tile_shape: tuple,
119 edge_overlaps: dict,
120 overlap_fraction: float = 1.0
121) -> "np.ndarray":
122 """
123 Create blend mask based on actual overlap amounts using WORKING logic from old version.
124 CRITICAL: Uses endpoint=False and same logic as old working version.
125 """
126 height, width = tile_shape
128 # Create 1D weights
129 y_weight = np.ones(height, dtype=np.float32)
130 x_weight = np.ones(width, dtype=np.float32)
132 # Process each edge based on actual overlap (same as old working version)
133 # CRITICAL: endpoint=False (this is what made the old version work!)
134 if 'top' in edge_overlaps:
135 overlap_pixels = int(edge_overlaps['top'] * overlap_fraction)
136 if overlap_pixels > 0:
137 y_weight[:overlap_pixels] = np.linspace(0, 1, overlap_pixels, endpoint=False)
139 if 'bottom' in edge_overlaps:
140 overlap_pixels = int(edge_overlaps['bottom'] * overlap_fraction)
141 if overlap_pixels > 0:
142 y_weight[-overlap_pixels:] = np.linspace(1, 0, overlap_pixels, endpoint=False)
144 if 'left' in edge_overlaps:
145 overlap_pixels = int(edge_overlaps['left'] * overlap_fraction)
146 if overlap_pixels > 0:
147 x_weight[:overlap_pixels] = np.linspace(0, 1, overlap_pixels, endpoint=False)
149 if 'right' in edge_overlaps:
150 overlap_pixels = int(edge_overlaps['right'] * overlap_fraction)
151 if overlap_pixels > 0:
152 x_weight[-overlap_pixels:] = np.linspace(1, 0, overlap_pixels, endpoint=False)
154 # Use outer product (same as old working version)
155 mask = np.outer(y_weight, x_weight)
156 return mask.astype(np.float32)
159@special_inputs("positions")
160@numpy_func
161def assemble_stack_cpu(
162 image_tiles: "np.ndarray",
163 positions: Union[List[Tuple[float, float]], "np.ndarray"],
164 blend_method: str = "fixed",
165 fixed_margin_ratio: float = 0.1,
166 overlap_blend_fraction: float = 1.0
167) -> "np.ndarray":
168 """
169 Assembles tiles with simple, working blending approach.
171 Args:
172 image_tiles: 3D array of tiles (N, H, W)
173 positions: List of (x, y) tuples or 2D array [N, 2]
174 blend_method: "none", "fixed", or "dynamic"
175 fixed_margin_ratio: Ratio for fixed blending (e.g., 0.1 = 10%)
176 overlap_blend_fraction: For dynamic mode, fraction of overlap to blend
177 use_endpoint: Whether to include endpoint in gradients
178 """
179 # --- 1. Validate inputs ---
180 if not isinstance(image_tiles, np.ndarray) or image_tiles.ndim != 3: 180 ↛ 181line 180 didn't jump to line 181 because the condition on line 180 was never true
181 raise TypeError("image_tiles must be a 3D NumPy ndarray of shape (N, H, W).")
183 if image_tiles.shape[0] == 0: 183 ↛ 184line 183 didn't jump to line 184 because the condition on line 183 was never true
184 logger.warning("image_tiles array is empty (0 tiles).")
185 return np.array([[[]]], dtype=np.uint16)
187 # Convert positions to numpy
188 if isinstance(positions, list): 188 ↛ 193line 188 didn't jump to line 193 because the condition on line 188 was always true
189 if not positions or not isinstance(positions[0], tuple) or len(positions[0]) != 2: 189 ↛ 190line 189 didn't jump to line 190 because the condition on line 189 was never true
190 raise TypeError("positions must be a list of (x, y) tuples.")
191 positions = np.array(positions, dtype=np.float32)
192 else:
193 if not isinstance(positions, np.ndarray):
194 positions = to_numpy(positions)
195 if positions.ndim != 2 or positions.shape[1] != 2:
196 raise TypeError("positions must be an array of shape [N, 2].")
198 if image_tiles.shape[0] != positions.shape[0]: 198 ↛ 199line 198 didn't jump to line 199 because the condition on line 198 was never true
199 raise ValueError(f"Mismatch: {image_tiles.shape[0]} tiles vs {positions.shape[0]} positions.")
201 num_tiles, tile_h, tile_w = image_tiles.shape
202 tile_shape = (tile_h, tile_w)
204 # Convert to float32
205 image_tiles_float = image_tiles.astype(np.float32)
207 # --- 2. Compute canvas bounds ---
208 min_x = np.floor(np.min(positions[:, 0])).astype(int)
209 min_y = np.floor(np.min(positions[:, 1])).astype(int)
210 max_x = np.ceil(np.max(positions[:, 0]) + tile_w).astype(int)
211 max_y = np.ceil(np.max(positions[:, 1]) + tile_h).astype(int)
213 canvas_width = max_x - min_x
214 canvas_height = max_y - min_y
216 if canvas_width <= 0 or canvas_height <= 0: 216 ↛ 217line 216 didn't jump to line 217 because the condition on line 216 was never true
217 logger.warning(f"Invalid canvas size: {canvas_height}x{canvas_width}")
218 return np.array([], dtype=np.uint16)
220 composite_accum = np.zeros((canvas_height, canvas_width), dtype=np.float32)
221 weight_accum = np.zeros((canvas_height, canvas_width), dtype=np.float32)
223 # --- 3. Create blend masks ---
224 if blend_method == "none": 224 ↛ 225line 224 didn't jump to line 225 because the condition on line 224 was never true
225 blend_masks = [np.ones(tile_shape, dtype=np.float32) for _ in range(num_tiles)]
227 else:
228 # Find overlaps
229 edge_pairs = _get_all_overlapping_pairs(positions, tile_shape)
230 tile_overlaps = [{} for _ in range(num_tiles)]
232 # Build overlap info per tile
233 for tile_i, tile_j, edge_direction, pixel_overlap in edge_pairs:
234 if edge_direction not in tile_overlaps[tile_i]:
235 tile_overlaps[tile_i][edge_direction] = pixel_overlap
236 else:
237 # Keep maximum overlap
238 tile_overlaps[tile_i][edge_direction] = max(
239 tile_overlaps[tile_i][edge_direction], pixel_overlap
240 )
242 # Create masks using WORKING logic from old version
243 blend_masks = []
244 for i in range(num_tiles):
245 if blend_method == "fixed": 245 ↛ 251line 245 didn't jump to line 251 because the condition on line 245 was always true
246 mask = _create_fixed_blend_mask(
247 tile_shape,
248 tile_overlaps[i],
249 margin_ratio=fixed_margin_ratio
250 )
251 elif blend_method == "dynamic":
252 mask = _create_dynamic_blend_mask(
253 tile_shape,
254 tile_overlaps[i],
255 overlap_fraction=overlap_blend_fraction
256 )
257 else:
258 raise ValueError(f"Unknown blend_method: {blend_method}")
260 blend_masks.append(mask)
262 # --- 4. Place tiles ---
263 for i in range(num_tiles):
264 tile = image_tiles_float[i]
265 pos_x, pos_y = positions[i]
267 # Canvas position
268 target_x = pos_x - min_x
269 target_y = pos_y - min_y
271 # Integer and fractional parts
272 x_int = int(np.floor(target_x))
273 y_int = int(np.floor(target_y))
274 x_frac = target_x - x_int
275 y_frac = target_y - y_int
277 # Subpixel shift
278 shift_x = -x_frac
279 shift_y = -y_frac
281 shifted_tile = subpixel_shift(
282 tile,
283 shift=(shift_y, shift_x),
284 order=1,
285 mode='constant',
286 cval=0.0
287 )
289 # Apply blend mask
290 blended_tile = shifted_tile * blend_masks[i]
292 # Canvas bounds
293 y_start = y_int
294 y_end = y_start + tile_h
295 x_start = x_int
296 x_end = x_start + tile_w
298 # Tile bounds (for edge cases)
299 tile_y_start = 0
300 tile_y_end = tile_h
301 tile_x_start = 0
302 tile_x_end = tile_w
304 # Clip to canvas
305 if y_start < 0: 305 ↛ 306line 305 didn't jump to line 306 because the condition on line 305 was never true
306 tile_y_start = -y_start
307 y_start = 0
308 if x_start < 0: 308 ↛ 309line 308 didn't jump to line 309 because the condition on line 308 was never true
309 tile_x_start = -x_start
310 x_start = 0
311 if y_end > canvas_height: 311 ↛ 312line 311 didn't jump to line 312 because the condition on line 311 was never true
312 tile_y_end -= (y_end - canvas_height)
313 y_end = canvas_height
314 if x_end > canvas_width: 314 ↛ 315line 314 didn't jump to line 315 because the condition on line 314 was never true
315 tile_x_end -= (x_end - canvas_width)
316 x_end = canvas_width
318 # Skip if out of bounds
319 if (tile_y_start >= tile_y_end or tile_x_start >= tile_x_end or 319 ↛ 321line 319 didn't jump to line 321 because the condition on line 319 was never true
320 y_start >= y_end or x_start >= x_end):
321 continue
323 # Accumulate
324 composite_accum[y_start:y_end, x_start:x_end] += \
325 blended_tile[tile_y_start:tile_y_end, tile_x_start:tile_x_end]
327 weight_accum[y_start:y_end, x_start:x_end] += \
328 blend_masks[i][tile_y_start:tile_y_end, tile_x_start:tile_x_end]
330 # --- 5. Normalize ---
331 epsilon = 1e-7
332 stitched = composite_accum / (weight_accum + epsilon)
334 # Convert to uint16
335 stitched_uint16 = np.clip(stitched, 0, 65535).astype(np.uint16)
337 return stitched_uint16.reshape(1, canvas_height, canvas_width)
340def to_numpy(tensor):
341 """Convert various tensor types to numpy"""
342 if hasattr(tensor, 'dtype') and tensor.__class__.__module__ == 'numpy':
343 return tensor
344 if hasattr(tensor, 'get'): # CuPy
345 return tensor.get()
346 if hasattr(tensor, 'detach'): # PyTorch
347 return tensor.detach().cpu().numpy()
348 if hasattr(tensor, 'numpy') and hasattr(tensor, 'device'): # TF
349 return tensor.numpy()
350 raise ValueError(f"Unsupported tensor type: {type(tensor)}")