Coverage for openhcs/processing/backends/enhance/torch_nlm_processor.py: 25.0%
46 statements
« prev ^ index » next coverage.py v7.10.3, created at 2025-08-14 05:57 +0000
« prev ^ index » next coverage.py v7.10.3, created at 2025-08-14 05:57 +0000
1"""
2Non-Local Means Denoising Implementation using torch_nlm
4This module provides OpenHCS-decorated wrapper functions for the torch_nlm library,
5which implements memory-efficient non-local means denoising with GPU acceleration.
7Non-local means is an advanced denoising algorithm that preserves fine details
8and textures by comparing patches across the entire image rather than just
9local neighborhoods. The torch_nlm implementation provides significant speedup
10over traditional CPU implementations, especially for large 3D volumes.
12Doctrinal Clauses:
13- Clause 3 — Declarative Primacy: All functions are pure and stateless
14- Clause 65 — Fail Loudly: No silent fallbacks or inferred capabilities
15- Clause 88 — No Inferred Capabilities: Explicit PyTorch and torch_nlm dependency
16- Clause 273 — Memory Backend Restrictions: GPU-only implementation
17"""
18from __future__ import annotations
20import logging
21from typing import Optional
23from openhcs.utils.import_utils import optional_import
24from openhcs.core.memory.decorators import torch as torch_func
26# Import torch modules as optional dependencies
27torch = optional_import("torch")
29# Import torch_nlm as optional dependency
30# Note: The PyPI package is named 'nlm-torch' but imports as 'torch_nlm'
31torch_nlm = optional_import("torch_nlm")
32if torch_nlm is not None: 32 ↛ 33line 32 didn't jump to line 33 because the condition on line 32 was never true
33 nlm2d = torch_nlm.nlm2d
34 nlm3d = torch_nlm.nlm3d
35else:
36 nlm2d = None
37 nlm3d = None
39logger = logging.getLogger(__name__)
42def _validate_3d_array(image: "torch.Tensor") -> None:
43 """Validate that input is a 3D torch tensor."""
44 if torch is None:
45 raise ImportError("PyTorch is required for torch_nlm functions")
47 if not isinstance(image, torch.Tensor):
48 raise TypeError(f"Input must be a torch.Tensor, got {type(image)}")
50 if image.ndim != 3:
51 raise ValueError(f"Input must be a 3D tensor (Z, Y, X), got {image.ndim}D tensor")
54@torch_func
55def non_local_means_denoise_torch(
56 image: "torch.Tensor",
57 *,
58 kernel_size: int = 11,
59 std: float = 1.0,
60 kernel_size_mean: int = 3,
61 sub_filter_size: int = 32,
62 slice_by_slice: bool = True,
63 **kwargs
64) -> "torch.Tensor":
65 """
66 Apply Non-Local Means denoising to a 3D image stack using torch_nlm.
68 Non-Local Means is an advanced denoising algorithm that preserves fine details
69 and textures by comparing patches across the entire image rather than just
70 local neighborhoods. This implementation uses torch_nlm for GPU acceleration.
72 Args:
73 image: 3D PyTorch tensor of shape (Z, Y, X)
74 kernel_size: Size of the neighborhood for patch comparison (default: 11)
75 std: Standard deviation for weight calculation (default: 1.0)
76 kernel_size_mean: Kernel size for initial mean filtering (default: 3)
77 sub_filter_size: Number of neighborhoods computed per iteration for memory efficiency (default: 32)
78 slice_by_slice: Process each Z-slice independently using 2D NLM (default: True).
79 If False, uses 3D NLM processing across all dimensions.
80 **kwargs: Additional arguments (ignored for compatibility)
82 Returns:
83 Denoised 3D PyTorch tensor of shape (Z, Y, X)
85 Raises:
86 ImportError: If torch_nlm is not available
87 TypeError: If input is not a torch.Tensor
88 ValueError: If input is not 3D
89 RuntimeError: If tensor is not on CUDA device
91 Additional OpenHCS Parameters
92 -----------------------------
93 slice_by_slice : bool, optional (default: True)
94 If True, process 3D arrays slice-by-slice using 2D non-local means to avoid
95 cross-slice contamination. If False, use 3D non-local means processing.
96 Recommended for stitched microscopy data to prevent artifacts at field boundaries.
97 """
98 _validate_3d_array(image)
100 if torch_nlm is None:
101 raise ImportError(
102 "torch_nlm is required for this function. "
103 "Install with: pip install nlm-torch"
104 )
106 # FAIL LOUDLY if not on CUDA - no CPU fallback allowed
107 if image.device.type != "cuda":
108 raise RuntimeError(
109 f"torch_nlm requires CUDA tensor, got device: {image.device}. "
110 "Move tensor to CUDA with: tensor.cuda()"
111 )
113 # Store original dtype for conversion back
114 original_dtype = image.dtype
115 device = image.device
117 # Convert to float32 for processing if needed
118 if image.dtype != torch.float32:
119 image_float = image.float()
120 else:
121 image_float = image
123 # Handle slice_by_slice processing using OpenHCS pattern
124 if slice_by_slice:
125 # Process each Z-slice independently using 2D non-local means
126 from openhcs.core.memory.stack_utils import unstack_slices, stack_slices, _detect_memory_type
128 # Detect memory type and use proper OpenHCS utilities
129 memory_type = _detect_memory_type(image_float)
130 gpu_id = 0 # Default GPU ID for slice processing
132 # Unstack 3D array into 2D slices
133 slices_2d = unstack_slices(image_float, memory_type, gpu_id)
135 # Process each slice
136 processed_slices = []
137 for slice_2d in slices_2d:
138 # Apply 2D non-local means to this slice
139 denoised_slice = nlm2d(
140 slice_2d,
141 kernel_size=kernel_size,
142 std=std,
143 kernel_size_mean=kernel_size_mean,
144 sub_filter_size=sub_filter_size
145 )
146 processed_slices.append(denoised_slice)
148 # Stack results back to 3D
149 result = stack_slices(processed_slices, memory_type, gpu_id)
150 else:
151 # Use 3D processing directly (fallback to nlm3d)
152 result = nlm3d(
153 image_float,
154 kernel_size=kernel_size,
155 std=std,
156 kernel_size_mean=kernel_size_mean,
157 sub_filter_size=sub_filter_size
158 )
160 # Convert back to original dtype
161 result = result.to(original_dtype)
163 return result
166# Alias for convenience
167torch_nlm_denoise = non_local_means_denoise_torch