Coverage for openhcs/processing/backends/enhance/jax_nlm_processor.py: 15.2%

103 statements  

« prev     ^ index     » next       coverage.py v7.10.3, created at 2025-08-14 05:57 +0000

1""" 

2JAX-based Non-Local Means Denoising Implementation 

3 

4This module provides OpenHCS-decorated wrapper functions for non-local means denoising 

5using JAX backend with automatic output rescaling to prevent clipping issues when 

6converting to uint16. 

7 

8Non-local means is an advanced denoising algorithm that preserves fine details 

9and textures by comparing patches across the entire image rather than just 

10local neighborhoods. This JAX implementation provides GPU acceleration with 

11automatic output normalization. 

12 

13Doctrinal Clauses: 

14- Clause 3 — Declarative Primacy: All functions are pure and stateless 

15- Clause 65 — Fail Loudly: No silent fallbacks or inferred capabilities 

16- Clause 88 — No Inferred Capabilities: Explicit JAX dependency 

17- Clause 273 — Memory Backend Restrictions: JAX-only implementation 

18""" 

19from __future__ import annotations 

20 

21import logging 

22from typing import Optional 

23 

24from openhcs.utils.import_utils import optional_import 

25from openhcs.core.memory.decorators import jax as jax_func 

26 

27# Import JAX modules as optional dependencies 

28jax = optional_import("jax") 

29jnp = optional_import("jax.numpy") if jax is not None else None 

30lax = jax.lax if jax is not None else None 

31tree_util = jax.tree_util if jax is not None else None 

32 

33logger = logging.getLogger(__name__) 

34 

35 

36def _validate_jax_array(image: "jnp.ndarray") -> None: 

37 """Validate that input is a JAX array (2D or 3D).""" 

38 if jax is None or jnp is None: 

39 raise ImportError("JAX is required for JAX NLM functions") 

40 

41 if not isinstance(image, jnp.ndarray): 

42 raise TypeError(f"Input must be a jax.numpy.ndarray, got {type(image)}") 

43 

44 if image.ndim not in [2, 3]: 

45 raise ValueError(f"Input must be a 2D or 3D array, got {image.ndim}D array") 

46 

47 

48def _rescale_to_unit_range(image: "jnp.ndarray") -> "jnp.ndarray": 

49 """ 

50 Rescale image so that the minimum value across the entire stack is 0  

51 and the maximum value is 1. 

52  

53 This prevents clipping issues when converting to uint16. 

54  

55 Args: 

56 image: 3D JAX array of shape (Z, Y, X) 

57  

58 Returns: 

59 Rescaled 3D JAX array with values in [0, 1] range 

60 """ 

61 # Calculate global min and max across the entire stack 

62 global_min = jnp.min(image) 

63 global_max = jnp.max(image) 

64 

65 # Avoid division by zero 

66 range_val = global_max - global_min 

67 

68 # If all values are the same, return zeros 

69 def rescale_normal(args): 

70 image, global_min, range_val = args 

71 return (image - global_min) / range_val 

72 

73 def return_zeros(args): 

74 image, _, _ = args 

75 return jnp.zeros_like(image) 

76 

77 # Use JAX conditional to handle zero range 

78 result = lax.cond( 

79 range_val > 0, 

80 rescale_normal, 

81 return_zeros, 

82 (image, global_min, range_val) 

83 ) 

84 

85 return result 

86 

87 

88def _ixs(y_ixs, x_ixs): 

89 """Create meshgrid for vectorized operations.""" 

90 return jnp.meshgrid(x_ixs, y_ixs) 

91 

92 

93def _vmap_2d(f, y_ixs, x_ixs): 

94 """Apply function f over 2D grid using vectorized mapping.""" 

95 _x, _y = _ixs(y_ixs, x_ixs) 

96 return jax.vmap(jax.vmap(f))(_y, _x) 

97 

98 

99# Use jax.tree_util.Partial instead of functools.partial for better JAX integration 

100# - jax.tree_util.Partial is a JAX pytree, compatible with JAX transformations 

101# - Enables proper serialization and JIT compilation 

102# - Better performance with JAX's internal machinery 

103@tree_util.Partial(jax.jit, static_argnums=(1, 2)) if jax is not None and tree_util is not None else lambda f: f 

