Coverage for openhcs/processing/backends/processors/torch_processor.py: 9.9%
262 statements
« prev ^ index » next coverage.py v7.10.3, created at 2025-08-14 05:57 +0000
« prev ^ index » next coverage.py v7.10.3, created at 2025-08-14 05:57 +0000
1"""
2PyTorch Image Processor Implementation
4This module implements the ImageProcessorInterface using PyTorch as the backend.
5It leverages GPU acceleration for image processing operations.
7Doctrinal Clauses:
8- Clause 3 — Declarative Primacy: All functions are pure and stateless
9- Clause 88 — No Inferred Capabilities: Explicit PyTorch dependency
10- Clause 106-A — Declared Memory Types: All methods specify PyTorch tensors
11"""
12from __future__ import annotations
14import logging
15from typing import Any, List, Optional, Tuple
17from openhcs.core.utils import optional_import
18from openhcs.core.memory.decorators import torch as torch_func
20# Import PyTorch as an optional dependency
21torch = optional_import("torch")
22F = optional_import("torch.nn.functional") if torch is not None else None
23HAS_TORCH = torch is not None
25logger = logging.getLogger(__name__)
28def create_linear_weight_mask(height: int, width: int, margin_ratio: float = 0.1) -> "torch.Tensor":
29 """
30 Create a 2D weight mask that linearly ramps from 0 at the edges to 1 in the center.
32 Args:
33 height: Height of the mask
34 width: Width of the mask
35 margin_ratio: Ratio of the margin to the image size
37 Returns:
38 2D PyTorch weight mask of shape (height, width)
39 """
40 if torch is None:
41 raise ImportError("PyTorch is required for TorchImageProcessor")
43 margin_y = int(torch.floor(torch.tensor(height * margin_ratio)))
44 margin_x = int(torch.floor(torch.tensor(width * margin_ratio)))
46 weight_y = torch.ones(height, dtype=torch.float32)
47 if margin_y > 0:
48 ramp_top = torch.linspace(0, 1, margin_y, dtype=torch.float32)
49 ramp_bottom = torch.linspace(1, 0, margin_y, dtype=torch.float32)
50 weight_y[:margin_y] = ramp_top
51 weight_y[-margin_y:] = ramp_bottom
53 weight_x = torch.ones(width, dtype=torch.float32)
54 if margin_x > 0:
55 ramp_left = torch.linspace(0, 1, margin_x, dtype=torch.float32)
56 ramp_right = torch.linspace(1, 0, margin_x, dtype=torch.float32)
57 weight_x[:margin_x] = ramp_left
58 weight_x[-margin_x:] = ramp_right
60 # Create 2D weight mask using outer product
61 weight_mask = torch.outer(weight_y, weight_x)
63 return weight_mask
66def _validate_3d_array(array: Any, name: str = "input") -> None:
67 """
68 Validate that the input is a 3D PyTorch tensor.
70 Args:
71 array: Array to validate
72 name: Name of the array for error messages
74 Raises:
75 TypeError: If the array is not a PyTorch tensor
76 ValueError: If the array is not 3D
77 ImportError: If PyTorch is not available
78 """
79 if torch is None:
80 raise ImportError("PyTorch is required for TorchImageProcessor")
82 if not isinstance(array, torch.Tensor):
83 raise TypeError(f"{name} must be a PyTorch tensor, got {type(array)}. "
84 f"No automatic conversion is performed to maintain explicit contracts.")
86 if array.ndim != 3:
87 raise ValueError(f"{name} must be a 3D tensor, got {array.ndim}D")
89def _gaussian_blur(image: "torch.Tensor", sigma: float) -> "torch.Tensor":
90 """
91 Apply Gaussian blur to a 2D image.
93 Args:
94 image: 2D PyTorch tensor of shape (H, W)
95 sigma: Standard deviation of the Gaussian kernel
97 Returns:
98 Blurred 2D PyTorch tensor of shape (H, W)
99 """
100 # Calculate kernel size based on sigma
101 kernel_size = max(3, int(2 * 4 * sigma + 1))
102 if kernel_size % 2 == 0:
103 kernel_size += 1 # Ensure odd kernel size
105 # Create 1D Gaussian kernel
106 coords = torch.arange(kernel_size, dtype=torch.float32, device=image.device)
107 coords -= (kernel_size - 1) / 2
109 # Calculate Gaussian values
110 gauss = torch.exp(-(coords**2) / (2 * sigma**2))
111 kernel = gauss / gauss.sum()
113 # Reshape for 2D convolution
114 kernel_x = kernel.view(1, 1, kernel_size, 1)
115 kernel_y = kernel.view(1, 1, 1, kernel_size)
117 # Add batch and channel dimensions to image
118 img = image.unsqueeze(0).unsqueeze(0)
120 # Apply separable convolution
121 blurred = F.conv2d(img, kernel_x, padding=(kernel_size//2, 0))
122 blurred = F.conv2d(blurred, kernel_y, padding=(0, kernel_size//2))
124 # Remove batch and channel dimensions
125 return blurred.squeeze(0).squeeze(0)
127@torch_func
128def sharpen(image: "torch.Tensor", radius: float = 1.0, amount: float = 1.0) -> "torch.Tensor":
129 """
130 Sharpen a 3D image using unsharp masking.
132 This applies sharpening to each Z-slice independently.
134 Args:
135 image: 3D PyTorch tensor of shape (Z, Y, X)
136 radius: Radius of Gaussian blur
137 amount: Sharpening strength
139 Returns:
140 Sharpened 3D PyTorch tensor of shape (Z, Y, X)
141 """
142 _validate_3d_array(image)
144 # Store original dtype
145 dtype = image.dtype
147 # Process each Z-slice independently
148 result = torch.zeros_like(image, dtype=torch.float32)
150 for z in range(image.shape[0]):
151 # Convert to float for processing
152 slice_float_raw = image[z].float()
153 slice_float = slice_float_raw / torch.max(slice_float_raw)
155 # Create blurred version for unsharp mask
156 blurred = _gaussian_blur(slice_float, sigma=radius)
158 # Apply unsharp mask: original + amount * (original - blurred)
159 sharpened = slice_float + amount * (slice_float - blurred)
161 # Clip to valid range
162 sharpened = torch.clamp(sharpened, 0, 1.0)
164 # Scale back to original range
165 min_val = torch.min(sharpened)
166 max_val = torch.max(sharpened)
167 if max_val > min_val:
168 sharpened = (sharpened - min_val) * 65535 / (max_val - min_val)
170 result[z] = sharpened
172 # Convert back to original dtype
173 if dtype == torch.uint16:
174 result = torch.clamp(result, 0, 65535).to(torch.uint16)
175 else:
176 result = result.to(dtype)
178 return result
180@torch_func
181def percentile_normalize(
182 image: "torch.Tensor",
183 low_percentile: float = 1.0,
184 high_percentile: float = 99.0,
185 target_min: float = 0.0,
186 target_max: float = 65535.0
187) -> "torch.Tensor":
188 """
189 Normalize a 3D image using percentile-based contrast stretching.
191 This applies normalization to each Z-slice independently.
193 Args:
194 image: 3D PyTorch tensor of shape (Z, Y, X)
195 low_percentile: Lower percentile (0-100)
196 high_percentile: Upper percentile (0-100)
197 target_min: Target minimum value
198 target_max: Target maximum value
200 Returns:
201 Normalized 3D PyTorch tensor of shape (Z, Y, X)
202 """
203 _validate_3d_array(image)
205 # Process each Z-slice independently
206 result = torch.zeros_like(image, dtype=torch.float32)
208 for z in range(image.shape[0]):
209 # Get percentile values for this slice
210 # Handle large slices that exceed PyTorch's quantile() size limits
211 slice_float = image[z].float()
212 slice_elements = slice_float.numel()
214 # PyTorch quantile() fails on very large tensors, so we use sampling for large slices
215 max_elements_for_quantile = 10_000_000 # ~10M elements, conservative limit for quantile()
217 logger.debug(f"🔥 QUANTILE DEBUG: percentile_normalize slice {z} shape {image[z].shape}, {slice_elements:,} elements")
219 if slice_elements > max_elements_for_quantile:
220 # Use random sampling for large slices to estimate percentiles
221 sample_size = min(max_elements_for_quantile, slice_elements // 10) # Sample 10% or max size
222 flat_slice = slice_float.flatten()
224 # Generate random indices for sampling (memory efficient)
225 # Use torch.randint instead of torch.randperm to avoid creating huge tensors
226 indices = torch.randint(0, slice_elements, (sample_size,), device=image.device)
227 sampled_values = flat_slice[indices]
229 p_low = torch.quantile(sampled_values, low_percentile / 100.0)
230 p_high = torch.quantile(sampled_values, high_percentile / 100.0)
231 else:
232 # Use full slice for smaller slices
233 p_low = torch.quantile(slice_float, low_percentile / 100.0)
234 p_high = torch.quantile(slice_float, high_percentile / 100.0)
236 # Avoid division by zero
237 if p_high == p_low:
238 result[z] = torch.ones_like(image[z], dtype=torch.float32) * target_min
239 continue
241 # Clip and normalize to target range
242 clipped = torch.clamp(image[z].float(), p_low, p_high)
243 scale = (target_max - target_min) / (p_high - p_low)
244 normalized = (clipped - p_low) * scale + target_min
245 result[z] = normalized
247 # Convert to uint16
248 result = torch.clamp(result, 0, 65535).to(torch.uint16)
250 return result
252@torch_func
253def stack_percentile_normalize(
254 stack: "torch.Tensor",
255 low_percentile: float = 1.0,
256 high_percentile: float = 99.0,
257 target_min: float = 0.0,
258 target_max: float = 65535.0
259) -> "torch.Tensor":
260 """
261 Normalize a stack using global percentile-based contrast stretching.
263 This ensures consistent normalization across all Z-slices by computing
264 global percentiles across the entire stack.
266 Args:
267 stack: 3D PyTorch tensor of shape (Z, Y, X)
268 low_percentile: Lower percentile (0-100)
269 high_percentile: Upper percentile (0-100)
270 target_min: Target minimum value
271 target_max: Target maximum value
273 Returns:
274 Normalized 3D PyTorch tensor of shape (Z, Y, X)
275 """
276 _validate_3d_array(stack)
278 # Calculate global percentiles across the entire stack
279 # Handle large tensors that exceed PyTorch's quantile() size limits
280 stack_float = stack.float()
281 total_elements = stack_float.numel()
283 # PyTorch quantile() fails on very large tensors, so we use sampling for large stacks
284 max_elements_for_quantile = 10_000_000 # ~10M elements, conservative limit for quantile()
286 logger.debug(f"🔥 QUANTILE DEBUG: stack_percentile_normalize called with tensor shape {stack.shape}, {total_elements:,} elements")
288 if total_elements > max_elements_for_quantile:
289 # Use random sampling for large tensors to estimate percentiles
290 sample_size = min(max_elements_for_quantile, total_elements // 10) # Sample 10% or max size
291 flat_stack = stack_float.flatten()
293 # Generate random indices for sampling (memory efficient)
294 # Use torch.randint instead of torch.randperm to avoid creating huge tensors
295 indices = torch.randint(0, total_elements, (sample_size,), device=stack.device)
296 sampled_values = flat_stack[indices]
298 p_low = torch.quantile(sampled_values, low_percentile / 100.0)
299 p_high = torch.quantile(sampled_values, high_percentile / 100.0)
301 logger.debug(f"Used sampling ({sample_size:,} of {total_elements:,} elements) for percentile calculation due to large tensor size")
302 else:
303 # Use full tensor for smaller stacks
304 p_low = torch.quantile(stack_float, low_percentile / 100.0)
305 p_high = torch.quantile(stack_float, high_percentile / 100.0)
307 # Avoid division by zero
308 if p_high == p_low:
309 return torch.ones_like(stack) * target_min
311 # Clip and normalize to target range (match NumPy implementation exactly)
312 clipped = torch.clamp(stack, p_low, p_high)
313 normalized = (clipped - p_low) * (target_max - target_min) / (p_high - p_low) + target_min
314 normalized = normalized.to(torch.uint16)
316 return normalized
318@torch_func
319def create_composite(
320 stack: "torch.Tensor", weights: Optional[List[float]] = None
321) -> "torch.Tensor":
322 """
323 Create a composite image from a 3D stack of 2D images.
325 Args:
326 stack: 3D PyTorch tensor of shape (N, Y, X) where N is number of images
327 weights: List of weights for each image. If None, equal weights are used.
329 Returns:
330 Composite 3D PyTorch tensor of shape (1, Y, X)
331 """
332 # Validate input is 3D tensor
333 _validate_3d_array(stack)
335 n_images, height, width = stack.shape
337 # Default weights if none provided
338 if weights is None:
339 # Equal weights for all images
340 weights = [1.0 / n_images] * n_images
341 elif not isinstance(weights, list):
342 raise TypeError("weights must be a list of values")
344 # FAIL FAST: No fallback weights - weights must match exactly
345 if len(weights) != n_images:
346 raise ValueError(
347 f"Weights list length ({len(weights)}) must exactly match number of images ({n_images}). "
348 f"No automatic padding or truncation allowed."
349 )
351 dtype = stack.dtype
352 device = stack.device
354 # Create empty composite
355 composite = torch.zeros((height, width), dtype=torch.float32, device=device)
356 total_weight = 0.0
358 # Add each image with its weight
359 for i in range(n_images):
360 weight = weights[i]
361 if weight <= 0.0:
362 continue
364 # Add to composite
365 composite += stack[i].float() * weight
366 total_weight += weight
368 # Normalize by total weight
369 if total_weight > 0:
370 composite /= total_weight
372 # Convert back to original dtype (usually uint16)
373 if dtype in [torch.uint8, torch.uint16, torch.uint32, torch.int8, torch.int16, torch.int32, torch.int64]:
374 # Get the maximum value for the specific integer dtype
375 if dtype == torch.uint8:
376 max_val = 255
377 elif dtype == torch.uint16:
378 max_val = 65535
379 elif dtype == torch.uint32:
380 max_val = 4294967295
381 elif dtype == torch.int8:
382 max_val = 127
383 elif dtype == torch.int16:
384 max_val = 32767
385 elif dtype == torch.int32:
386 max_val = 2147483647
387 elif dtype == torch.int64:
388 max_val = 9223372036854775807
390 composite = torch.clamp(composite, 0, max_val).to(dtype)
391 else:
392 composite = composite.to(dtype)
394 # Return as 3D tensor with shape (1, Y, X)
395 return composite.reshape(1, height, width)
397@torch_func
398def apply_mask(image: "torch.Tensor", mask: "torch.Tensor") -> "torch.Tensor":
399 """
400 Apply a mask to a 3D image while maintaining 3D structure.
402 This applies the mask to each Z-slice independently if mask is 2D,
403 or applies the 3D mask directly if mask is 3D.
405 Args:
406 image: 3D PyTorch tensor of shape (Z, Y, X)
407 mask: 3D PyTorch tensor of shape (Z, Y, X) or 2D PyTorch tensor of shape (Y, X)
409 Returns:
410 Masked 3D PyTorch tensor of shape (Z, Y, X) - dimensionality preserved
411 """
412 _validate_3d_array(image)
414 # Handle 2D mask (apply to each Z-slice)
415 if isinstance(mask, torch.Tensor) and mask.ndim == 2:
416 if mask.shape != image.shape[1:]:
417 raise ValueError(
418 f"2D mask shape {mask.shape} doesn't match image slice shape {image.shape[1:]}"
419 )
421 # Apply 2D mask to each Z-slice
422 result = torch.zeros_like(image)
423 for z in range(image.shape[0]):
424 result[z] = image[z].float() * mask.float()
426 return result.to(image.dtype)
428 # Handle 3D mask
429 if isinstance(mask, torch.Tensor) and mask.ndim == 3:
430 if mask.shape != image.shape:
431 raise ValueError(
432 f"3D mask shape {mask.shape} doesn't match image shape {image.shape}"
433 )
435 # Apply 3D mask directly
436 masked = image.float() * mask.float()
437 return masked.to(image.dtype)
439 # If we get here, the mask is neither 2D nor 3D PyTorch tensor
440 raise TypeError(f"mask must be a 2D or 3D PyTorch tensor, got {type(mask)}")
442@torch_func
443def create_weight_mask(
444 shape: Tuple[int, int], margin_ratio: float = 0.1
445) -> "torch.Tensor":
446 """
447 Create a weight mask for blending images.
449 Args:
450 shape: Shape of the mask (height, width)
451 margin_ratio: Ratio of image size to use as margin
453 Returns:
454 2D PyTorch weight mask of shape (Y, X)
455 """
456 if not isinstance(shape, tuple) or len(shape) != 2:
457 raise TypeError("shape must be a tuple of (height, width)")
459 height, width = shape
460 return create_linear_weight_mask(height, width, margin_ratio)
462@torch_func
463def max_projection(stack: "torch.Tensor") -> "torch.Tensor":
464 """
465 Create a maximum intensity projection from a Z-stack.
467 Args:
468 stack: 3D PyTorch tensor of shape (Z, Y, X)
470 Returns:
471 3D PyTorch tensor of shape (1, Y, X)
472 """
473 _validate_3d_array(stack)
475 # Store original dtype for conversion back
476 original_dtype = stack.dtype
478 # Convert to float32 if needed for GPU operations
479 if stack.dtype == torch.uint16:
480 stack_float = stack.float()
481 else:
482 stack_float = stack
484 # Create max projection
485 projection_2d = torch.max(stack_float, dim=0)[0]
487 # Convert back to original dtype
488 projection_2d = projection_2d.to(original_dtype)
490 return projection_2d.reshape(1, projection_2d.shape[0], projection_2d.shape[1])
492@torch_func
493def mean_projection(stack: "torch.Tensor") -> "torch.Tensor":
494 """
495 Create a mean intensity projection from a Z-stack.
497 Args:
498 stack: 3D PyTorch tensor of shape (Z, Y, X)
500 Returns:
501 3D PyTorch tensor of shape (1, Y, X)
502 """
503 _validate_3d_array(stack)
505 # Store original dtype for conversion back
506 original_dtype = stack.dtype
508 # Convert to float32 for mean calculation (always needed for mean)
509 stack_float = stack.float()
511 # Create mean projection
512 projection_2d = torch.mean(stack_float, dim=0)
514 # Convert back to original dtype
515 projection_2d = projection_2d.to(original_dtype)
517 return projection_2d.reshape(1, projection_2d.shape[0], projection_2d.shape[1])
519@torch_func
520def stack_equalize_histogram(
521 stack: "torch.Tensor",
522 bins: int = 65536,
523 range_min: float = 0.0,
524 range_max: float = 65535.0
525) -> "torch.Tensor":
526 """
527 Apply histogram equalization to an entire stack.
529 This ensures consistent contrast enhancement across all Z-slices by
530 computing a global histogram across the entire stack.
532 Args:
533 stack: 3D PyTorch tensor of shape (Z, Y, X)
534 bins: Number of bins for histogram computation
535 range_min: Minimum value for histogram range
536 range_max: Maximum value for histogram range
538 Returns:
539 Equalized 3D PyTorch tensor of shape (Z, Y, X)
540 """
541 _validate_3d_array(stack)
543 # PyTorch doesn't have a direct histogram equalization function
544 # We'll implement it manually using torch.histc for the histogram
546 # Flatten the entire stack to compute the global histogram
547 flat_stack = stack.float().flatten()
549 # For very large stacks, use sampling to avoid memory issues
550 max_elements_for_histogram = 50_000_000 # 50M elements limit
551 if flat_stack.numel() > max_elements_for_histogram:
552 # Use random sampling for histogram computation
553 sample_size = max_elements_for_histogram
554 indices = torch.randint(0, flat_stack.numel(), (sample_size,), device=stack.device)
555 sampled_stack = flat_stack[indices]
556 hist = torch.histc(sampled_stack, bins=bins, min=range_min, max=range_max)
557 logger.debug(f"Used sampling ({sample_size:,} of {flat_stack.numel():,} elements) for histogram computation")
558 else:
559 # Use full stack for smaller stacks
560 hist = torch.histc(flat_stack, bins=bins, min=range_min, max=range_max)
562 # We don't need bin edges for the lookup table approach
564 # Calculate cumulative distribution function (CDF)
565 cdf = torch.cumsum(hist, dim=0)
567 # Normalize the CDF to the range [0, 65535]
568 # Avoid division by zero
569 if cdf[-1] > 0:
570 cdf = 65535 * cdf / cdf[-1]
572 # PyTorch doesn't have a direct equivalent to numpy's interp
573 # We'll use a lookup table approach
575 # Scale input values to bin indices
576 indices = torch.clamp(
577 ((flat_stack - range_min) / (range_max - range_min) * (bins - 1)).long(),
578 0, bins - 1
579 )
581 # Look up CDF values
582 equalized_flat = torch.gather(cdf, 0, indices)
584 # Reshape back to original shape
585 equalized_stack = equalized_flat.reshape(stack.shape)
587 # Convert to uint16
588 return equalized_stack.to(torch.uint16)
590@torch_func
591def create_projection(
592 stack: "torch.Tensor", method: str = "max_projection"
593) -> "torch.Tensor":
594 """
595 Create a projection from a stack using the specified method.
597 Args:
598 stack: 3D PyTorch tensor of shape (Z, Y, X)
599 method: Projection method (max_projection, mean_projection)
601 Returns:
602 3D PyTorch tensor of shape (1, Y, X)
603 """
604 _validate_3d_array(stack)
606 if method == "max_projection":
607 return max_projection(stack)
609 if method == "mean_projection":
610 return mean_projection(stack)
612 # FAIL FAST: No fallback projection methods
613 raise ValueError(f"Unknown projection method: {method}. Valid methods: max_projection, mean_projection")
615@torch_func
616def tophat(
617 image: "torch.Tensor",
618 selem_radius: int = 50,
619 downsample_factor: int = 4
620) -> "torch.Tensor":
621 """
622 Apply white top-hat filter to a 3D image for background removal.
624 This applies the filter to each Z-slice independently using PyTorch's
625 native operations.
627 Args:
628 image: 3D PyTorch tensor of shape (Z, Y, X)
629 selem_radius: Radius of the structuring element disk
630 downsample_factor: Factor by which to downsample the image for processing
632 Returns:
633 Filtered 3D PyTorch tensor of shape (Z, Y, X)
634 """
635 _validate_3d_array(image)
637 # Store device for later use
638 device = image.device
640 # Process each Z-slice independently
641 result = torch.zeros_like(image)
643 # We'll create structuring elements for each slice as needed
645 for z in range(image.shape[0]):
646 # Store original data type
647 input_dtype = image[z].dtype
649 # 1) Downsample using PyTorch's interpolate function
650 # First, add batch and channel dimensions for interpolate
651 img_4d = image[z].float().unsqueeze(0).unsqueeze(0)
653 # Calculate new dimensions
654 new_h = image[z].shape[0] // downsample_factor
655 new_w = image[z].shape[1] // downsample_factor
657 # Resize using PyTorch's interpolate function
658 image_small = F.interpolate(
659 img_4d,
660 size=(new_h, new_w),
661 mode='bilinear',
662 align_corners=False
663 ).squeeze(0).squeeze(0)
665 # 2) Resize the structuring element to match the downsampled image
666 small_selem_radius = max(1, selem_radius // downsample_factor)
667 small_grid_size = 2 * small_selem_radius + 1
668 small_grid_y, small_grid_x = torch.meshgrid(
669 torch.arange(small_grid_size, device=device) - small_selem_radius,
670 torch.arange(small_grid_size, device=device) - small_selem_radius,
671 indexing='ij'
672 )
673 small_mask = (small_grid_x.pow(2) + small_grid_y.pow(2)) <= small_selem_radius**2
674 small_selem = small_mask.float()
676 # 3) Apply white top-hat using PyTorch's convolution operations
677 # White top-hat is opening subtracted from the original image
678 # Opening is erosion followed by dilation
680 # Implement erosion using min pooling with custom kernel
681 # First, pad the image to handle boundary conditions
682 pad_size = small_selem_radius
683 padded = F.pad(
684 image_small.unsqueeze(0).unsqueeze(0),
685 (pad_size, pad_size, pad_size, pad_size),
686 mode='reflect'
687 )
689 # Unfold the padded image into patches
690 patches = F.unfold(padded, kernel_size=small_grid_size, stride=1)
692 # Reshape patches for processing
693 patch_size = small_grid_size * small_grid_size
694 patches = patches.reshape(1, patch_size, new_h, new_w)
696 # Apply the structuring element as a mask
697 masked_patches = patches * small_selem.reshape(-1, 1, 1)
699 # Perform erosion (min pooling)
700 eroded = torch.min(
701 masked_patches + (1 - small_selem.reshape(-1, 1, 1)) * 1e9,
702 dim=1
703 )[0]
705 # Implement dilation using max pooling with custom kernel
706 # Pad the eroded image
707 padded_eroded = F.pad(
708 eroded.unsqueeze(0).unsqueeze(0),
709 (pad_size, pad_size, pad_size, pad_size),
710 mode='reflect'
711 )
713 # Unfold the padded eroded image into patches
714 patches_eroded = F.unfold(padded_eroded, kernel_size=small_grid_size, stride=1)
716 # Reshape patches for processing
717 patch_size = small_grid_size * small_grid_size
718 patches_eroded = patches_eroded.reshape(1, patch_size, new_h, new_w)
720 # Apply the structuring element as a mask
721 masked_patches_eroded = patches_eroded * small_selem.reshape(-1, 1, 1)
723 # Perform dilation (max pooling)
724 opened = torch.max(masked_patches_eroded, dim=1)[0]
726 # White top-hat is original minus opening
727 tophat_small = image_small - opened
729 # 4) Calculate background
730 background_small = image_small - tophat_small
732 # 5) Upscale background to original size
733 background_4d = background_small.unsqueeze(0).unsqueeze(0)
734 background_large = F.interpolate(
735 background_4d,
736 size=image[z].shape,
737 mode='bilinear',
738 align_corners=False
739 ).squeeze(0).squeeze(0)
741 # 6) Subtract background and clip negative values
742 slice_result = torch.clamp(image[z].float() - background_large, min=0.0)
744 # 7) Convert back to original data type
745 result[z] = slice_result.to(input_dtype)
747 return result