Coverage for openhcs/processing/backends/processors/torch_processor.py: 10.5%
268 statements
« prev ^ index » next coverage.py v7.10.7, created at 2025-10-01 18:33 +0000
« prev ^ index » next coverage.py v7.10.7, created at 2025-10-01 18:33 +0000
1"""
2PyTorch Image Processor Implementation
4This module implements the ImageProcessorInterface using PyTorch as the backend.
5It leverages GPU acceleration for image processing operations.
7Doctrinal Clauses:
8- Clause 3 — Declarative Primacy: All functions are pure and stateless
9- Clause 88 — No Inferred Capabilities: Explicit PyTorch dependency
10- Clause 106-A — Declared Memory Types: All methods specify PyTorch tensors
11"""
12from __future__ import annotations
14import logging
15import os
16from typing import Any, List, Optional, Tuple
18from openhcs.core.utils import optional_import
19from openhcs.core.memory.decorators import torch as torch_func
21logger = logging.getLogger(__name__)
23# Check if we're in subprocess runner mode and should skip GPU imports
24if os.getenv('OPENHCS_SUBPROCESS_NO_GPU') == '1': 24 ↛ 26line 24 didn't jump to line 26 because the condition on line 24 was never true
25 # Subprocess runner mode - skip GPU imports
26 torch = None
27 F = None
28 HAS_TORCH = False
29 logger.info("Subprocess runner mode - skipping torch import")
30else:
31 # Normal mode - import PyTorch as an optional dependency
32 torch = optional_import("torch")
33 F = optional_import("torch.nn.functional") if torch is not None else None
34 HAS_TORCH = torch is not None
37def create_linear_weight_mask(height: int, width: int, margin_ratio: float = 0.1) -> "torch.Tensor":
38 """
39 Create a 2D weight mask that linearly ramps from 0 at the edges to 1 in the center.
41 Args:
42 height: Height of the mask
43 width: Width of the mask
44 margin_ratio: Ratio of the margin to the image size
46 Returns:
47 2D PyTorch weight mask of shape (height, width)
48 """
49 if torch is None:
50 raise ImportError("PyTorch is required for TorchImageProcessor")
52 margin_y = int(torch.floor(torch.tensor(height * margin_ratio)))
53 margin_x = int(torch.floor(torch.tensor(width * margin_ratio)))
55 weight_y = torch.ones(height, dtype=torch.float32)
56 if margin_y > 0:
57 ramp_top = torch.linspace(0, 1, margin_y, dtype=torch.float32)
58 ramp_bottom = torch.linspace(1, 0, margin_y, dtype=torch.float32)
59 weight_y[:margin_y] = ramp_top
60 weight_y[-margin_y:] = ramp_bottom
62 weight_x = torch.ones(width, dtype=torch.float32)
63 if margin_x > 0:
64 ramp_left = torch.linspace(0, 1, margin_x, dtype=torch.float32)
65 ramp_right = torch.linspace(1, 0, margin_x, dtype=torch.float32)
66 weight_x[:margin_x] = ramp_left
67 weight_x[-margin_x:] = ramp_right
69 # Create 2D weight mask using outer product
70 weight_mask = torch.outer(weight_y, weight_x)
72 return weight_mask
75def _validate_3d_array(array: Any, name: str = "input") -> None:
76 """
77 Validate that the input is a 3D PyTorch tensor.
79 Args:
80 array: Array to validate
81 name: Name of the array for error messages
83 Raises:
84 TypeError: If the array is not a PyTorch tensor
85 ValueError: If the array is not 3D
86 ImportError: If PyTorch is not available
87 """
88 if torch is None:
89 raise ImportError("PyTorch is required for TorchImageProcessor")
91 if not isinstance(array, torch.Tensor):
92 raise TypeError(f"{name} must be a PyTorch tensor, got {type(array)}. "
93 f"No automatic conversion is performed to maintain explicit contracts.")
95 if array.ndim != 3:
96 raise ValueError(f"{name} must be a 3D tensor, got {array.ndim}D")
98def _gaussian_blur(image: "torch.Tensor", sigma: float) -> "torch.Tensor":
99 """
100 Apply Gaussian blur to a 2D image.
102 Args:
103 image: 2D PyTorch tensor of shape (H, W)
104 sigma: Standard deviation of the Gaussian kernel
106 Returns:
107 Blurred 2D PyTorch tensor of shape (H, W)
108 """
109 # Calculate kernel size based on sigma
110 kernel_size = max(3, int(2 * 4 * sigma + 1))
111 if kernel_size % 2 == 0:
112 kernel_size += 1 # Ensure odd kernel size
114 # Create 1D Gaussian kernel
115 coords = torch.arange(kernel_size, dtype=torch.float32, device=image.device)
116 coords -= (kernel_size - 1) / 2
118 # Calculate Gaussian values
119 gauss = torch.exp(-(coords**2) / (2 * sigma**2))
120 kernel = gauss / gauss.sum()
122 # Reshape for 2D convolution
123 kernel_x = kernel.view(1, 1, kernel_size, 1)
124 kernel_y = kernel.view(1, 1, 1, kernel_size)
126 # Add batch and channel dimensions to image
127 img = image.unsqueeze(0).unsqueeze(0)
129 # Apply separable convolution
130 blurred = F.conv2d(img, kernel_x, padding=(kernel_size//2, 0))
131 blurred = F.conv2d(blurred, kernel_y, padding=(0, kernel_size//2))
133 # Remove batch and channel dimensions
134 return blurred.squeeze(0).squeeze(0)
136@torch_func
137def sharpen(image: "torch.Tensor", radius: float = 1.0, amount: float = 1.0) -> "torch.Tensor":
138 """
139 Sharpen a 3D image using unsharp masking.
141 This applies sharpening to each Z-slice independently.
143 Args:
144 image: 3D PyTorch tensor of shape (Z, Y, X)
145 radius: Radius of Gaussian blur
146 amount: Sharpening strength
148 Returns:
149 Sharpened 3D PyTorch tensor of shape (Z, Y, X)
150 """
151 _validate_3d_array(image)
153 # Store original dtype
154 dtype = image.dtype
156 # Process each Z-slice independently
157 result = torch.zeros_like(image, dtype=torch.float32)
159 for z in range(image.shape[0]):
160 # Convert to float for processing
161 slice_float_raw = image[z].float()
162 slice_float = slice_float_raw / torch.max(slice_float_raw)
164 # Create blurred version for unsharp mask
165 blurred = _gaussian_blur(slice_float, sigma=radius)
167 # Apply unsharp mask: original + amount * (original - blurred)
168 sharpened = slice_float + amount * (slice_float - blurred)
170 # Clip to valid range
171 sharpened = torch.clamp(sharpened, 0, 1.0)
173 # Scale back to original range
174 min_val = torch.min(sharpened)
175 max_val = torch.max(sharpened)
176 if max_val > min_val:
177 sharpened = (sharpened - min_val) * 65535 / (max_val - min_val)
179 result[z] = sharpened
181 # Convert back to original dtype
182 if dtype == torch.uint16:
183 result = torch.clamp(result, 0, 65535).to(torch.uint16)
184 else:
185 result = result.to(dtype)
187 return result
189@torch_func
190def percentile_normalize(
191 image: "torch.Tensor",
192 low_percentile: float = 1.0,
193 high_percentile: float = 99.0,
194 target_min: float = 0.0,
195 target_max: float = 65535.0
196) -> "torch.Tensor":
197 """
198 Normalize a 3D image using percentile-based contrast stretching.
200 This applies normalization to each Z-slice independently.
202 Args:
203 image: 3D PyTorch tensor of shape (Z, Y, X)
204 low_percentile: Lower percentile (0-100)
205 high_percentile: Upper percentile (0-100)
206 target_min: Target minimum value
207 target_max: Target maximum value
209 Returns:
210 Normalized 3D PyTorch tensor of shape (Z, Y, X)
211 """
212 _validate_3d_array(image)
214 # Process each Z-slice independently
215 result = torch.zeros_like(image, dtype=torch.float32)
217 for z in range(image.shape[0]):
218 # Get percentile values for this slice
219 # Handle large slices that exceed PyTorch's quantile() size limits
220 slice_float = image[z].float()
221 slice_elements = slice_float.numel()
223 # PyTorch quantile() fails on very large tensors, so we use sampling for large slices
224 max_elements_for_quantile = 10_000_000 # ~10M elements, conservative limit for quantile()
226 logger.debug(f"🔥 QUANTILE DEBUG: percentile_normalize slice {z} shape {image[z].shape}, {slice_elements:,} elements")
228 if slice_elements > max_elements_for_quantile:
229 # Use random sampling for large slices to estimate percentiles
230 sample_size = min(max_elements_for_quantile, slice_elements // 10) # Sample 10% or max size
231 flat_slice = slice_float.flatten()
233 # Generate random indices for sampling (memory efficient)
234 # Use torch.randint instead of torch.randperm to avoid creating huge tensors
235 indices = torch.randint(0, slice_elements, (sample_size,), device=image.device)
236 sampled_values = flat_slice[indices]
238 p_low = torch.quantile(sampled_values, low_percentile / 100.0)
239 p_high = torch.quantile(sampled_values, high_percentile / 100.0)
240 else:
241 # Use full slice for smaller slices
242 p_low = torch.quantile(slice_float, low_percentile / 100.0)
243 p_high = torch.quantile(slice_float, high_percentile / 100.0)
245 # Avoid division by zero
246 if p_high == p_low:
247 result[z] = torch.ones_like(image[z], dtype=torch.float32) * target_min
248 continue
250 # Clip and normalize to target range
251 clipped = torch.clamp(image[z].float(), p_low, p_high)
252 scale = (target_max - target_min) / (p_high - p_low)
253 normalized = (clipped - p_low) * scale + target_min
254 result[z] = normalized
256 # Convert to uint16
257 result = torch.clamp(result, 0, 65535).to(torch.uint16)
259 return result
261@torch_func
262def stack_percentile_normalize(
263 stack: "torch.Tensor",
264 low_percentile: float = 1.0,
265 high_percentile: float = 99.0,
266 target_min: float = 0.0,
267 target_max: float = 65535.0
268) -> "torch.Tensor":
269 """
270 Normalize a stack using global percentile-based contrast stretching.
272 This ensures consistent normalization across all Z-slices by computing
273 global percentiles across the entire stack.
275 Args:
276 stack: 3D PyTorch tensor of shape (Z, Y, X)
277 low_percentile: Lower percentile (0-100)
278 high_percentile: Upper percentile (0-100)
279 target_min: Target minimum value
280 target_max: Target maximum value
282 Returns:
283 Normalized 3D PyTorch tensor of shape (Z, Y, X)
284 """
285 _validate_3d_array(stack)
287 # Calculate global percentiles across the entire stack
288 # Handle large tensors that exceed PyTorch's quantile() size limits
289 stack_float = stack.float()
290 total_elements = stack_float.numel()
292 # PyTorch quantile() fails on very large tensors, so we use sampling for large stacks
293 max_elements_for_quantile = 10_000_000 # ~10M elements, conservative limit for quantile()
295 logger.debug(f"🔥 QUANTILE DEBUG: stack_percentile_normalize called with tensor shape {stack.shape}, {total_elements:,} elements")
297 if total_elements > max_elements_for_quantile:
298 # Use random sampling for large tensors to estimate percentiles
299 sample_size = min(max_elements_for_quantile, total_elements // 10) # Sample 10% or max size
300 flat_stack = stack_float.flatten()
302 # Generate random indices for sampling (memory efficient)
303 # Use torch.randint instead of torch.randperm to avoid creating huge tensors
304 indices = torch.randint(0, total_elements, (sample_size,), device=stack.device)
305 sampled_values = flat_stack[indices]
307 p_low = torch.quantile(sampled_values, low_percentile / 100.0)
308 p_high = torch.quantile(sampled_values, high_percentile / 100.0)
310 logger.debug(f"Used sampling ({sample_size:,} of {total_elements:,} elements) for percentile calculation due to large tensor size")
311 else:
312 # Use full tensor for smaller stacks
313 p_low = torch.quantile(stack_float, low_percentile / 100.0)
314 p_high = torch.quantile(stack_float, high_percentile / 100.0)
316 # Avoid division by zero
317 if p_high == p_low:
318 return torch.ones_like(stack) * target_min
320 # Clip and normalize to target range (match NumPy implementation exactly)
321 clipped = torch.clamp(stack, p_low, p_high)
322 normalized = (clipped - p_low) * (target_max - target_min) / (p_high - p_low) + target_min
323 normalized = normalized.to(torch.uint16)
325 return normalized
327@torch_func
328def create_composite(
329 stack: "torch.Tensor", weights: Optional[List[float]] = None
330) -> "torch.Tensor":
331 """
332 Create a composite image from a 3D stack of 2D images.
334 Args:
335 stack: 3D PyTorch tensor of shape (N, Y, X) where N is number of images
336 weights: List of weights for each image. If None, equal weights are used.
338 Returns:
339 Composite 3D PyTorch tensor of shape (1, Y, X)
340 """
341 # Validate input is 3D tensor
342 _validate_3d_array(stack)
344 n_images, height, width = stack.shape
346 # Default weights if none provided
347 if weights is None:
348 # Equal weights for all images
349 weights = [1.0 / n_images] * n_images
350 elif not isinstance(weights, list):
351 raise TypeError("weights must be a list of values")
353 # FAIL FAST: No fallback weights - weights must match exactly
354 if len(weights) != n_images:
355 raise ValueError(
356 f"Weights list length ({len(weights)}) must exactly match number of images ({n_images}). "
357 f"No automatic padding or truncation allowed."
358 )
360 dtype = stack.dtype
361 device = stack.device
363 # Create empty composite
364 composite = torch.zeros((height, width), dtype=torch.float32, device=device)
365 total_weight = 0.0
367 # Add each image with its weight
368 for i in range(n_images):
369 weight = weights[i]
370 if weight <= 0.0:
371 continue
373 # Add to composite
374 composite += stack[i].float() * weight
375 total_weight += weight
377 # Normalize by total weight
378 if total_weight > 0:
379 composite /= total_weight
381 # Convert back to original dtype (usually uint16)
382 if dtype in [torch.uint8, torch.uint16, torch.uint32, torch.int8, torch.int16, torch.int32, torch.int64]:
383 # Get the maximum value for the specific integer dtype
384 if dtype == torch.uint8:
385 max_val = 255
386 elif dtype == torch.uint16:
387 max_val = 65535
388 elif dtype == torch.uint32:
389 max_val = 4294967295
390 elif dtype == torch.int8:
391 max_val = 127
392 elif dtype == torch.int16:
393 max_val = 32767
394 elif dtype == torch.int32:
395 max_val = 2147483647
396 elif dtype == torch.int64:
397 max_val = 9223372036854775807
399 composite = torch.clamp(composite, 0, max_val).to(dtype)
400 else:
401 composite = composite.to(dtype)
403 # Return as 3D tensor with shape (1, Y, X)
404 return composite.reshape(1, height, width)
406@torch_func
407def apply_mask(image: "torch.Tensor", mask: "torch.Tensor") -> "torch.Tensor":
408 """
409 Apply a mask to a 3D image while maintaining 3D structure.
411 This applies the mask to each Z-slice independently if mask is 2D,
412 or applies the 3D mask directly if mask is 3D.
414 Args:
415 image: 3D PyTorch tensor of shape (Z, Y, X)
416 mask: 3D PyTorch tensor of shape (Z, Y, X) or 2D PyTorch tensor of shape (Y, X)
418 Returns:
419 Masked 3D PyTorch tensor of shape (Z, Y, X) - dimensionality preserved
420 """
421 _validate_3d_array(image)
423 # Handle 2D mask (apply to each Z-slice)
424 if isinstance(mask, torch.Tensor) and mask.ndim == 2:
425 if mask.shape != image.shape[1:]:
426 raise ValueError(
427 f"2D mask shape {mask.shape} doesn't match image slice shape {image.shape[1:]}"
428 )
430 # Apply 2D mask to each Z-slice
431 result = torch.zeros_like(image)
432 for z in range(image.shape[0]):
433 result[z] = image[z].float() * mask.float()
435 return result.to(image.dtype)
437 # Handle 3D mask
438 if isinstance(mask, torch.Tensor) and mask.ndim == 3:
439 if mask.shape != image.shape:
440 raise ValueError(
441 f"3D mask shape {mask.shape} doesn't match image shape {image.shape}"
442 )
444 # Apply 3D mask directly
445 masked = image.float() * mask.float()
446 return masked.to(image.dtype)
448 # If we get here, the mask is neither 2D nor 3D PyTorch tensor
449 raise TypeError(f"mask must be a 2D or 3D PyTorch tensor, got {type(mask)}")
451@torch_func
452def create_weight_mask(
453 shape: Tuple[int, int], margin_ratio: float = 0.1
454) -> "torch.Tensor":
455 """
456 Create a weight mask for blending images.
458 Args:
459 shape: Shape of the mask (height, width)
460 margin_ratio: Ratio of image size to use as margin
462 Returns:
463 2D PyTorch weight mask of shape (Y, X)
464 """
465 if not isinstance(shape, tuple) or len(shape) != 2:
466 raise TypeError("shape must be a tuple of (height, width)")
468 height, width = shape
469 return create_linear_weight_mask(height, width, margin_ratio)
471@torch_func
472def max_projection(stack: "torch.Tensor") -> "torch.Tensor":
473 """
474 Create a maximum intensity projection from a Z-stack.
476 Args:
477 stack: 3D PyTorch tensor of shape (Z, Y, X)
479 Returns:
480 3D PyTorch tensor of shape (1, Y, X)
481 """
482 _validate_3d_array(stack)
484 # Store original dtype for conversion back
485 original_dtype = stack.dtype
487 # Convert to float32 if needed for GPU operations
488 if stack.dtype == torch.uint16:
489 stack_float = stack.float()
490 else:
491 stack_float = stack
493 # Create max projection
494 projection_2d = torch.max(stack_float, dim=0)[0]
496 # Convert back to original dtype
497 projection_2d = projection_2d.to(original_dtype)
499 return projection_2d.reshape(1, projection_2d.shape[0], projection_2d.shape[1])
501@torch_func
502def mean_projection(stack: "torch.Tensor") -> "torch.Tensor":
503 """
504 Create a mean intensity projection from a Z-stack.
506 Args:
507 stack: 3D PyTorch tensor of shape (Z, Y, X)
509 Returns:
510 3D PyTorch tensor of shape (1, Y, X)
511 """
512 _validate_3d_array(stack)
514 # Store original dtype for conversion back
515 original_dtype = stack.dtype
517 # Convert to float32 for mean calculation (always needed for mean)
518 stack_float = stack.float()
520 # Create mean projection
521 projection_2d = torch.mean(stack_float, dim=0)
523 # Convert back to original dtype
524 projection_2d = projection_2d.to(original_dtype)
526 return projection_2d.reshape(1, projection_2d.shape[0], projection_2d.shape[1])
528@torch_func
529def stack_equalize_histogram(
530 stack: "torch.Tensor",
531 bins: int = 65536,
532 range_min: float = 0.0,
533 range_max: float = 65535.0
534) -> "torch.Tensor":
535 """
536 Apply histogram equalization to an entire stack.
538 This ensures consistent contrast enhancement across all Z-slices by
539 computing a global histogram across the entire stack.
541 Args:
542 stack: 3D PyTorch tensor of shape (Z, Y, X)
543 bins: Number of bins for histogram computation
544 range_min: Minimum value for histogram range
545 range_max: Maximum value for histogram range
547 Returns:
548 Equalized 3D PyTorch tensor of shape (Z, Y, X)
549 """
550 _validate_3d_array(stack)
552 # PyTorch doesn't have a direct histogram equalization function
553 # We'll implement it manually using torch.histc for the histogram
555 # Flatten the entire stack to compute the global histogram
556 flat_stack = stack.float().flatten()
558 # For very large stacks, use sampling to avoid memory issues
559 max_elements_for_histogram = 50_000_000 # 50M elements limit
560 if flat_stack.numel() > max_elements_for_histogram:
561 # Use random sampling for histogram computation
562 sample_size = max_elements_for_histogram
563 indices = torch.randint(0, flat_stack.numel(), (sample_size,), device=stack.device)
564 sampled_stack = flat_stack[indices]
565 hist = torch.histc(sampled_stack, bins=bins, min=range_min, max=range_max)
566 logger.debug(f"Used sampling ({sample_size:,} of {flat_stack.numel():,} elements) for histogram computation")
567 else:
568 # Use full stack for smaller stacks
569 hist = torch.histc(flat_stack, bins=bins, min=range_min, max=range_max)
571 # We don't need bin edges for the lookup table approach
573 # Calculate cumulative distribution function (CDF)
574 cdf = torch.cumsum(hist, dim=0)
576 # Normalize the CDF to the range [0, 65535]
577 # Avoid division by zero
578 if cdf[-1] > 0:
579 cdf = 65535 * cdf / cdf[-1]
581 # PyTorch doesn't have a direct equivalent to numpy's interp
582 # We'll use a lookup table approach
584 # Scale input values to bin indices
585 indices = torch.clamp(
586 ((flat_stack - range_min) / (range_max - range_min) * (bins - 1)).long(),
587 0, bins - 1
588 )
590 # Look up CDF values
591 equalized_flat = torch.gather(cdf, 0, indices)
593 # Reshape back to original shape
594 equalized_stack = equalized_flat.reshape(stack.shape)
596 # Convert to uint16
597 return equalized_stack.to(torch.uint16)
599@torch_func
600def create_projection(
601 stack: "torch.Tensor", method: str = "max_projection"
602) -> "torch.Tensor":
603 """
604 Create a projection from a stack using the specified method.
606 Args:
607 stack: 3D PyTorch tensor of shape (Z, Y, X)
608 method: Projection method (max_projection, mean_projection)
610 Returns:
611 3D PyTorch tensor of shape (1, Y, X)
612 """
613 _validate_3d_array(stack)
615 if method == "max_projection":
616 return max_projection(stack)
618 if method == "mean_projection":
619 return mean_projection(stack)
621 # FAIL FAST: No fallback projection methods
622 raise ValueError(f"Unknown projection method: {method}. Valid methods: max_projection, mean_projection")
624@torch_func
625def tophat(
626 image: "torch.Tensor",
627 selem_radius: int = 50,
628 downsample_factor: int = 4
629) -> "torch.Tensor":
630 """
631 Apply white top-hat filter to a 3D image for background removal.
633 This applies the filter to each Z-slice independently using PyTorch's
634 native operations.
636 Args:
637 image: 3D PyTorch tensor of shape (Z, Y, X)
638 selem_radius: Radius of the structuring element disk
639 downsample_factor: Factor by which to downsample the image for processing
641 Returns:
642 Filtered 3D PyTorch tensor of shape (Z, Y, X)
643 """
644 _validate_3d_array(image)
646 # Store device for later use
647 device = image.device
649 # Process each Z-slice independently
650 result = torch.zeros_like(image)
652 # We'll create structuring elements for each slice as needed
654 for z in range(image.shape[0]):
655 # Store original data type
656 input_dtype = image[z].dtype
658 # 1) Downsample using PyTorch's interpolate function
659 # First, add batch and channel dimensions for interpolate
660 img_4d = image[z].float().unsqueeze(0).unsqueeze(0)
662 # Calculate new dimensions
663 new_h = image[z].shape[0] // downsample_factor
664 new_w = image[z].shape[1] // downsample_factor
666 # Resize using PyTorch's interpolate function
667 image_small = F.interpolate(
668 img_4d,
669 size=(new_h, new_w),
670 mode='bilinear',
671 align_corners=False
672 ).squeeze(0).squeeze(0)
674 # 2) Resize the structuring element to match the downsampled image
675 small_selem_radius = max(1, selem_radius // downsample_factor)
676 small_grid_size = 2 * small_selem_radius + 1
677 small_grid_y, small_grid_x = torch.meshgrid(
678 torch.arange(small_grid_size, device=device) - small_selem_radius,
679 torch.arange(small_grid_size, device=device) - small_selem_radius,
680 indexing='ij'
681 )
682 small_mask = (small_grid_x.pow(2) + small_grid_y.pow(2)) <= small_selem_radius**2
683 small_selem = small_mask.float()
685 # 3) Apply white top-hat using PyTorch's convolution operations
686 # White top-hat is opening subtracted from the original image
687 # Opening is erosion followed by dilation
689 # Implement erosion using min pooling with custom kernel
690 # First, pad the image to handle boundary conditions
691 pad_size = small_selem_radius
692 padded = F.pad(
693 image_small.unsqueeze(0).unsqueeze(0),
694 (pad_size, pad_size, pad_size, pad_size),
695 mode='reflect'
696 )
698 # Unfold the padded image into patches
699 patches = F.unfold(padded, kernel_size=small_grid_size, stride=1)
701 # Reshape patches for processing
702 patch_size = small_grid_size * small_grid_size
703 patches = patches.reshape(1, patch_size, new_h, new_w)
705 # Apply the structuring element as a mask
706 masked_patches = patches * small_selem.reshape(-1, 1, 1)
708 # Perform erosion (min pooling)
709 eroded = torch.min(
710 masked_patches + (1 - small_selem.reshape(-1, 1, 1)) * 1e9,
711 dim=1
712 )[0]
714 # Implement dilation using max pooling with custom kernel
715 # Pad the eroded image
716 padded_eroded = F.pad(
717 eroded.unsqueeze(0).unsqueeze(0),
718 (pad_size, pad_size, pad_size, pad_size),
719 mode='reflect'
720 )
722 # Unfold the padded eroded image into patches
723 patches_eroded = F.unfold(padded_eroded, kernel_size=small_grid_size, stride=1)
725 # Reshape patches for processing
726 patch_size = small_grid_size * small_grid_size
727 patches_eroded = patches_eroded.reshape(1, patch_size, new_h, new_w)
729 # Apply the structuring element as a mask
730 masked_patches_eroded = patches_eroded * small_selem.reshape(-1, 1, 1)
732 # Perform dilation (max pooling)
733 opened = torch.max(masked_patches_eroded, dim=1)[0]
735 # White top-hat is original minus opening
736 tophat_small = image_small - opened
738 # 4) Calculate background
739 background_small = image_small - tophat_small
741 # 5) Upscale background to original size
742 background_4d = background_small.unsqueeze(0).unsqueeze(0)
743 background_large = F.interpolate(
744 background_4d,
745 size=image[z].shape,
746 mode='bilinear',
747 align_corners=False
748 ).squeeze(0).squeeze(0)
750 # 6) Subtract background and clip negative values
751 slice_result = torch.clamp(image[z].float() - background_large, min=0.0)
753 # 7) Convert back to original data type
754 result[z] = slice_result.to(input_dtype)
756 return result