104def _nlm_core(img: "jnp.ndarray", search_window_radius: int, filter_radius: int, h: float, sigma: float) -> "jnp.ndarray": 

105 """ 

106 Core non-local means implementation based on Buades et al. 

107 

108 This is a vectorized and JIT-compiled implementation adapted from: 

109 https://github.com/bhchiang/nlm 

110 

111 Args: 

112 img: 2D image array 

113 search_window_radius: Radius of search window 

114 filter_radius: Radius of comparison patches 

115 h: Filter strength parameter 

116 sigma: Noise standard deviation 

117 

118 Returns: 

119 Denoised 2D image 

120 """ 

121 _h, _w = img.shape 

122 pad = search_window_radius 

123 img_pad = jnp.pad(img, pad, mode='reflect') 

124 

125 filter_length = 2 * filter_radius + 1 

126 search_window_length = 2 * search_window_radius + 1 

127 

128 win_y_ixs = win_x_ixs = jnp.arange(search_window_length - filter_length + 1) 

129 filter_size = (filter_length, filter_length) 

130 

131 def compute(y, x): 

132 # (y + pad, x + pad) are the center of the current neighborhood 

133 win_center_y = y + pad 

134 win_center_x = x + pad 

135 

136 center_patch = lax.dynamic_slice( 

137 img_pad, 

138 (win_center_y - filter_radius, win_center_x - filter_radius), 

139 filter_size 

140 ) 

141 

142 # Iterate over all patches in this neighborhood 

143 def _compare(center): 

144 center_y, center_x = center 

145 patch = lax.dynamic_slice( 

146 img_pad, 

147 (center_y - filter_radius, center_x - filter_radius), 

148 filter_size 

149 ) 

150 d2 = jnp.sum((patch - center_patch) ** 2) / (filter_length ** 2) 

151 weight = jnp.exp(-(jnp.maximum(d2 - 2 * (sigma**2), 0) / (h**2))) 

152 intensity = img_pad[center_y, center_x] 

153 return (weight, intensity) 

154 

155 def compare(patch_y, patch_x): 

156 patch_center_y = patch_y + filter_radius 

157 patch_center_x = patch_x + filter_radius 

158 

159 # Skip if patch is out of image boundaries or this is the center patch 

160 skip = (lax.lt(patch_center_y, pad) | 

161 lax.ge(patch_center_y, _h + pad) | 

162 lax.lt(patch_center_x, pad) | 

163 lax.ge(patch_center_x, _w + pad) | 

164 (lax.eq(patch_center_y, win_center_y) & lax.eq(patch_center_x, win_center_x))) 

165 

166 return lax.cond( 

167 skip, 

168 lambda _: (0., 0.), 

169 _compare, 

170 (patch_center_y, patch_center_x) 

171 ) 

172 

173 weights, intensities = _vmap_2d(compare, y + win_y_ixs, x + win_x_ixs) 

174 

175 # Use max weight for the center patch 

176 max_weight = jnp.max(weights) 

177 total_weight = jnp.sum(weights) + max_weight 

178 pixel = ((jnp.sum(weights * intensities) + 

179 max_weight * img_pad[win_center_y, win_center_x]) / total_weight) 

180 

181 return pixel 

182 

183 h_ixs = jnp.arange(_h) 

184 w_ixs = jnp.arange(_w) 

185 out = _vmap_2d(compute, h_ixs, w_ixs) 

186 

187 return out 

188 

189 

190@jax_func 

191def non_local_means_denoise_jax( 

192 image: "jnp.ndarray", 

193 *, 

194 search_window_radius: int = 7, 

195 filter_radius: int = 1, 

196 h: Optional[float] = None, 

197 sigma: Optional[float] = None, 

198 slice_by_slice: bool = False, 

199 **kwargs 

200) -> "jnp.ndarray": 

