Coverage for openhcs/processing/backends/enhance/cupy_clahe.py: 7.0%
296 statements
« prev ^ index » next coverage.py v7.10.3, created at 2025-08-14 05:57 +0000
« prev ^ index » next coverage.py v7.10.3, created at 2025-08-14 05:57 +0000
1from __future__ import annotations
3import logging
4from typing import Any, List, Optional, Tuple
6from openhcs.core.memory.decorators import cupy as cupy_func
7from openhcs.core.utils import optional_import
9# Import CuPy as an optional dependency
10cp = optional_import("cupy")
11ndimage = None
12if cp is not None: 12 ↛ 17line 12 didn't jump to line 17 because the condition on line 12 was always true
13 cupyx_scipy = optional_import("cupyx.scipy")
14 if cupyx_scipy is not None: 14 ↛ 17line 14 didn't jump to line 17 because the condition on line 14 was always true
15 ndimage = cupyx_scipy.ndimage
17logger = logging.getLogger(__name__)
19@cupy_func
20def clahe_2d(
21 image: "cp.ndarray",
22 clip_limit: float = 2.0,
23 tile_grid_size: tuple = None,
24 nbins: int = None,
25 adaptive_bins: bool = True,
26 adaptive_tiles: bool = True
27) -> "cp.ndarray":
28 """
29 Optimized 2D CLAHE with vectorized bilinear interpolation.
30 """
32 result = cp.zeros_like(image)
34 for z in range(image.shape[0]):
35 slice_2d = image[z]
36 height, width = slice_2d.shape
38 # Adaptive parameters
39 if nbins is None:
40 if adaptive_bins:
41 data_range = float(cp.max(slice_2d) - cp.min(slice_2d))
42 adaptive_nbins = min(512, max(64, int(cp.sqrt(data_range))))
43 else:
44 adaptive_nbins = 256
45 else:
46 adaptive_nbins = nbins
48 if tile_grid_size is None:
49 if adaptive_tiles:
50 target_tile_size = 80
51 adaptive_tile_rows = max(2, min(16, height // target_tile_size))
52 adaptive_tile_cols = max(2, min(16, width // target_tile_size))
53 adaptive_tile_grid = (adaptive_tile_rows, adaptive_tile_cols)
54 else:
55 adaptive_tile_grid = (8, 8)
56 else:
57 adaptive_tile_grid = tile_grid_size
59 result[z] = _clahe_2d_vectorized(
60 slice_2d, clip_limit, adaptive_tile_grid, adaptive_nbins
61 )
63 return result
65def _clahe_2d_vectorized(
66 image: "cp.ndarray",
67 clip_limit: float,
68 tile_grid_size: tuple,
69 nbins: int
70) -> "cp.ndarray":
71 """
72 Vectorized CLAHE implementation for 2D images.
73 """
74 if image.ndim != 2:
75 raise ValueError("Input must be 2D array")
77 height, width = image.shape
78 tile_rows, tile_cols = tile_grid_size
80 # Calculate tile dimensions
81 tile_height = height // tile_rows
82 tile_width = width // tile_cols
84 # Ensure we have valid tiles
85 if tile_height < 1 or tile_width < 1:
86 raise ValueError(f"Image too small for {tile_rows}x{tile_cols} tiles")
88 # Calculate crop dimensions
89 crop_height = tile_height * tile_rows
90 crop_width = tile_width * tile_cols
91 image_crop = image[:crop_height, :crop_width]
93 # Calculate actual clip limit
94 actual_clip_limit = max(1, int(clip_limit * tile_height * tile_width / nbins))
96 # Get value range
97 min_val = float(cp.min(image_crop))
98 max_val = float(cp.max(image_crop))
100 if max_val <= min_val:
101 return image.astype(image.dtype) # Constant image
103 # Compute tile CDFs
104 tile_cdfs = _compute_tile_cdfs_2d(
105 image_crop, tile_rows, tile_cols, tile_height, tile_width,
106 nbins, actual_clip_limit, min_val, max_val
107 )
109 # Apply vectorized interpolation
110 result = _apply_vectorized_interpolation_2d(
111 image_crop, tile_cdfs, tile_rows, tile_cols,
112 tile_height, tile_width, nbins, min_val, max_val
113 )
115 # Handle original image size
116 if result.shape != image.shape:
117 full_result = cp.zeros_like(image, dtype=result.dtype)
118 full_result[:crop_height, :crop_width] = result
120 # Fill remaining areas by replication
121 if crop_height < height:
122 full_result[crop_height:, :crop_width] = result[-1:, :]
123 if crop_width < width:
124 full_result[:crop_height, crop_width:] = result[:, -1:]
125 if crop_height < height and crop_width < width:
126 full_result[crop_height:, crop_width:] = result[-1, -1]
127 result = full_result
129 return result.astype(image.dtype)
131def _compute_tile_cdfs_2d(
132 image: "cp.ndarray",
133 tile_rows: int,
134 tile_cols: int,
135 tile_height: int,
136 tile_width: int,
137 nbins: int,
138 clip_limit: int,
139 min_val: float,
140 max_val: float
141) -> "cp.ndarray":
142 """Compute CDFs for all tiles efficiently."""
144 tile_cdfs = cp.zeros((tile_rows, tile_cols, nbins), dtype=cp.float32)
146 # Precompute bin edges
147 bin_edges = cp.linspace(min_val, max_val, nbins + 1, dtype=cp.float32)
148 bin_width = (max_val - min_val) / nbins
150 for row in range(tile_rows):
151 for col in range(tile_cols):
152 # Extract tile
153 y_start = row * tile_height
154 y_end = (row + 1) * tile_height
155 x_start = col * tile_width
156 x_end = (col + 1) * tile_width
158 tile = image[y_start:y_end, x_start:x_end]
160 # Compute histogram
161 hist, _ = cp.histogram(tile, bins=bin_edges)
163 # Clip and redistribute
164 hist = _clip_histogram_optimized(hist, clip_limit)
166 # Compute CDF and normalize properly
167 cdf = cp.cumsum(hist, dtype=cp.float32)
168 if cdf[-1] > 0:
169 # Normalize to [0, 1] then scale to output range
170 cdf = cdf / cdf[-1]
171 # Map to intensity values (proper CLAHE transformation)
172 tile_cdfs[row, col, :] = min_val + cdf * (max_val - min_val)
173 else:
174 tile_cdfs[row, col, :] = min_val
176 return tile_cdfs
178def _apply_vectorized_interpolation_2d(
179 image: "cp.ndarray",
180 tile_cdfs: "cp.ndarray",
181 tile_rows: int,
182 tile_cols: int,
183 tile_height: int,
184 tile_width: int,
185 nbins: int,
186 min_val: float,
187 max_val: float
188) -> "cp.ndarray":
189 """Vectorized bilinear interpolation."""
191 height, width = image.shape
193 # Create coordinate grids
194 y_coords, x_coords = cp.meshgrid(
195 cp.arange(height, dtype=cp.float32),
196 cp.arange(width, dtype=cp.float32),
197 indexing='ij'
198 )
200 # Calculate tile centers
201 tile_centers_y = cp.arange(tile_rows, dtype=cp.float32) * tile_height + tile_height // 2
202 tile_centers_x = cp.arange(tile_cols, dtype=cp.float32) * tile_width + tile_width // 2
204 # Find surrounding tiles for each pixel (vectorized)
205 tile_y_low = cp.searchsorted(tile_centers_y, y_coords.flatten()) - 1
206 tile_x_low = cp.searchsorted(tile_centers_x, x_coords.flatten()) - 1
208 # Clamp to valid ranges
209 tile_y_low = cp.clip(tile_y_low, 0, tile_rows - 2).reshape(height, width)
210 tile_x_low = cp.clip(tile_x_low, 0, tile_cols - 2).reshape(height, width)
212 tile_y_high = tile_y_low + 1
213 tile_x_high = tile_x_low + 1
215 # Convert pixel values to bin indices (vectorized)
216 normalized_values = (image - min_val) / (max_val - min_val)
217 bin_indices = cp.clip(
218 (normalized_values * (nbins - 1)).astype(cp.int32),
219 0, nbins - 1
220 )
222 # Calculate interpolation weights (vectorized)
223 center_y_low = tile_centers_y[tile_y_low]
224 center_y_high = tile_centers_y[tile_y_high]
225 center_x_low = tile_centers_x[tile_x_low]
226 center_x_high = tile_centers_x[tile_x_high]
228 # Avoid division by zero
229 dy = center_y_high - center_y_low
230 dx = center_x_high - center_x_low
232 wy = cp.where(dy > 0, (y_coords - center_y_low) / dy, 0.0)
233 wx = cp.where(dx > 0, (x_coords - center_x_low) / dx, 0.0)
235 # Clamp weights
236 wy = cp.clip(wy, 0.0, 1.0)
237 wx = cp.clip(wx, 0.0, 1.0)
239 # Get transformation values (this is the tricky part - need advanced indexing)
240 val_tl = tile_cdfs[tile_y_low, tile_x_low, bin_indices]
241 val_tr = tile_cdfs[tile_y_low, tile_x_high, bin_indices]
242 val_bl = tile_cdfs[tile_y_high, tile_x_low, bin_indices]
243 val_br = tile_cdfs[tile_y_high, tile_x_high, bin_indices]
245 # Bilinear interpolation (vectorized)
246 val_top = (1 - wx) * val_tl + wx * val_tr
247 val_bottom = (1 - wx) * val_bl + wx * val_br
248 result = (1 - wy) * val_top + wy * val_bottom
250 return result
252def _clip_histogram_optimized(hist: "cp.ndarray", clip_limit: int) -> "cp.ndarray":
253 """Optimized histogram clipping."""
254 if clip_limit <= 0:
255 return hist
257 # Convert to float for precise calculations
258 hist_float = hist.astype(cp.float32)
260 # Find excess and clip
261 excess = cp.maximum(hist_float - clip_limit, 0)
262 total_excess = cp.sum(excess)
264 clipped_hist = cp.minimum(hist_float, clip_limit)
266 # Redistribute excess uniformly
267 if total_excess > 0:
268 nbins = len(hist)
269 redistribution = total_excess / nbins
270 clipped_hist += redistribution
272 # Handle overflow after redistribution (iterative clipping)
273 for _ in range(3): # Max 3 iterations should be enough
274 overflow = cp.maximum(clipped_hist - clip_limit, 0)
275 total_overflow = cp.sum(overflow)
277 if total_overflow < 1e-6:
278 break
280 clipped_hist = cp.minimum(clipped_hist, clip_limit)
281 # Redistribute overflow to non-saturated bins
282 non_saturated = clipped_hist < clip_limit
283 if cp.any(non_saturated):
284 available_space = cp.sum(cp.maximum(clip_limit - clipped_hist, 0))
285 if available_space > 0:
286 redistrib_factor = min(1.0, total_overflow / available_space)
287 clipped_hist += cp.where(
288 non_saturated,
289 redistrib_factor * cp.maximum(clip_limit - clipped_hist, 0),
290 0
291 )
293 return clipped_hist.astype(hist.dtype)
295@cupy_func
296def clahe_3d(
297 stack: "cp.ndarray",
298 clip_limit: float = 2.0,
299 tile_grid_size_3d: tuple = None,
300 nbins: int = None,
301 adaptive_bins: bool = True,
302 adaptive_tiles: bool = True,
303 memory_efficient: bool = True
304) -> "cp.ndarray":
305 """
306 Optimized 3D CLAHE with vectorized trilinear interpolation.
308 Args:
309 stack: 3D CuPy array of shape (Z, Y, X)
310 clip_limit: Threshold for contrast limiting
311 tile_grid_size_3d: Number of tiles (z_tiles, y_tiles, x_tiles)
312 nbins: Number of histogram bins
313 adaptive_bins: Whether to adapt bins based on data range
314 adaptive_tiles: Whether to adapt tile size based on volume dimensions
315 memory_efficient: Use chunked processing for large volumes
316 """
318 depth, height, width = stack.shape
320 # Adaptive parameters
321 if nbins is None:
322 if adaptive_bins:
323 data_range = float(cp.max(stack) - cp.min(stack))
324 adaptive_nbins = min(512, max(128, int(cp.cbrt(data_range * 64))))
325 else:
326 adaptive_nbins = 256
327 else:
328 adaptive_nbins = nbins
330 if tile_grid_size_3d is None:
331 if adaptive_tiles:
332 target_tile_size = 48
333 adaptive_z_tiles = max(1, min(depth // 4, depth // target_tile_size))
334 adaptive_y_tiles = max(2, min(8, height // target_tile_size))
335 adaptive_x_tiles = max(2, min(8, width // target_tile_size))
336 adaptive_tile_grid_3d = (adaptive_z_tiles, adaptive_y_tiles, adaptive_x_tiles)
337 else:
338 adaptive_tile_grid_3d = (max(1, depth // 8), 4, 4)
339 else:
340 adaptive_tile_grid_3d = tile_grid_size_3d
342 # Check memory requirements and use chunked processing if needed
343 total_voxels = depth * height * width
344 if memory_efficient and total_voxels > 512**3: # ~134M voxels threshold
345 return _clahe_3d_chunked(stack, clip_limit, adaptive_tile_grid_3d, adaptive_nbins)
346 else:
347 return _clahe_3d_vectorized(stack, clip_limit, adaptive_tile_grid_3d, adaptive_nbins)
349def _clahe_3d_vectorized(
350 stack: "cp.ndarray",
351 clip_limit: float,
352 tile_grid_size_3d: tuple,
353 nbins: int
354) -> "cp.ndarray":
355 """
356 Full vectorized 3D CLAHE implementation.
357 """
358 depth, height, width = stack.shape
359 tile_z, tile_y, tile_x = tile_grid_size_3d
361 # Calculate 3D tile dimensions
362 tile_depth = max(1, depth // tile_z)
363 tile_height = max(4, height // tile_y)
364 tile_width = max(4, width // tile_x)
366 # Ensure valid tiles
367 if tile_depth < 1 or tile_height < 1 or tile_width < 1:
368 raise ValueError(f"Volume too small for {tile_z}x{tile_y}x{tile_x} tiles")
370 # Recalculate actual number of tiles
371 actual_tile_z = depth // tile_depth
372 actual_tile_y = height // tile_height
373 actual_tile_x = width // tile_width
375 # Calculate crop dimensions
376 crop_depth = tile_depth * actual_tile_z
377 crop_height = tile_height * actual_tile_y
378 crop_width = tile_width * actual_tile_x
379 stack_crop = stack[:crop_depth, :crop_height, :crop_width]
381 # Calculate actual clip limit
382 voxels_per_tile = tile_depth * tile_height * tile_width
383 actual_clip_limit = max(1, int(clip_limit * voxels_per_tile / nbins))
385 # Get value range
386 min_val = float(cp.min(stack_crop))
387 max_val = float(cp.max(stack_crop))
389 if max_val <= min_val:
390 return stack.astype(stack.dtype) # Constant volume
392 # Compute 3D tile CDFs
393 tile_cdfs = _compute_tile_cdfs_3d(
394 stack_crop, actual_tile_z, actual_tile_y, actual_tile_x,
395 tile_depth, tile_height, tile_width,
396 nbins, actual_clip_limit, min_val, max_val
397 )
399 # Apply vectorized trilinear interpolation
400 result = _apply_vectorized_trilinear_interpolation(
401 stack_crop, tile_cdfs, actual_tile_z, actual_tile_y, actual_tile_x,
402 tile_depth, tile_height, tile_width, nbins, min_val, max_val
403 )
405 # Handle original stack size
406 if result.shape != stack.shape:
407 full_result = cp.zeros_like(stack, dtype=result.dtype)
408 full_result[:crop_depth, :crop_height, :crop_width] = result
410 # Fill remaining regions efficiently
411 _fill_3d_boundaries(full_result, result, crop_depth, crop_height, crop_width,
412 depth, height, width)
413 result = full_result
415 return result.astype(stack.dtype)
417def _compute_tile_cdfs_3d(
418 stack: "cp.ndarray",
419 tile_z: int,
420 tile_y: int,
421 tile_x: int,
422 tile_depth: int,
423 tile_height: int,
424 tile_width: int,
425 nbins: int,
426 clip_limit: int,
427 min_val: float,
428 max_val: float
429) -> "cp.ndarray":
430 """Compute CDFs for all 3D tiles efficiently."""
432 tile_cdfs = cp.zeros((tile_z, tile_y, tile_x, nbins), dtype=cp.float32)
434 # Precompute bin edges
435 bin_edges = cp.linspace(min_val, max_val, nbins + 1, dtype=cp.float32)
437 for z_idx in range(tile_z):
438 for y_idx in range(tile_y):
439 for x_idx in range(tile_x):
440 # Extract 3D tile
441 z_start = z_idx * tile_depth
442 z_end = (z_idx + 1) * tile_depth
443 y_start = y_idx * tile_height
444 y_end = (y_idx + 1) * tile_height
445 x_start = x_idx * tile_width
446 x_end = (x_idx + 1) * tile_width
448 tile_3d = stack[z_start:z_end, y_start:y_end, x_start:x_end]
450 # Compute 3D histogram efficiently
451 hist, _ = cp.histogram(tile_3d.ravel(), bins=bin_edges)
453 # Clip and redistribute
454 hist = _clip_histogram_optimized(hist, clip_limit)
456 # Compute CDF and normalize properly
457 cdf = cp.cumsum(hist, dtype=cp.float32)
458 if cdf[-1] > 0:
459 # Normalize to [0, 1] then scale to output range
460 cdf = cdf / cdf[-1]
461 tile_cdfs[z_idx, y_idx, x_idx, :] = min_val + cdf * (max_val - min_val)
462 else:
463 tile_cdfs[z_idx, y_idx, x_idx, :] = min_val
465 return tile_cdfs
467def _apply_vectorized_trilinear_interpolation(
468 stack: "cp.ndarray",
469 tile_cdfs: "cp.ndarray",
470 tile_z: int,
471 tile_y: int,
472 tile_x: int,
473 tile_depth: int,
474 tile_height: int,
475 tile_width: int,
476 nbins: int,
477 min_val: float,
478 max_val: float
479) -> "cp.ndarray":
480 """Vectorized trilinear interpolation for 3D CLAHE."""
482 depth, height, width = stack.shape
484 # Create 3D coordinate grids
485 z_coords, y_coords, x_coords = cp.meshgrid(
486 cp.arange(depth, dtype=cp.float32),
487 cp.arange(height, dtype=cp.float32),
488 cp.arange(width, dtype=cp.float32),
489 indexing='ij'
490 )
492 # Calculate tile centers
493 tile_centers_z = cp.arange(tile_z, dtype=cp.float32) * tile_depth + tile_depth // 2
494 tile_centers_y = cp.arange(tile_y, dtype=cp.float32) * tile_height + tile_height // 2
495 tile_centers_x = cp.arange(tile_x, dtype=cp.float32) * tile_width + tile_width // 2
497 # Find surrounding tiles for each voxel (vectorized)
498 total_voxels = depth * height * width
499 coords_flat = cp.column_stack([
500 z_coords.ravel(),
501 y_coords.ravel(),
502 x_coords.ravel()
503 ])
505 # Use searchsorted to find tile indices
506 tile_z_low = cp.searchsorted(tile_centers_z, coords_flat[:, 0]) - 1
507 tile_y_low = cp.searchsorted(tile_centers_y, coords_flat[:, 1]) - 1
508 tile_x_low = cp.searchsorted(tile_centers_x, coords_flat[:, 2]) - 1
510 # Clamp to valid ranges
511 tile_z_low = cp.clip(tile_z_low, 0, tile_z - 2).reshape(depth, height, width)
512 tile_y_low = cp.clip(tile_y_low, 0, tile_y - 2).reshape(depth, height, width)
513 tile_x_low = cp.clip(tile_x_low, 0, tile_x - 2).reshape(depth, height, width)
515 # Handle edge case for single tile in z-dimension
516 if tile_z == 1:
517 tile_z_low = cp.zeros_like(tile_z_low)
519 tile_z_high = cp.minimum(tile_z_low + 1, tile_z - 1)
520 tile_y_high = tile_y_low + 1
521 tile_x_high = tile_x_low + 1
523 # Convert voxel values to bin indices (vectorized)
524 normalized_values = (stack - min_val) / (max_val - min_val)
525 bin_indices = cp.clip(
526 (normalized_values * (nbins - 1)).astype(cp.int32),
527 0, nbins - 1
528 )
530 # Calculate interpolation weights (vectorized)
531 center_z_low = tile_centers_z[tile_z_low]
532 center_z_high = tile_centers_z[tile_z_high]
533 center_y_low = tile_centers_y[tile_y_low]
534 center_y_high = tile_centers_y[tile_y_high]
535 center_x_low = tile_centers_x[tile_x_low]
536 center_x_high = tile_centers_x[tile_x_high]
538 # Avoid division by zero
539 dz = center_z_high - center_z_low
540 dy = center_y_high - center_y_low
541 dx = center_x_high - center_x_low
543 wz = cp.where(dz > 0, (z_coords - center_z_low) / dz, 0.0)
544 wy = cp.where(dy > 0, (y_coords - center_y_low) / dy, 0.0)
545 wx = cp.where(dx > 0, (x_coords - center_x_low) / dx, 0.0)
547 # Clamp weights
548 wz = cp.clip(wz, 0.0, 1.0)
549 wy = cp.clip(wy, 0.0, 1.0)
550 wx = cp.clip(wx, 0.0, 1.0)
552 # Get the 8 surrounding transformation values using advanced indexing
553 val_000 = tile_cdfs[tile_z_low, tile_y_low, tile_x_low, bin_indices]
554 val_001 = tile_cdfs[tile_z_low, tile_y_low, tile_x_high, bin_indices]
555 val_010 = tile_cdfs[tile_z_low, tile_y_high, tile_x_low, bin_indices]
556 val_011 = tile_cdfs[tile_z_low, tile_y_high, tile_x_high, bin_indices]
557 val_100 = tile_cdfs[tile_z_high, tile_y_low, tile_x_low, bin_indices]
558 val_101 = tile_cdfs[tile_z_high, tile_y_low, tile_x_high, bin_indices]
559 val_110 = tile_cdfs[tile_z_high, tile_y_high, tile_x_low, bin_indices]
560 val_111 = tile_cdfs[tile_z_high, tile_y_high, tile_x_high, bin_indices]
562 # Trilinear interpolation (vectorized)
563 # First interpolate along x-axis
564 val_00 = (1 - wx) * val_000 + wx * val_001 # front-bottom
565 val_01 = (1 - wx) * val_010 + wx * val_011 # front-top
566 val_10 = (1 - wx) * val_100 + wx * val_101 # back-bottom
567 val_11 = (1 - wx) * val_110 + wx * val_111 # back-top
569 # Then interpolate along y-axis
570 val_0 = (1 - wy) * val_00 + wy * val_01 # front face
571 val_1 = (1 - wy) * val_10 + wy * val_11 # back face
573 # Finally interpolate along z-axis
574 result = (1 - wz) * val_0 + wz * val_1
576 return result
578def _clahe_3d_chunked(
579 stack: "cp.ndarray",
580 clip_limit: float,
581 tile_grid_size_3d: tuple,
582 nbins: int,
583 chunk_size: int = 128
584) -> "cp.ndarray":
585 """
586 Memory-efficient chunked processing for very large 3D volumes.
588 Processes the volume in overlapping chunks to manage memory usage.
589 """
590 depth, height, width = stack.shape
591 result = cp.zeros_like(stack)
593 # Calculate overlap needed for smooth transitions
594 tile_z, tile_y, tile_x = tile_grid_size_3d
595 tile_depth = max(1, depth // tile_z)
596 overlap = tile_depth // 2
598 # Process volume in z-chunks
599 for z_start in range(0, depth, chunk_size - overlap):
600 z_end = min(z_start + chunk_size, depth)
602 # Extract chunk with context
603 chunk_start = max(0, z_start - overlap)
604 chunk_end = min(depth, z_end + overlap)
606 chunk = stack[chunk_start:chunk_end, :, :]
608 # Adjust tile grid for chunk
609 chunk_depth = chunk_end - chunk_start
610 chunk_tile_z = max(1, min(tile_z, chunk_depth // tile_depth))
611 chunk_tile_grid = (chunk_tile_z, tile_y, tile_x)
613 # Process chunk
614 chunk_result = _clahe_3d_vectorized(
615 chunk, clip_limit, chunk_tile_grid, nbins
616 )
618 # Extract the relevant part (without overlap)
619 extract_start = z_start - chunk_start
620 extract_end = extract_start + (z_end - z_start)
622 result[z_start:z_end, :, :] = chunk_result[extract_start:extract_end, :, :]
624 return result
626def _fill_3d_boundaries(
627 full_result: "cp.ndarray",
628 cropped_result: "cp.ndarray",
629 crop_depth: int,
630 crop_height: int,
631 crop_width: int,
632 depth: int,
633 height: int,
634 width: int
635) -> None:
636 """Efficiently fill boundary regions by replicating edge values."""
638 # Fill z-direction boundaries
639 if crop_depth < depth:
640 full_result[crop_depth:, :crop_height, :crop_width] = cropped_result[-1:, :, :]
642 # Fill y-direction boundaries
643 if crop_height < height:
644 full_result[:crop_depth, crop_height:, :crop_width] = cropped_result[:, -1:, :]
645 if crop_depth < depth:
646 full_result[crop_depth:, crop_height:, :crop_width] = cropped_result[-1:, -1:, :]
648 # Fill x-direction boundaries
649 if crop_width < width:
650 full_result[:crop_depth, :crop_height, crop_width:] = cropped_result[:, :, -1:]
651 if crop_height < height:
652 full_result[:crop_depth, crop_height:, crop_width:] = cropped_result[:, -1:, -1:]
653 if crop_depth < depth:
654 full_result[crop_depth:, :crop_height, crop_width:] = cropped_result[-1:, :, -1:]
655 if crop_depth < depth and crop_height < height:
656 full_result[crop_depth:, crop_height:, crop_width:] = cropped_result[-1, -1, -1]