Coverage for openhcs/processing/backends/assemblers/assemble_stack_cupy.py: 6.7%
262 statements
« prev ^ index » next coverage.py v7.10.3, created at 2025-08-14 05:57 +0000
« prev ^ index » next coverage.py v7.10.3, created at 2025-08-14 05:57 +0000
1"""
2CuPy implementation of image assembly functions.
4This module provides GPU-accelerated functions for assembling microscopy images
5using CuPy. It handles subpixel positioning and blending of image tiles.
6"""
7from __future__ import annotations
9import logging
10from typing import TYPE_CHECKING, List, Tuple, Union, List, Tuple, Union
12from openhcs.core.memory.decorators import cupy as cupy_func
13from openhcs.core.pipeline.function_contracts import special_inputs
14from openhcs.core.utils import optional_import
16# For type checking only
17if TYPE_CHECKING: 17 ↛ 18line 17 didn't jump to line 18 because the condition on line 17 was never true
18 import cupy as cp
19 from cupyx.scipy.ndimage import gaussian_filter
20 from cupyx.scipy.ndimage import shift as subpixel_shift
22# Import CuPy as an optional dependency
23cp = optional_import("cupy")
25# Import CuPy functions if available
26if cp is not None: 26 ↛ 35line 26 didn't jump to line 35 because the condition on line 26 was always true
27 cupyx_scipy = optional_import("cupyx.scipy.ndimage")
28 if cupyx_scipy is not None: 28 ↛ 32line 28 didn't jump to line 32 because the condition on line 28 was always true
29 gaussian_filter = cupyx_scipy.gaussian_filter
30 subpixel_shift = cupyx_scipy.shift
31 else:
32 gaussian_filter = None
33 subpixel_shift = None
34else:
35 gaussian_filter = None
36 subpixel_shift = None
38logger = logging.getLogger(__name__)
40def _get_all_overlapping_pairs_gpu(positions: "cp.ndarray", tile_shape: tuple) -> list: # type: ignore
41 """
42 GPU-accelerated detection of ALL overlapping tile pairs with edge directions.
44 Args:
45 positions: CuPy array of shape (N, 2) with (x, y) positions
46 tile_shape: (height, width) of tiles
48 Returns:
49 List of (tile_i, tile_j, edge_direction, pixel_overlap) tuples
50 edge_direction: 'left', 'right', 'top', 'bottom' relative to tile_i
51 """
52 height, width = tile_shape
53 N = positions.shape[0]
55 if N <= 1:
56 return []
58 # Vectorized computation of ALL pairwise overlaps (fully GPU-accelerated)
59 # Broadcast positions for vectorized comparisons
60 pos_i = positions[:, cp.newaxis, :] # Shape: (N, 1, 2)
61 pos_j = positions[cp.newaxis, :, :] # Shape: (1, N, 2)
63 # Extract coordinates
64 xi, yi = pos_i[:, :, 0], pos_i[:, :, 1] # Shape: (N, 1)
65 xj, yj = pos_j[:, :, 0], pos_j[:, :, 1] # Shape: (1, N)
67 # Compute tile boundaries
68 left_i, right_i = xi, xi + width
69 top_i, bottom_i = yi, yi + height
70 left_j, right_j = xj, xj + width
71 top_j, bottom_j = yj, yj + height
73 # Compute overlap amounts between ALL pairs (vectorized on GPU)
74 x_overlap = cp.maximum(0, cp.minimum(right_i, right_j) - cp.maximum(left_i, left_j))
75 y_overlap = cp.maximum(0, cp.minimum(bottom_i, bottom_j) - cp.maximum(top_i, top_j))
77 # Valid overlaps (both x and y must overlap, and not self)
78 valid_overlap = (x_overlap > 0) & (y_overlap > 0) & (cp.arange(N)[:, None] != cp.arange(N)[None, :])
80 print(f"🔍 GPU DIRECT ADJACENCY: Checking all {N}×{N} pairs for overlaps")
82 # VECTORIZED: Keep everything on GPU, eliminate CPU transfers
83 overlapping_pairs = cp.where(valid_overlap)
84 pair_indices_i = overlapping_pairs[0]
85 pair_indices_j = overlapping_pairs[1]
87 if len(pair_indices_i) == 0:
88 return []
90 # Extract overlap values and positions for valid pairs (all on GPU)
91 pair_x_overlaps = x_overlap[pair_indices_i, pair_indices_j]
92 pair_y_overlaps = y_overlap[pair_indices_i, pair_indices_j]
94 # Get positions for all pairs
95 pos_i = positions[pair_indices_i] # Shape: (num_pairs, 2)
96 pos_j = positions[pair_indices_j] # Shape: (num_pairs, 2)
98 # Vectorized direction determination
99 xi_vals, yi_vals = pos_i[:, 0], pos_i[:, 1]
100 xj_vals, yj_vals = pos_j[:, 0], pos_j[:, 1]
102 # Create boolean masks for each direction (vectorized)
103 has_x_overlap = pair_x_overlaps > 0
104 has_y_overlap = pair_y_overlaps > 0
106 j_left_of_i = xj_vals < xi_vals
107 j_right_of_i = xj_vals > xi_vals
108 j_above_i = yj_vals < yi_vals
109 j_below_i = yj_vals > yi_vals
111 # Build edge pairs list (minimal CPU transfer at the end)
112 edge_pairs = []
114 # Convert to CPU only for final list construction (much smaller data)
115 indices_i_cpu = cp.asnumpy(pair_indices_i)
116 indices_j_cpu = cp.asnumpy(pair_indices_j)
117 x_overlaps_cpu = cp.asnumpy(pair_x_overlaps)
118 y_overlaps_cpu = cp.asnumpy(pair_y_overlaps)
120 has_x_cpu = cp.asnumpy(has_x_overlap)
121 has_y_cpu = cp.asnumpy(has_y_overlap)
122 left_cpu = cp.asnumpy(j_left_of_i)
123 right_cpu = cp.asnumpy(j_right_of_i)
124 above_cpu = cp.asnumpy(j_above_i)
125 below_cpu = cp.asnumpy(j_below_i)
127 # Vectorized edge pair construction
128 for idx in range(len(indices_i_cpu)):
129 i, j = indices_i_cpu[idx], indices_j_cpu[idx]
130 x_overlap_val = float(x_overlaps_cpu[idx])
131 y_overlap_val = float(y_overlaps_cpu[idx])
133 # Horizontal overlaps
134 if has_x_cpu[idx]:
135 if left_cpu[idx]:
136 edge_pairs.append((i, j, 'left', x_overlap_val))
137 elif right_cpu[idx]:
138 edge_pairs.append((i, j, 'right', x_overlap_val))
140 # Vertical overlaps
141 if has_y_cpu[idx]:
142 if above_cpu[idx]:
143 edge_pairs.append((i, j, 'top', y_overlap_val))
144 elif below_cpu[idx]:
145 edge_pairs.append((i, j, 'bottom', y_overlap_val))
147 print(f"✅ GPU: Found {len(edge_pairs)} total edge overlaps from {len(indices_i_cpu)} overlapping pairs")
148 return edge_pairs
151def _create_batch_fixed_masks_gpu(
152 tile_shape: tuple,
153 all_edge_overlaps: list,
154 margin_ratio: float = 0.1
155) -> "cp.ndarray":
156 """
157 VECTORIZED: Create all fixed blend masks at once for 2-3x speedup.
158 Uses batch operations instead of individual mask creation.
159 """
160 height, width = tile_shape
161 num_tiles = len(all_edge_overlaps)
163 # Pre-calculate margin pixels
164 margin_pixels_y = int(height * margin_ratio)
165 margin_pixels_x = int(width * margin_ratio)
167 # Create batch of 1D weights - shape (N, height) and (N, width)
168 y_weights = cp.ones((num_tiles, height), dtype=cp.float32)
169 x_weights = cp.ones((num_tiles, width), dtype=cp.float32)
171 # Pre-generate gradient arrays (reuse for all tiles)
172 if margin_pixels_y > 0:
173 top_gradient = cp.linspace(0, 1, margin_pixels_y, endpoint=False, dtype=cp.float32)
174 bottom_gradient = cp.linspace(1, 0, margin_pixels_y, endpoint=False, dtype=cp.float32)
176 if margin_pixels_x > 0:
177 left_gradient = cp.linspace(0, 1, margin_pixels_x, endpoint=False, dtype=cp.float32)
178 right_gradient = cp.linspace(1, 0, margin_pixels_x, endpoint=False, dtype=cp.float32)
180 # Apply gradients to each tile (vectorized where possible)
181 for i, edge_overlaps in enumerate(all_edge_overlaps):
182 if 'top' in edge_overlaps and margin_pixels_y > 0:
183 y_weights[i, :margin_pixels_y] = top_gradient
185 if 'bottom' in edge_overlaps and margin_pixels_y > 0:
186 y_weights[i, -margin_pixels_y:] = bottom_gradient
188 if 'left' in edge_overlaps and margin_pixels_x > 0:
189 x_weights[i, :margin_pixels_x] = left_gradient
191 if 'right' in edge_overlaps and margin_pixels_x > 0:
192 x_weights[i, -margin_pixels_x:] = right_gradient
194 # Batch outer product using broadcasting: (N, H, 1) * (N, 1, W) = (N, H, W)
195 masks = y_weights[:, :, cp.newaxis] * x_weights[:, cp.newaxis, :]
197 return masks.astype(cp.float32)
200def _create_batch_dynamic_masks_gpu(
201 tile_shape: tuple,
202 all_edge_overlaps: list,
203 overlap_fraction: float = 1.0
204) -> "cp.ndarray":
205 """
206 VECTORIZED: Create all dynamic blend masks at once for 2-3x speedup.
207 """
208 height, width = tile_shape
209 num_tiles = len(all_edge_overlaps)
211 # Create batch of 1D weights
212 y_weights = cp.ones((num_tiles, height), dtype=cp.float32)
213 x_weights = cp.ones((num_tiles, width), dtype=cp.float32)
215 # Apply gradients to each tile
216 for i, edge_overlaps in enumerate(all_edge_overlaps):
217 if 'top' in edge_overlaps:
218 overlap_pixels = int(edge_overlaps['top'] * overlap_fraction)
219 if overlap_pixels > 0:
220 y_weights[i, :overlap_pixels] = cp.linspace(0, 1, overlap_pixels, endpoint=False)
222 if 'bottom' in edge_overlaps:
223 overlap_pixels = int(edge_overlaps['bottom'] * overlap_fraction)
224 if overlap_pixels > 0:
225 y_weights[i, -overlap_pixels:] = cp.linspace(1, 0, overlap_pixels, endpoint=False)
227 if 'left' in edge_overlaps:
228 overlap_pixels = int(edge_overlaps['left'] * overlap_fraction)
229 if overlap_pixels > 0:
230 x_weights[i, :overlap_pixels] = cp.linspace(0, 1, overlap_pixels, endpoint=False)
232 if 'right' in edge_overlaps:
233 overlap_pixels = int(edge_overlaps['right'] * overlap_fraction)
234 if overlap_pixels > 0:
235 x_weights[i, -overlap_pixels:] = cp.linspace(1, 0, overlap_pixels, endpoint=False)
237 # Batch outer product using broadcasting
238 masks = y_weights[:, :, cp.newaxis] * x_weights[:, cp.newaxis, :]
240 return masks.astype(cp.float32)
243def _create_dynamic_blend_mask_gpu(
244 tile_shape: tuple,
245 edge_overlaps: dict,
246 overlap_fraction: float = 1.0
247) -> "cp.ndarray":
248 """
249 GPU version of dynamic blend mask using WORKING logic from CPU version.
250 CRITICAL: Uses endpoint=False and same logic as working CPU version.
251 """
252 height, width = tile_shape
254 # Create 1D weights
255 y_weight = cp.ones(height, dtype=cp.float32)
256 x_weight = cp.ones(width, dtype=cp.float32)
258 # Process each edge based on actual overlap (same as working CPU version)
259 # CRITICAL: endpoint=False (this is what made the CPU version work!)
260 if 'top' in edge_overlaps:
261 overlap_pixels = int(edge_overlaps['top'] * overlap_fraction)
262 if overlap_pixels > 0:
263 y_weight[:overlap_pixels] = cp.linspace(0, 1, overlap_pixels, endpoint=False)
265 if 'bottom' in edge_overlaps:
266 overlap_pixels = int(edge_overlaps['bottom'] * overlap_fraction)
267 if overlap_pixels > 0:
268 y_weight[-overlap_pixels:] = cp.linspace(1, 0, overlap_pixels, endpoint=False)
270 if 'left' in edge_overlaps:
271 overlap_pixels = int(edge_overlaps['left'] * overlap_fraction)
272 if overlap_pixels > 0:
273 x_weight[:overlap_pixels] = cp.linspace(0, 1, overlap_pixels, endpoint=False)
275 if 'right' in edge_overlaps:
276 overlap_pixels = int(edge_overlaps['right'] * overlap_fraction)
277 if overlap_pixels > 0:
278 x_weight[-overlap_pixels:] = cp.linspace(1, 0, overlap_pixels, endpoint=False)
280 # Use outer product (same as working CPU version)
281 mask = cp.outer(y_weight, x_weight)
282 return mask.astype(cp.float32)
285# Removed old complex function - using simpler _create_simple_dynamic_mask_gpu instead
288def _create_gaussian_blend_mask(tile_shape: tuple, blend_radius: float) -> "cp.ndarray": # type: ignore
289 """
290 Legacy function for backward compatibility.
291 Use _create_blend_mask with blend_method="gaussian" instead.
292 """
293 return _create_blend_mask(tile_shape, "gaussian", blend_radius)
296@special_inputs("positions") # The input name is "positions"
297@cupy_func
298def assemble_stack_cupy(
299 image_tiles: "cp.ndarray", # type: ignore
300 positions: Union[List[Tuple[float, float]], "cp.ndarray"], # type: ignore
301 blend_method: str = "fixed",
302 fixed_margin_ratio: float = 0.1,
303 overlap_blend_fraction: float = 1.0
304) -> "cp.ndarray": # type: ignore
305 """
306 GPU-accelerated assembly using WORKING logic from CPU version.
308 Args:
309 image_tiles: 3D CuPy array of tiles (N, H, W)
310 positions: List of (x, y) tuples or 2D array [N, 2]
311 blend_method: "none", "fixed", or "dynamic"
312 fixed_margin_ratio: Ratio for fixed blending (e.g., 0.1 = 10%)
313 overlap_blend_fraction: For dynamic mode, fraction of overlap to blend
315 Returns:
316 3D CuPy array (1, H_canvas, W_canvas) with assembled image
317 """
318 # The compiler will ensure this function is only called when CuPy is available
319 # No need to check for CuPy availability here
320 # --- 1. Validate and standardize inputs ---
321 if not isinstance(image_tiles, cp.ndarray) or image_tiles.ndim != 3:
322 raise TypeError("image_tiles must be a 3D CuPy ndarray of shape (N, H, W).")
323 if image_tiles.shape[0] == 0:
324 logger.warning("image_tiles array is empty (0 tiles). Returning an empty array.")
325 return cp.array([[[]]], dtype=cp.uint16) # Shape (1,0,0) to indicate empty 3D
327 # Convert positions to CuPy array for GPU-native operations
328 if isinstance(positions, list):
329 # Convert list of tuples to CuPy array
330 if not positions or not isinstance(positions[0], tuple) or len(positions[0]) != 2:
331 raise TypeError("positions must be a list of (x, y) tuples.")
332 positions = cp.array(positions, dtype=cp.float32)
333 else:
334 # Handle array input (backward compatibility)
335 if not hasattr(positions, 'ndim') or positions.ndim != 2 or positions.shape[1] != 2:
336 raise TypeError("positions must be an array of shape [N, 2] or list of (x, y) tuples.")
337 positions = cp.asarray(positions) # Convert to cupy for GPU operations
339 # Debug: Print positions information
340 print(f"Assembly: Received {positions.shape[0]} positions for {image_tiles.shape[0]} tiles")
341 print(f"Position range: X=[{float(cp.min(positions[:, 0])):.1f}, {float(cp.max(positions[:, 0])):.1f}], Y=[{float(cp.min(positions[:, 1])):.1f}, {float(cp.max(positions[:, 1])):.1f}]")
342 print(f"First 3 positions: {positions[:3].tolist()}")
344 # Debug: Check image tile statistics
345 print(f"🔥 ASSEMBLY DEBUG: Image tiles shape: {image_tiles.shape}")
346 print(f"🔥 ASSEMBLY DEBUG: Image tiles dtype: {image_tiles.dtype}")
347 for i in range(min(3, image_tiles.shape[0])):
348 tile_min = float(cp.min(image_tiles[i]))
349 tile_max = float(cp.max(image_tiles[i]))
350 tile_mean = float(cp.mean(image_tiles[i]))
351 tile_nonzero = int(cp.count_nonzero(image_tiles[i]))
352 print(f"🔥 ASSEMBLY DEBUG: Tile {i}: min={tile_min:.3f}, max={tile_max:.3f}, mean={tile_mean:.3f}, nonzero={tile_nonzero}")
354 # Debug: Check if tiles are all zeros
355 total_nonzero = int(cp.count_nonzero(image_tiles))
356 total_pixels = int(cp.prod(cp.array(image_tiles.shape)))
357 print(f"🔥 ASSEMBLY DEBUG: Total nonzero pixels: {total_nonzero}/{total_pixels} ({100*total_nonzero/total_pixels:.1f}%)")
359 if image_tiles.shape[0] != positions.shape[0]:
360 raise ValueError(f"Mismatch between number of image_tiles ({image_tiles.shape[0]}) and positions ({positions.shape[0]}).")
362 num_tiles, tile_h, tile_w = image_tiles.shape
363 first_tile_shape = (tile_h, tile_w) # Used for blend mask, assumes all tiles same H, W
365 # Note: Convert tiles to float32 one at a time to save memory
366 # (removed bulk conversion to avoid doubling memory usage)
368 # --- 2. Compute canvas bounds ---
369 # positions_xy are for top-left corners.
370 # Add tile dimensions to get bottom-right corners for each tile.
371 # positions_xy[:, 0] is X (width dimension), positions_xy[:, 1] is Y (height dimension)
373 # Min/max X coordinates of tile top-left corners
374 min_x_pos = cp.min(positions[:, 0])
375 max_x_pos = cp.max(positions[:, 0])
377 # Min/max Y coordinates of tile top-left corners
378 min_y_pos = cp.min(positions[:, 1])
379 max_y_pos = cp.max(positions[:, 1])
381 # Canvas dimensions need to encompass all tiles
382 # Canvas origin will be (min_x_pos_rounded_down, min_y_pos_rounded_down)
383 # Max extent is max_pos + tile_dim
384 canvas_min_x = cp.floor(min_x_pos).astype(cp.int32) # cupy needs explicit int type for astype(int)
385 canvas_min_y = cp.floor(min_y_pos).astype(cp.int32) # cupy needs explicit int type for astype(int)
387 canvas_max_x = cp.ceil(max_x_pos + tile_w).astype(cp.int32) # cupy needs explicit int type for astype(int)
388 canvas_max_y = cp.ceil(max_y_pos + tile_h).astype(cp.int32) # cupy needs explicit int type for astype(int)
390 canvas_width = canvas_max_x - canvas_min_x
391 canvas_height = canvas_max_y - canvas_min_y
393 # Debug: Print canvas information
394 print(f"Canvas: {int(canvas_width)}x{int(canvas_height)} pixels, origin=({float(canvas_min_x):.1f}, {float(canvas_min_y):.1f})")
395 print(f"Tile size: {tile_w}x{tile_h} pixels")
397 if canvas_width <= 0 or canvas_height <= 0:
398 logger.warning(f"Calculated canvas dimensions are non-positive ({canvas_height}x{canvas_width}). Check positions and tile sizes.")
399 return cp.array([], dtype=cp.uint16)
401 composite_accum = cp.zeros((int(canvas_height), int(canvas_width)), dtype=cp.float32)
402 weight_accum = cp.zeros((int(canvas_height), int(canvas_width)), dtype=cp.float32)
404 # --- 3. Generate blend masks using WORKING logic from CPU version ---
405 if blend_method == "none":
406 blend_masks = [cp.ones(first_tile_shape, dtype=cp.float32) for _ in range(num_tiles)]
408 else:
409 # Find overlaps (same as working CPU version)
410 edge_pairs = _get_all_overlapping_pairs_gpu(positions, first_tile_shape)
411 tile_overlaps = [{} for _ in range(num_tiles)]
413 # Build overlap info per tile
414 for tile_i, tile_j, edge_direction, pixel_overlap in edge_pairs:
415 if edge_direction not in tile_overlaps[tile_i]:
416 tile_overlaps[tile_i][edge_direction] = pixel_overlap
417 else:
418 # Keep maximum overlap
419 tile_overlaps[tile_i][edge_direction] = max(
420 tile_overlaps[tile_i][edge_direction], pixel_overlap
421 )
423 # VECTORIZED: Create all masks at once using batch operations
424 if blend_method == "fixed":
425 # Create all fixed masks in one batch operation
426 masks_batch = _create_batch_fixed_masks_gpu(
427 first_tile_shape,
428 tile_overlaps,
429 margin_ratio=fixed_margin_ratio
430 )
431 elif blend_method == "dynamic":
432 # Create all dynamic masks in one batch operation
433 masks_batch = _create_batch_dynamic_masks_gpu(
434 first_tile_shape,
435 tile_overlaps,
436 overlap_fraction=overlap_blend_fraction
437 )
438 else:
439 raise ValueError(f"Unknown blend_method: {blend_method}")
441 # Convert batch tensor to list for compatibility with existing code
442 blend_masks = [masks_batch[i] for i in range(num_tiles)]
444 # --- 3.5. Batch convert to float32 for better memory efficiency ---
445 image_tiles_float = image_tiles.astype(cp.float32)
447 # --- 3.6. VECTORIZED: Pre-calculate all position data ---
448 positions_array = cp.array(positions, dtype=cp.float32) # Shape: (N, 2)
449 target_canvas_positions = positions_array - cp.array([canvas_min_x, canvas_min_y], dtype=cp.float32)
451 # Vectorized calculation of integer and fractional parts for all tiles
452 canvas_starts_int = cp.floor(target_canvas_positions).astype(cp.int32) # Shape: (N, 2)
453 fractional_parts = target_canvas_positions - canvas_starts_int # Shape: (N, 2)
454 subpixel_shifts = -fractional_parts # Shape: (N, 2) - negative for scipy.ndimage.shift
456 # --- 4. Place tiles with subpixel shifts (using pre-calculated values) ---
457 for i in range(num_tiles):
458 tile_float = image_tiles_float[i]
460 # Use pre-calculated values (vectorized above)
461 canvas_x_start_int = int(canvas_starts_int[i, 0].item())
462 canvas_y_start_int = int(canvas_starts_int[i, 1].item())
463 shift_x_subpixel = subpixel_shifts[i, 0]
464 shift_y_subpixel = subpixel_shifts[i, 1]
466 shifted_tile = subpixel_shift(tile_float, shift=(shift_y_subpixel, shift_x_subpixel), order=1, mode='constant', cval=0.0)
468 # Apply tile-specific blending mask
469 blended_tile = shifted_tile * blend_masks[i]
471 # Define where this tile (and its mask) go on the canvas
472 y_start_on_canvas = canvas_y_start_int
473 y_end_on_canvas = y_start_on_canvas + tile_h
474 x_start_on_canvas = canvas_x_start_int
475 x_end_on_canvas = x_start_on_canvas + tile_w
477 # Define what part of the tile to take (in case it goes off-canvas)
478 tile_y_start_src = 0
479 tile_y_end_src = tile_h
480 tile_x_start_src = 0
481 tile_x_end_src = tile_w
483 # Adjust for tile parts that are off the canvas (negative start)
484 if y_start_on_canvas < 0:
485 tile_y_start_src = -y_start_on_canvas
486 y_start_on_canvas = 0
487 if x_start_on_canvas < 0:
488 tile_x_start_src = -x_start_on_canvas
489 x_start_on_canvas = 0
491 # Adjust for tile parts that are off the canvas (positive end)
492 if y_end_on_canvas > canvas_height:
493 tile_y_end_src -= (y_end_on_canvas - canvas_height)
494 y_end_on_canvas = canvas_height
495 if x_end_on_canvas > canvas_width:
496 tile_x_end_src -= (x_end_on_canvas - canvas_width)
497 x_end_on_canvas = canvas_width
499 # If the tile is entirely off-canvas after adjustments, skip
500 if tile_y_start_src >= tile_y_end_src or tile_x_start_src >= tile_x_end_src:
501 continue
502 if y_start_on_canvas >= y_end_on_canvas or x_start_on_canvas >= x_end_on_canvas:
503 continue
505 # Add to accumulators
506 composite_accum[y_start_on_canvas:y_end_on_canvas, x_start_on_canvas:x_end_on_canvas] += \
507 blended_tile[tile_y_start_src:tile_y_end_src, tile_x_start_src:tile_x_end_src]
509 weight_accum[y_start_on_canvas:y_end_on_canvas, x_start_on_canvas:x_end_on_canvas] += \
510 blend_masks[i][tile_y_start_src:tile_y_end_src, tile_x_start_src:tile_x_end_src]
512 # --- 5. Normalize + cast ---
513 epsilon = 1e-7 # To avoid division by zero
514 stitched_image_float = composite_accum / (weight_accum + epsilon)
516 # Clip to 0-65535 and cast to uint16
517 stitched_image_uint16 = cp.clip(stitched_image_float, 0, 65535).astype(cp.uint16)
519 # Return as a 3D array with a single Z-slice
520 return stitched_image_uint16.reshape(1, canvas_height.item(), canvas_width.item()) # .item() to convert 0-dim cupy array to scalar