201 """ 

202 Apply Non-Local Means denoising to image(s) using JAX. 

203 

204 This function applies vectorized and JIT-compiled non-local means denoising 

205 based on the implementation by Buades et al. The output is automatically 

206 rescaled to [0, 1] range to prevent clipping issues when converting to uint16. 

207 

208 Can handle both 2D and 3D inputs: 

209 - 2D input: Direct processing (when called by decorator on individual slices) 

210 - 3D input: Slice-by-slice processing or raises error for 3D mode 

211 

212 Args: 

213 image: 2D JAX array of shape (Y, X) or 3D JAX array of shape (Z, Y, X) 

214 search_window_radius: Radius of search window (default: 7) 

215 filter_radius: Radius of comparison patches (default: 1) 

216 h: Filter strength parameter (default: auto-estimated from image) 

217 sigma: Noise standard deviation (default: auto-estimated from image) 

218 slice_by_slice: Process each Z-slice independently (default: False, but effectively True). 

219 If explicitly set to False, raises NotImplementedError for 3D processing. 

220 **kwargs: Additional arguments (ignored for compatibility) 

221 

222 Returns: 

223 Denoised JAX array of same shape as input with values always rescaled to [0, 1] range 

224 

225 Raises: 

226 ImportError: If JAX is not available 

227 TypeError: If input is not a jax.numpy.ndarray 

228 ValueError: If input is not 2D or 3D 

229 NotImplementedError: If slice_by_slice=False (3D processing not yet implemented) 

230 

231 Additional OpenHCS Parameters 

232 ----------------------------- 

233 slice_by_slice : bool, optional (default: False, but effectively True) 

234 If True or not explicitly set to False, process 3D arrays slice-by-slice using 

235 2D non-local means. If explicitly set to False, raises NotImplementedError. 

236 Note: 3D processing is not yet implemented for JAX backend. 

237 """ 

238 _validate_jax_array(image) 

239 

240 if jax is None or jnp is None: 

241 raise ImportError( 

242 "JAX is required for this function. " 

243 "Install with: pip install jax" 

244 ) 

245 

246 # Store original dtype for reference 

247 original_dtype = image.dtype 

248 

249 # Convert to float32 for processing and normalize to [0, 1] range 

250 image_float = image.astype(jnp.float32) 

251 

252 # Normalize input to [0, 1] for consistent parameter behavior 

253 img_min = jnp.min(image_float) 

254 img_max = jnp.max(image_float) 

255 if img_max > img_min: 

256 image_normalized = (image_float - img_min) / (img_max - img_min) 

257 else: 

258 image_normalized = jnp.zeros_like(image_float) 

259 

260 # Auto-estimate parameters if not provided 

261 if sigma is None: 

262 # Simple noise estimation using Laplacian 

263 laplacian_kernel = jnp.array([[0, -1, 0], [-1, 4, -1], [0, -1, 0]], dtype=jnp.float32) 

264 

265 # Apply to appropriate slice for estimation 

266 if image.ndim == 3: 

267 estimation_slice = image_normalized[0] # Use first slice for 3D 

268 else: 

269 estimation_slice = image_normalized # Use the 2D image directly 

270 

271 padded = jnp.pad(estimation_slice, 1, mode='reflect') 

272 laplacian = jnp.zeros_like(estimation_slice) 

273 for i in range(3): 

274 for j in range(3): 

275 shifted = padded[i:i + estimation_slice.shape[0], j:j + estimation_slice.shape[1]] 

276 laplacian += laplacian_kernel[i, j] * shifted 

277 sigma = jnp.sqrt(2) * jnp.std(laplacian) / 6.0 

278 sigma = jnp.maximum(sigma, 0.01) # Minimum sigma 

279 

280 if h is None: 

281 h = 0.75 * sigma # Standard relationship 

282 

283 # Handle different input dimensions 

284 if image.ndim == 2: 

285 # 2D input: Process directly (called by decorator on individual slices) 

286 result = _nlm_core(image_normalized, search_window_radius, filter_radius, h, sigma) 

287 elif image.ndim == 3: 

288 # 3D input: If we get here with 3D input, it means slice_by_slice=False 

289 # because when slice_by_slice=True, the decorator handles slicing 

290 raise NotImplementedError( 

291 "3D non-local means processing is not yet implemented for JAX backend. " 

292 "Use slice_by_slice=True for 2D slice-by-slice processing." 

293 ) 

294 else: 

295 raise ValueError(f"Unexpected input dimensions: {image.ndim}D") 

296 

297 # Always rescale output to [0, 1] range to prevent uint16 clipping 

298 result = _rescale_to_unit_range(result) 

299 logger.info("Rescaled NLM output to [0, 1] range to prevent uint16 clipping") 

300 

301 return result 

302 

303 

304# Alias for convenience 

305jax_nlm_denoise = non_local_means_denoise_jax