Coverage for openhcs/processing/backends/enhance/torch_nlm_processor.py: 23.4%
46 statements
« prev ^ index » next coverage.py v7.11.0, created at 2025-11-04 02:09 +0000
« prev ^ index » next coverage.py v7.11.0, created at 2025-11-04 02:09 +0000
1"""
2Non-Local Means Denoising Implementation using torch_nlm
4This module provides OpenHCS-decorated wrapper functions for the torch_nlm library,
5which implements memory-efficient non-local means denoising with GPU acceleration.
7Non-local means is an advanced denoising algorithm that preserves fine details
8and textures by comparing patches across the entire image rather than just
9local neighborhoods. The torch_nlm implementation provides significant speedup
10over traditional CPU implementations, especially for large 3D volumes.
12Doctrinal Clauses:
13- Clause 3 — Declarative Primacy: All functions are pure and stateless
14- Clause 65 — Fail Loudly: No silent fallbacks or inferred capabilities
15- Clause 88 — No Inferred Capabilities: Explicit PyTorch and torch_nlm dependency
16- Clause 273 — Memory Backend Restrictions: GPU-only implementation
17"""
18from __future__ import annotations
20import logging
22from openhcs.utils.import_utils import optional_import
23from openhcs.core.memory.decorators import torch as torch_func
25# Import torch modules as optional dependencies
26from openhcs.core.lazy_gpu_imports import torch
28# Import torch_nlm as optional dependency
29# Note: The PyPI package is named 'nlm-torch' but imports as 'torch_nlm'
30torch_nlm = optional_import("torch_nlm")
31if torch_nlm is not None: 31 ↛ 32line 31 didn't jump to line 32 because the condition on line 31 was never true
32 nlm2d = torch_nlm.nlm2d
33 nlm3d = torch_nlm.nlm3d
34else:
35 nlm2d = None
36 nlm3d = None
38logger = logging.getLogger(__name__)
41def _validate_3d_array(image: "torch.Tensor") -> None:
42 """Validate that input is a 3D torch tensor."""
43 if torch is None:
44 raise ImportError("PyTorch is required for torch_nlm functions")
46 if not isinstance(image, torch.Tensor):
47 raise TypeError(f"Input must be a torch.Tensor, got {type(image)}")
49 if image.ndim != 3:
50 raise ValueError(f"Input must be a 3D tensor (Z, Y, X), got {image.ndim}D tensor")
53@torch_func
54def non_local_means_denoise_torch(
55 image: "torch.Tensor",
56 *,
57 kernel_size: int = 11,
58 std: float = 1.0,
59 kernel_size_mean: int = 3,
60 sub_filter_size: int = 32,
61 slice_by_slice: bool = True,
62 **kwargs
63) -> "torch.Tensor":
64 """
65 Apply Non-Local Means denoising to a 3D image stack using torch_nlm.
67 Non-Local Means is an advanced denoising algorithm that preserves fine details
68 and textures by comparing patches across the entire image rather than just
69 local neighborhoods. This implementation uses torch_nlm for GPU acceleration.
71 Args:
72 image: 3D PyTorch tensor of shape (Z, Y, X)
73 kernel_size: Size of the neighborhood for patch comparison (default: 11)
74 std: Standard deviation for weight calculation (default: 1.0)
75 kernel_size_mean: Kernel size for initial mean filtering (default: 3)
76 sub_filter_size: Number of neighborhoods computed per iteration for memory efficiency (default: 32)
77 slice_by_slice: Process each Z-slice independently using 2D NLM (default: True).
78 If False, uses 3D NLM processing across all dimensions.
79 **kwargs: Additional arguments (ignored for compatibility)
81 Returns:
82 Denoised 3D PyTorch tensor of shape (Z, Y, X)
84 Raises:
85 ImportError: If torch_nlm is not available
86 TypeError: If input is not a torch.Tensor
87 ValueError: If input is not 3D
88 RuntimeError: If tensor is not on CUDA device
90 Additional OpenHCS Parameters
91 -----------------------------
92 slice_by_slice : bool, optional (default: True)
93 If True, process 3D arrays slice-by-slice using 2D non-local means to avoid
94 cross-slice contamination. If False, use 3D non-local means processing.
95 Recommended for stitched microscopy data to prevent artifacts at field boundaries.
96 """
97 _validate_3d_array(image)
99 if torch_nlm is None:
100 raise ImportError(
101 "torch_nlm is required for this function. "
102 "Install with: pip install nlm-torch"
103 )
105 # FAIL LOUDLY if not on CUDA - no CPU fallback allowed
106 if image.device.type != "cuda":
107 raise RuntimeError(
108 f"torch_nlm requires CUDA tensor, got device: {image.device}. "
109 "Move tensor to CUDA with: tensor.cuda()"
110 )
112 # Store original dtype for conversion back
113 original_dtype = image.dtype
114 device = image.device
116 # Convert to float32 for processing if needed
117 if image.dtype != torch.float32:
118 image_float = image.float()
119 else:
120 image_float = image
122 # Handle slice_by_slice processing using OpenHCS pattern
123 if slice_by_slice:
124 # Process each Z-slice independently using 2D non-local means
125 from openhcs.core.memory.stack_utils import unstack_slices, stack_slices
126 from openhcs.core.memory.converters import detect_memory_type
128 # Detect memory type and use proper OpenHCS utilities
129 memory_type = detect_memory_type(image_float)
130 gpu_id = 0 # Default GPU ID for slice processing
132 # Unstack 3D array into 2D slices
133 slices_2d = unstack_slices(image_float, memory_type, gpu_id)
135 # Process each slice
136 processed_slices = []
137 for slice_2d in slices_2d:
138 # Apply 2D non-local means to this slice
139 denoised_slice = nlm2d(
140 slice_2d,
141 kernel_size=kernel_size,
142 std=std,
143 kernel_size_mean=kernel_size_mean,
144 sub_filter_size=sub_filter_size
145 )
146 processed_slices.append(denoised_slice)
148 # Stack results back to 3D
149 result = stack_slices(processed_slices, memory_type, gpu_id)
150 else:
151 # Use 3D processing directly (fallback to nlm3d)
152 result = nlm3d(
153 image_float,
154 kernel_size=kernel_size,
155 std=std,
156 kernel_size_mean=kernel_size_mean,
157 sub_filter_size=sub_filter_size
158 )
160 # Convert back to original dtype
161 result = result.to(original_dtype)
163 return result
166# Alias for convenience
167torch_nlm_denoise = non_local_means_denoise_torch