Coverage for openhcs/processing/backends/enhance/torch_nlm_processor.py: 25.0%

46 statements  

« prev     ^ index     » next       coverage.py v7.10.3, created at 2025-08-14 05:57 +0000

1""" 

2Non-Local Means Denoising Implementation using torch_nlm 

3 

4This module provides OpenHCS-decorated wrapper functions for the torch_nlm library, 

5which implements memory-efficient non-local means denoising with GPU acceleration. 

6 

7Non-local means is an advanced denoising algorithm that preserves fine details 

8and textures by comparing patches across the entire image rather than just 

9local neighborhoods. The torch_nlm implementation provides significant speedup 

10over traditional CPU implementations, especially for large 3D volumes. 

11 

12Doctrinal Clauses: 

13- Clause 3 — Declarative Primacy: All functions are pure and stateless 

14- Clause 65 — Fail Loudly: No silent fallbacks or inferred capabilities 

15- Clause 88 — No Inferred Capabilities: Explicit PyTorch and torch_nlm dependency 

16- Clause 273 — Memory Backend Restrictions: GPU-only implementation 

17""" 

18from __future__ import annotations 

19 

20import logging 

21from typing import Optional 

22 

23from openhcs.utils.import_utils import optional_import 

24from openhcs.core.memory.decorators import torch as torch_func 

25 

26# Import torch modules as optional dependencies 

27torch = optional_import("torch") 

28 

29# Import torch_nlm as optional dependency 

30# Note: The PyPI package is named 'nlm-torch' but imports as 'torch_nlm' 

31torch_nlm = optional_import("torch_nlm") 

32if torch_nlm is not None: 32 ↛ 33line 32 didn't jump to line 33 because the condition on line 32 was never true

33 nlm2d = torch_nlm.nlm2d 

34 nlm3d = torch_nlm.nlm3d 

35else: 

36 nlm2d = None 

37 nlm3d = None 

38 

39logger = logging.getLogger(__name__) 

40 

41 

42def _validate_3d_array(image: "torch.Tensor") -> None: 

43 """Validate that input is a 3D torch tensor.""" 

44 if torch is None: 

45 raise ImportError("PyTorch is required for torch_nlm functions") 

46 

47 if not isinstance(image, torch.Tensor): 

48 raise TypeError(f"Input must be a torch.Tensor, got {type(image)}") 

49 

50 if image.ndim != 3: 

51 raise ValueError(f"Input must be a 3D tensor (Z, Y, X), got {image.ndim}D tensor") 

52 

53 

54@torch_func 

55def non_local_means_denoise_torch( 

56 image: "torch.Tensor", 

57 *, 

58 kernel_size: int = 11, 

59 std: float = 1.0, 

60 kernel_size_mean: int = 3, 

61 sub_filter_size: int = 32, 

62 slice_by_slice: bool = True, 

63 **kwargs 

64) -> "torch.Tensor": 

65 """ 

66 Apply Non-Local Means denoising to a 3D image stack using torch_nlm. 

67 

68 Non-Local Means is an advanced denoising algorithm that preserves fine details 

69 and textures by comparing patches across the entire image rather than just 

70 local neighborhoods. This implementation uses torch_nlm for GPU acceleration. 

71 

72 Args: 

73 image: 3D PyTorch tensor of shape (Z, Y, X) 

74 kernel_size: Size of the neighborhood for patch comparison (default: 11) 

75 std: Standard deviation for weight calculation (default: 1.0) 

76 kernel_size_mean: Kernel size for initial mean filtering (default: 3) 

77 sub_filter_size: Number of neighborhoods computed per iteration for memory efficiency (default: 32) 

78 slice_by_slice: Process each Z-slice independently using 2D NLM (default: True). 

79 If False, uses 3D NLM processing across all dimensions. 

80 **kwargs: Additional arguments (ignored for compatibility) 

81 

82 Returns: 

83 Denoised 3D PyTorch tensor of shape (Z, Y, X) 

84 

85 Raises: 

86 ImportError: If torch_nlm is not available 

87 TypeError: If input is not a torch.Tensor 

88 ValueError: If input is not 3D 

89 RuntimeError: If tensor is not on CUDA device 

90 

91 Additional OpenHCS Parameters 

92 ----------------------------- 

93 slice_by_slice : bool, optional (default: True) 

94 If True, process 3D arrays slice-by-slice using 2D non-local means to avoid 

95 cross-slice contamination. If False, use 3D non-local means processing. 

96 Recommended for stitched microscopy data to prevent artifacts at field boundaries. 

97 """ 

98 _validate_3d_array(image) 

99 

100 if torch_nlm is None: 

101 raise ImportError( 

102 "torch_nlm is required for this function. " 

103 "Install with: pip install nlm-torch" 

104 ) 

105 

106 # FAIL LOUDLY if not on CUDA - no CPU fallback allowed 

107 if image.device.type != "cuda": 

108 raise RuntimeError( 

109 f"torch_nlm requires CUDA tensor, got device: {image.device}. " 

110 "Move tensor to CUDA with: tensor.cuda()" 

111 ) 

112 

113 # Store original dtype for conversion back 

114 original_dtype = image.dtype 

115 device = image.device 

116 

117 # Convert to float32 for processing if needed 

118 if image.dtype != torch.float32: 

119 image_float = image.float() 

120 else: 

121 image_float = image 

122 

123 # Handle slice_by_slice processing using OpenHCS pattern 

124 if slice_by_slice: 

125 # Process each Z-slice independently using 2D non-local means 

126 from openhcs.core.memory.stack_utils import unstack_slices, stack_slices, _detect_memory_type 

127 

128 # Detect memory type and use proper OpenHCS utilities 

129 memory_type = _detect_memory_type(image_float) 

130 gpu_id = 0 # Default GPU ID for slice processing 

131 

132 # Unstack 3D array into 2D slices 

133 slices_2d = unstack_slices(image_float, memory_type, gpu_id) 

134 

135 # Process each slice 

136 processed_slices = [] 

137 for slice_2d in slices_2d: 

138 # Apply 2D non-local means to this slice 

139 denoised_slice = nlm2d( 

140 slice_2d, 

141 kernel_size=kernel_size, 

142 std=std, 

143 kernel_size_mean=kernel_size_mean, 

144 sub_filter_size=sub_filter_size 

145 ) 

146 processed_slices.append(denoised_slice) 

147 

148 # Stack results back to 3D 

149 result = stack_slices(processed_slices, memory_type, gpu_id) 

150 else: 

151 # Use 3D processing directly (fallback to nlm3d) 

152 result = nlm3d( 

153 image_float, 

154 kernel_size=kernel_size, 

155 std=std, 

156 kernel_size_mean=kernel_size_mean, 

157 sub_filter_size=sub_filter_size 

158 ) 

159 

160 # Convert back to original dtype 

161 result = result.to(original_dtype) 

162 

163 return result 

164 

165 

166# Alias for convenience 

167torch_nlm_denoise = non_local_means_denoise_torch