Coverage for openhcs/processing/backends/enhance/torch_nlm_processor.py: 23.4%

46 statements  

« prev     ^ index     » next       coverage.py v7.11.0, created at 2025-11-04 02:09 +0000

1""" 

2Non-Local Means Denoising Implementation using torch_nlm 

3 

4This module provides OpenHCS-decorated wrapper functions for the torch_nlm library, 

5which implements memory-efficient non-local means denoising with GPU acceleration. 

6 

7Non-local means is an advanced denoising algorithm that preserves fine details 

8and textures by comparing patches across the entire image rather than just 

9local neighborhoods. The torch_nlm implementation provides significant speedup 

10over traditional CPU implementations, especially for large 3D volumes. 

11 

12Doctrinal Clauses: 

13- Clause 3 — Declarative Primacy: All functions are pure and stateless 

14- Clause 65 — Fail Loudly: No silent fallbacks or inferred capabilities 

15- Clause 88 — No Inferred Capabilities: Explicit PyTorch and torch_nlm dependency 

16- Clause 273 — Memory Backend Restrictions: GPU-only implementation 

17""" 

18from __future__ import annotations 

19 

20import logging 

21 

22from openhcs.utils.import_utils import optional_import 

23from openhcs.core.memory.decorators import torch as torch_func 

24 

25# Import torch modules as optional dependencies 

26from openhcs.core.lazy_gpu_imports import torch 

27 

28# Import torch_nlm as optional dependency 

29# Note: The PyPI package is named 'nlm-torch' but imports as 'torch_nlm' 

30torch_nlm = optional_import("torch_nlm") 

31if torch_nlm is not None: 31 ↛ 32line 31 didn't jump to line 32 because the condition on line 31 was never true

32 nlm2d = torch_nlm.nlm2d 

33 nlm3d = torch_nlm.nlm3d 

34else: 

35 nlm2d = None 

36 nlm3d = None 

37 

38logger = logging.getLogger(__name__) 

39 

40 

41def _validate_3d_array(image: "torch.Tensor") -> None: 

42 """Validate that input is a 3D torch tensor.""" 

43 if torch is None: 

44 raise ImportError("PyTorch is required for torch_nlm functions") 

45 

46 if not isinstance(image, torch.Tensor): 

47 raise TypeError(f"Input must be a torch.Tensor, got {type(image)}") 

48 

49 if image.ndim != 3: 

50 raise ValueError(f"Input must be a 3D tensor (Z, Y, X), got {image.ndim}D tensor") 

51 

52 

53@torch_func 

54def non_local_means_denoise_torch( 

55 image: "torch.Tensor", 

56 *, 

57 kernel_size: int = 11, 

58 std: float = 1.0, 

59 kernel_size_mean: int = 3, 

60 sub_filter_size: int = 32, 

61 slice_by_slice: bool = True, 

62 **kwargs 

63) -> "torch.Tensor": 

64 """ 

65 Apply Non-Local Means denoising to a 3D image stack using torch_nlm. 

66 

67 Non-Local Means is an advanced denoising algorithm that preserves fine details 

68 and textures by comparing patches across the entire image rather than just 

69 local neighborhoods. This implementation uses torch_nlm for GPU acceleration. 

70 

71 Args: 

72 image: 3D PyTorch tensor of shape (Z, Y, X) 

73 kernel_size: Size of the neighborhood for patch comparison (default: 11) 

74 std: Standard deviation for weight calculation (default: 1.0) 

75 kernel_size_mean: Kernel size for initial mean filtering (default: 3) 

76 sub_filter_size: Number of neighborhoods computed per iteration for memory efficiency (default: 32) 

77 slice_by_slice: Process each Z-slice independently using 2D NLM (default: True). 

78 If False, uses 3D NLM processing across all dimensions. 

79 **kwargs: Additional arguments (ignored for compatibility) 

80 

81 Returns: 

82 Denoised 3D PyTorch tensor of shape (Z, Y, X) 

83 

84 Raises: 

85 ImportError: If torch_nlm is not available 

86 TypeError: If input is not a torch.Tensor 

87 ValueError: If input is not 3D 

88 RuntimeError: If tensor is not on CUDA device 

89 

90 Additional OpenHCS Parameters 

91 ----------------------------- 

92 slice_by_slice : bool, optional (default: True) 

93 If True, process 3D arrays slice-by-slice using 2D non-local means to avoid 

94 cross-slice contamination. If False, use 3D non-local means processing. 

95 Recommended for stitched microscopy data to prevent artifacts at field boundaries. 

96 """ 

97 _validate_3d_array(image) 

98 

99 if torch_nlm is None: 

100 raise ImportError( 

101 "torch_nlm is required for this function. " 

102 "Install with: pip install nlm-torch" 

103 ) 

104 

105 # FAIL LOUDLY if not on CUDA - no CPU fallback allowed 

106 if image.device.type != "cuda": 

107 raise RuntimeError( 

108 f"torch_nlm requires CUDA tensor, got device: {image.device}. " 

109 "Move tensor to CUDA with: tensor.cuda()" 

110 ) 

111 

112 # Store original dtype for conversion back 

113 original_dtype = image.dtype 

114 device = image.device 

115 

116 # Convert to float32 for processing if needed 

117 if image.dtype != torch.float32: 

118 image_float = image.float() 

119 else: 

120 image_float = image 

121 

122 # Handle slice_by_slice processing using OpenHCS pattern 

123 if slice_by_slice: 

124 # Process each Z-slice independently using 2D non-local means 

125 from openhcs.core.memory.stack_utils import unstack_slices, stack_slices 

126 from openhcs.core.memory.converters import detect_memory_type 

127 

128 # Detect memory type and use proper OpenHCS utilities 

129 memory_type = detect_memory_type(image_float) 

130 gpu_id = 0 # Default GPU ID for slice processing 

131 

132 # Unstack 3D array into 2D slices 

133 slices_2d = unstack_slices(image_float, memory_type, gpu_id) 

134 

135 # Process each slice 

136 processed_slices = [] 

137 for slice_2d in slices_2d: 

138 # Apply 2D non-local means to this slice 

139 denoised_slice = nlm2d( 

140 slice_2d, 

141 kernel_size=kernel_size, 

142 std=std, 

143 kernel_size_mean=kernel_size_mean, 

144 sub_filter_size=sub_filter_size 

145 ) 

146 processed_slices.append(denoised_slice) 

147 

148 # Stack results back to 3D 

149 result = stack_slices(processed_slices, memory_type, gpu_id) 

150 else: 

151 # Use 3D processing directly (fallback to nlm3d) 

152 result = nlm3d( 

153 image_float, 

154 kernel_size=kernel_size, 

155 std=std, 

156 kernel_size_mean=kernel_size_mean, 

157 sub_filter_size=sub_filter_size 

158 ) 

159 

160 # Convert back to original dtype 

161 result = result.to(original_dtype) 

162 

163 return result 

164 

165 

166# Alias for convenience 

167torch_nlm_denoise = non_local_means_denoise_torch