Coverage for openhcs/processing/backends/enhance/jax_nlm_processor.py: 15.2%
103 statements
« prev ^ index » next coverage.py v7.10.3, created at 2025-08-14 05:57 +0000
« prev ^ index » next coverage.py v7.10.3, created at 2025-08-14 05:57 +0000
1"""
2JAX-based Non-Local Means Denoising Implementation
4This module provides OpenHCS-decorated wrapper functions for non-local means denoising
5using JAX backend with automatic output rescaling to prevent clipping issues when
6converting to uint16.
8Non-local means is an advanced denoising algorithm that preserves fine details
9and textures by comparing patches across the entire image rather than just
10local neighborhoods. This JAX implementation provides GPU acceleration with
11automatic output normalization.
13Doctrinal Clauses:
14- Clause 3 — Declarative Primacy: All functions are pure and stateless
15- Clause 65 — Fail Loudly: No silent fallbacks or inferred capabilities
16- Clause 88 — No Inferred Capabilities: Explicit JAX dependency
17- Clause 273 — Memory Backend Restrictions: JAX-only implementation
18"""
19from __future__ import annotations
21import logging
22from typing import Optional
24from openhcs.utils.import_utils import optional_import
25from openhcs.core.memory.decorators import jax as jax_func
27# Import JAX modules as optional dependencies
28jax = optional_import("jax")
29jnp = optional_import("jax.numpy") if jax is not None else None
30lax = jax.lax if jax is not None else None
31tree_util = jax.tree_util if jax is not None else None
33logger = logging.getLogger(__name__)
36def _validate_jax_array(image: "jnp.ndarray") -> None:
37 """Validate that input is a JAX array (2D or 3D)."""
38 if jax is None or jnp is None:
39 raise ImportError("JAX is required for JAX NLM functions")
41 if not isinstance(image, jnp.ndarray):
42 raise TypeError(f"Input must be a jax.numpy.ndarray, got {type(image)}")
44 if image.ndim not in [2, 3]:
45 raise ValueError(f"Input must be a 2D or 3D array, got {image.ndim}D array")
48def _rescale_to_unit_range(image: "jnp.ndarray") -> "jnp.ndarray":
49 """
50 Rescale image so that the minimum value across the entire stack is 0
51 and the maximum value is 1.
53 This prevents clipping issues when converting to uint16.
55 Args:
56 image: 3D JAX array of shape (Z, Y, X)
58 Returns:
59 Rescaled 3D JAX array with values in [0, 1] range
60 """
61 # Calculate global min and max across the entire stack
62 global_min = jnp.min(image)
63 global_max = jnp.max(image)
65 # Avoid division by zero
66 range_val = global_max - global_min
68 # If all values are the same, return zeros
69 def rescale_normal(args):
70 image, global_min, range_val = args
71 return (image - global_min) / range_val
73 def return_zeros(args):
74 image, _, _ = args
75 return jnp.zeros_like(image)
77 # Use JAX conditional to handle zero range
78 result = lax.cond(
79 range_val > 0,
80 rescale_normal,
81 return_zeros,
82 (image, global_min, range_val)
83 )
85 return result
88def _ixs(y_ixs, x_ixs):
89 """Create meshgrid for vectorized operations."""
90 return jnp.meshgrid(x_ixs, y_ixs)
93def _vmap_2d(f, y_ixs, x_ixs):
94 """Apply function f over 2D grid using vectorized mapping."""
95 _x, _y = _ixs(y_ixs, x_ixs)
96 return jax.vmap(jax.vmap(f))(_y, _x)
99# Use jax.tree_util.Partial instead of functools.partial for better JAX integration
100# - jax.tree_util.Partial is a JAX pytree, compatible with JAX transformations
101# - Enables proper serialization and JIT compilation
102# - Better performance with JAX's internal machinery
103@tree_util.Partial(jax.jit, static_argnums=(1, 2)) if jax is not None and tree_util is not None else lambda f: f
104def _nlm_core(img: "jnp.ndarray", search_window_radius: int, filter_radius: int, h: float, sigma: float) -> "jnp.ndarray":
105 """
106 Core non-local means implementation based on Buades et al.
108 This is a vectorized and JIT-compiled implementation adapted from:
109 https://github.com/bhchiang/nlm
111 Args:
112 img: 2D image array
113 search_window_radius: Radius of search window
114 filter_radius: Radius of comparison patches
115 h: Filter strength parameter
116 sigma: Noise standard deviation
118 Returns:
119 Denoised 2D image
120 """
121 _h, _w = img.shape
122 pad = search_window_radius
123 img_pad = jnp.pad(img, pad, mode='reflect')
125 filter_length = 2 * filter_radius + 1
126 search_window_length = 2 * search_window_radius + 1
128 win_y_ixs = win_x_ixs = jnp.arange(search_window_length - filter_length + 1)
129 filter_size = (filter_length, filter_length)
131 def compute(y, x):
132 # (y + pad, x + pad) are the center of the current neighborhood
133 win_center_y = y + pad
134 win_center_x = x + pad
136 center_patch = lax.dynamic_slice(
137 img_pad,
138 (win_center_y - filter_radius, win_center_x - filter_radius),
139 filter_size
140 )
142 # Iterate over all patches in this neighborhood
143 def _compare(center):
144 center_y, center_x = center
145 patch = lax.dynamic_slice(
146 img_pad,
147 (center_y - filter_radius, center_x - filter_radius),
148 filter_size
149 )
150 d2 = jnp.sum((patch - center_patch) ** 2) / (filter_length ** 2)
151 weight = jnp.exp(-(jnp.maximum(d2 - 2 * (sigma**2), 0) / (h**2)))
152 intensity = img_pad[center_y, center_x]
153 return (weight, intensity)
155 def compare(patch_y, patch_x):
156 patch_center_y = patch_y + filter_radius
157 patch_center_x = patch_x + filter_radius
159 # Skip if patch is out of image boundaries or this is the center patch
160 skip = (lax.lt(patch_center_y, pad) |
161 lax.ge(patch_center_y, _h + pad) |
162 lax.lt(patch_center_x, pad) |
163 lax.ge(patch_center_x, _w + pad) |
164 (lax.eq(patch_center_y, win_center_y) & lax.eq(patch_center_x, win_center_x)))
166 return lax.cond(
167 skip,
168 lambda _: (0., 0.),
169 _compare,
170 (patch_center_y, patch_center_x)
171 )
173 weights, intensities = _vmap_2d(compare, y + win_y_ixs, x + win_x_ixs)
175 # Use max weight for the center patch
176 max_weight = jnp.max(weights)
177 total_weight = jnp.sum(weights) + max_weight
178 pixel = ((jnp.sum(weights * intensities) +
179 max_weight * img_pad[win_center_y, win_center_x]) / total_weight)
181 return pixel
183 h_ixs = jnp.arange(_h)
184 w_ixs = jnp.arange(_w)
185 out = _vmap_2d(compute, h_ixs, w_ixs)
187 return out
190@jax_func
191def non_local_means_denoise_jax(
192 image: "jnp.ndarray",
193 *,
194 search_window_radius: int = 7,
195 filter_radius: int = 1,
196 h: Optional[float] = None,
197 sigma: Optional[float] = None,
198 slice_by_slice: bool = False,
199 **kwargs
200) -> "jnp.ndarray":
201 """
202 Apply Non-Local Means denoising to image(s) using JAX.
204 This function applies vectorized and JIT-compiled non-local means denoising
205 based on the implementation by Buades et al. The output is automatically
206 rescaled to [0, 1] range to prevent clipping issues when converting to uint16.
208 Can handle both 2D and 3D inputs:
209 - 2D input: Direct processing (when called by decorator on individual slices)
210 - 3D input: Slice-by-slice processing or raises error for 3D mode
212 Args:
213 image: 2D JAX array of shape (Y, X) or 3D JAX array of shape (Z, Y, X)
214 search_window_radius: Radius of search window (default: 7)
215 filter_radius: Radius of comparison patches (default: 1)
216 h: Filter strength parameter (default: auto-estimated from image)
217 sigma: Noise standard deviation (default: auto-estimated from image)
218 slice_by_slice: Process each Z-slice independently (default: False, but effectively True).
219 If explicitly set to False, raises NotImplementedError for 3D processing.
220 **kwargs: Additional arguments (ignored for compatibility)
222 Returns:
223 Denoised JAX array of same shape as input with values always rescaled to [0, 1] range
225 Raises:
226 ImportError: If JAX is not available
227 TypeError: If input is not a jax.numpy.ndarray
228 ValueError: If input is not 2D or 3D
229 NotImplementedError: If slice_by_slice=False (3D processing not yet implemented)
231 Additional OpenHCS Parameters
232 -----------------------------
233 slice_by_slice : bool, optional (default: False, but effectively True)
234 If True or not explicitly set to False, process 3D arrays slice-by-slice using
235 2D non-local means. If explicitly set to False, raises NotImplementedError.
236 Note: 3D processing is not yet implemented for JAX backend.
237 """
238 _validate_jax_array(image)
240 if jax is None or jnp is None:
241 raise ImportError(
242 "JAX is required for this function. "
243 "Install with: pip install jax"
244 )
246 # Store original dtype for reference
247 original_dtype = image.dtype
249 # Convert to float32 for processing and normalize to [0, 1] range
250 image_float = image.astype(jnp.float32)
252 # Normalize input to [0, 1] for consistent parameter behavior
253 img_min = jnp.min(image_float)
254 img_max = jnp.max(image_float)
255 if img_max > img_min:
256 image_normalized = (image_float - img_min) / (img_max - img_min)
257 else:
258 image_normalized = jnp.zeros_like(image_float)
260 # Auto-estimate parameters if not provided
261 if sigma is None:
262 # Simple noise estimation using Laplacian
263 laplacian_kernel = jnp.array([[0, -1, 0], [-1, 4, -1], [0, -1, 0]], dtype=jnp.float32)
265 # Apply to appropriate slice for estimation
266 if image.ndim == 3:
267 estimation_slice = image_normalized[0] # Use first slice for 3D
268 else:
269 estimation_slice = image_normalized # Use the 2D image directly
271 padded = jnp.pad(estimation_slice, 1, mode='reflect')
272 laplacian = jnp.zeros_like(estimation_slice)
273 for i in range(3):
274 for j in range(3):
275 shifted = padded[i:i + estimation_slice.shape[0], j:j + estimation_slice.shape[1]]
276 laplacian += laplacian_kernel[i, j] * shifted
277 sigma = jnp.sqrt(2) * jnp.std(laplacian) / 6.0
278 sigma = jnp.maximum(sigma, 0.01) # Minimum sigma
280 if h is None:
281 h = 0.75 * sigma # Standard relationship
283 # Handle different input dimensions
284 if image.ndim == 2:
285 # 2D input: Process directly (called by decorator on individual slices)
286 result = _nlm_core(image_normalized, search_window_radius, filter_radius, h, sigma)
287 elif image.ndim == 3:
288 # 3D input: If we get here with 3D input, it means slice_by_slice=False
289 # because when slice_by_slice=True, the decorator handles slicing
290 raise NotImplementedError(
291 "3D non-local means processing is not yet implemented for JAX backend. "
292 "Use slice_by_slice=True for 2D slice-by-slice processing."
293 )
294 else:
295 raise ValueError(f"Unexpected input dimensions: {image.ndim}D")
297 # Always rescale output to [0, 1] range to prevent uint16 clipping
298 result = _rescale_to_unit_range(result)
299 logger.info("Rescaled NLM output to [0, 1] range to prevent uint16 clipping")
301 return result
304# Alias for convenience
305jax_nlm_denoise = non_local_means_denoise_jax