Coverage for openhcs/processing/backends/processors/jax_processor.py: 13.4%
227 statements
« prev ^ index » next coverage.py v7.10.3, created at 2025-08-14 05:57 +0000
« prev ^ index » next coverage.py v7.10.3, created at 2025-08-14 05:57 +0000
1"""
2JAX Image Processor Implementation
4This module implements the ImageProcessorInterface using JAX as the backend.
5It leverages GPU acceleration for image processing operations.
7Doctrinal Clauses:
8- Clause 3 — Declarative Primacy: All functions are pure and stateless
9- Clause 88 — No Inferred Capabilities: Explicit JAX dependency
10- Clause 106-A — Declared Memory Types: All methods specify JAX arrays
11"""
12from __future__ import annotations
14import logging
15from typing import Any, List, Optional, Tuple
17from openhcs.core.memory.decorators import jax as jax_func
18from openhcs.core.utils import optional_import
20# Import JAX as an optional dependency
21jax = optional_import("jax")
22jnp = optional_import("jax.numpy") if jax is not None else None
23lax = jax.lax if jax is not None else None
25logger = logging.getLogger(__name__)
28@jax_func
29def create_linear_weight_mask(height: int, width: int, margin_ratio: float = 0.1) -> "jnp.ndarray":
30 """
31 Create a 2D weight mask that linearly ramps from 0 at the edges to 1 in the center.
33 Args:
34 height: Height of the mask
35 width: Width of the mask
36 margin_ratio: Ratio of the margin to the image size
38 Returns:
39 2D JAX weight mask of shape (height, width)
40 """
41 # The compiler will ensure this function is only called when JAX is available
42 # No need to check for JAX availability here
44 margin_y = int(jnp.floor(height * margin_ratio))
45 margin_x = int(jnp.floor(width * margin_ratio))
47 weight_h = jnp.ones(height, dtype=jnp.float32)
48 if margin_y > 0:
49 ramp_top = jnp.linspace(0, 1, margin_y, endpoint=False)
50 ramp_bottom = jnp.linspace(1, 0, margin_y, endpoint=False)
51 weight_h = weight_h.at[:margin_y].set(ramp_top)
52 weight_h = weight_h.at[-margin_y:].set(ramp_bottom)
54 weight_x = jnp.ones(width, dtype=jnp.float32)
55 if margin_x > 0:
56 ramp_left = jnp.linspace(0, 1, margin_x, endpoint=False)
57 ramp_right = jnp.linspace(1, 0, margin_x, endpoint=False)
58 weight_x = weight_x.at[:margin_x].set(ramp_left)
59 weight_x = weight_x.at[-margin_x:].set(ramp_right)
61 # Create 2D weight mask using outer product
62 weight_mask = jnp.outer(weight_h, weight_x)
64 return weight_mask
67def _validate_3d_array(array: Any, name: str = "input") -> None:
68 """
69 Validate that the input is a 3D JAX array.
71 Args:
72 array: Array to validate
73 name: Name of the array for error messages
75 Raises:
76 TypeError: If the array is not a JAX array
77 ValueError: If the array is not 3D
78 ImportError: If JAX is not available
79 """
80 # The compiler will ensure this function is only called when JAX is available
81 # No need to check for JAX availability here
83 if not isinstance(array, jnp.ndarray):
84 raise TypeError(f"{name} must be a JAX array, got {type(array)}. "
85 f"No automatic conversion is performed to maintain explicit contracts.")
87 if array.ndim != 3:
88 raise ValueError(f"{name} must be a 3D array, got {array.ndim}D")
90@jax_func
91def _gaussian_kernel(sigma: float, kernel_size: int) -> "jnp.ndarray":
92 """
93 Create a 2D Gaussian kernel.
95 Args:
96 sigma: Standard deviation of the Gaussian kernel
97 kernel_size: Size of the kernel (must be odd)
99 Returns:
100 2D JAX array of shape (kernel_size, kernel_size)
101 """
102 # Ensure kernel_size is odd
103 if kernel_size % 2 == 0:
104 kernel_size += 1
106 # Create 1D Gaussian kernel
107 x = jnp.arange(-(kernel_size // 2), kernel_size // 2 + 1, dtype=jnp.float32)
108 kernel_1d = jnp.exp(-0.5 * (x / sigma) ** 2)
109 kernel_1d = kernel_1d / jnp.sum(kernel_1d)
111 # Create 2D Gaussian kernel
112 kernel_2d = jnp.outer(kernel_1d, kernel_1d)
114 return kernel_2d
116@jax_func
117def _gaussian_blur(image: "jnp.ndarray", sigma: float) -> "jnp.ndarray":
118 """
119 Apply Gaussian blur to a 2D image.
121 Args:
122 image: 2D JAX array of shape (H, W)
123 sigma: Standard deviation of the Gaussian kernel
125 Returns:
126 Blurred 2D JAX array of shape (H, W)
127 """
128 # Calculate kernel size based on sigma
129 kernel_size = max(3, int(2 * 4 * sigma + 1))
131 # Create Gaussian kernel
132 kernel = _gaussian_kernel(sigma, kernel_size)
134 # Pad the image for convolution
135 pad_size = kernel_size // 2
136 padded = jnp.pad(image, ((pad_size, pad_size), (pad_size, pad_size)), mode='reflect')
138 # Apply convolution
139 # JAX doesn't have a direct 2D convolution function for arbitrary kernels
140 # We'll use lax.conv_general_dilated with appropriate parameters
142 # Reshape inputs for lax.conv_general_dilated
143 kernel_reshaped = kernel.reshape(kernel_size, kernel_size, 1, 1)
144 padded_reshaped = padded.reshape(1, padded.shape[0], padded.shape[1], 1)
146 # Apply convolution
147 result = lax.conv_general_dilated(
148 padded_reshaped,
149 kernel_reshaped,
150 window_strides=(1, 1),
151 padding='VALID',
152 dimension_numbers=('NHWC', 'HWIO', 'NHWC')
153 )
155 # Reshape back to 2D
156 return result[0, :, :, 0]
158@jax_func
159def sharpen(image: "jnp.ndarray", radius: float = 1.0, amount: float = 1.0
160) -> "jnp.ndarray":
161 """
162 Sharpen a 3D image using unsharp masking.
164 This applies sharpening to each Z-slice independently.
166 Args:
167 image: 3D JAX array of shape (Z, Y, X)
168 radius: Radius of Gaussian blur
169 amount: Sharpening strength
171 Returns:
172 Sharpened 3D JAX array of shape (Z, Y, X)
173 """
174 _validate_3d_array(image)
176 # Store original dtype
177 dtype = image.dtype
179 # Process each Z-slice independently
180 result_list = []
182 for z in range(image.shape[0]):
183 # Convert to float for processing
184 slice_float = image[z].astype(jnp.float32) / jnp.max(image[z])
186 # Create blurred version for unsharp mask
187 blurred = _gaussian_blur(slice_float, sigma=radius)
189 # Apply unsharp mask: original + amount * (original - blurred)
190 sharpened = slice_float + amount * (slice_float - blurred)
192 # Clip to valid range
193 sharpened = jnp.clip(sharpened, 0.0, 1.0)
195 # Scale back to original range
196 min_val = jnp.min(sharpened)
197 max_val = jnp.max(sharpened)
198 if max_val > min_val:
199 sharpened = (sharpened - min_val) * 65535.0 / (max_val - min_val)
201 result_list.append(sharpened)
203 # Stack results back into a 3D array
204 result = jnp.stack(result_list, axis=0)
206 # Convert back to original dtype
207 if jnp.issubdtype(dtype, jnp.integer):
208 result = jnp.clip(result, 0, 65535).astype(jnp.uint16)
209 else:
210 result = result.astype(dtype)
212 return result
214@jax_func
215def percentile_normalize(
216 image: "jnp.ndarray",
217 low_percentile: float = 1.0,
218 high_percentile: float = 99.0,
219 target_min: float = 0.0,
220 target_max: float = 65535.0
221) -> "jnp.ndarray":
222 """
223 Normalize a 3D image using percentile-based contrast stretching.
225 This applies normalization to each Z-slice independently.
227 Args:
228 image: 3D JAX array of shape (Z, Y, X)
229 low_percentile: Lower percentile (0-100)
230 high_percentile: Upper percentile (0-100)
231 target_min: Target minimum value
232 target_max: Target maximum value
234 Returns:
235 Normalized 3D JAX array of shape (Z, Y, X)
236 """
237 _validate_3d_array(image)
239 # Process each Z-slice independently
240 result_list = []
242 # Define a function to normalize a single slice
243 def normalize_single_slice(slice_idx):
244 slice_data = image[slice_idx]
246 # Get percentile values for this slice
247 p_low = jnp.percentile(slice_data, low_percentile)
248 p_high = jnp.percentile(slice_data, high_percentile)
250 # Avoid division by zero
251 equal_percentiles = jnp.isclose(p_high, p_low)
253 # Function to normalize when percentiles are different
254 def normalize_slice(args):
255 p_low, p_high, slice_data = args
256 # Clip and normalize to target range
257 clipped = jnp.clip(slice_data.astype(jnp.float32), p_low, p_high)
258 scale = (target_max - target_min) / (p_high - p_low)
259 normalized = (clipped - p_low) * scale + target_min
260 return normalized
262 # Function for the case where percentiles are equal
263 def return_constant(args):
264 _, _, slice_data = args
265 return jnp.ones_like(slice_data, dtype=jnp.float32) * target_min
267 # Handle the case where percentiles are equal
268 normalized = jax.lax.cond(
269 equal_percentiles,
270 return_constant,
271 normalize_slice,
272 (p_low, p_high, slice_data)
273 )
275 return normalized
277 # Process each slice
278 for z in range(image.shape[0]):
279 result_list.append(normalize_single_slice(z))
281 # Stack results back into a 3D array
282 result = jnp.stack(result_list, axis=0)
284 # Convert to uint16
285 result = jnp.clip(result, 0, 65535).astype(jnp.uint16)
287 return result
289@jax_func
290def stack_percentile_normalize(
291 stack: "jnp.ndarray",
292 low_percentile: float = 1.0,
293 high_percentile: float = 99.0,
294 target_min: float = 0.0,
295 target_max: float = 65535.0
296) -> "jnp.ndarray":
297 """
298 Normalize a stack using global percentile-based contrast stretching.
300 This ensures consistent normalization across all Z-slices by computing
301 global percentiles across the entire stack.
303 Args:
304 stack: 3D JAX array of shape (Z, Y, X)
305 low_percentile: Lower percentile (0-100)
306 high_percentile: Upper percentile (0-100)
307 target_min: Target minimum value
308 target_max: Target maximum value
310 Returns:
311 Normalized 3D JAX array of shape (Z, Y, X)
312 """
313 _validate_3d_array(stack)
315 # Calculate global percentiles across the entire stack
316 p_low = jnp.percentile(stack, low_percentile)
317 p_high = jnp.percentile(stack, high_percentile)
319 # Avoid division by zero
320 if p_high == p_low:
321 return jnp.ones_like(stack) * target_min
323 # Clip and normalize to target range (match NumPy implementation exactly)
324 clipped = jnp.clip(stack, p_low, p_high)
325 normalized = (clipped - p_low) * (target_max - target_min) / (p_high - p_low) + target_min
326 normalized = normalized.astype(jnp.uint16)
328 return normalized
330@jax_func
331def create_composite(
332 stack: "jnp.ndarray", weights: Optional[List[float]] = None
333) -> "jnp.ndarray":
334 """
335 Create a composite image from a 3D stack where each slice is a channel.
337 Args:
338 stack: 3D JAX array of shape (N, Y, X) where N is number of channel slices
339 weights: List of weights for each slice. If None, equal weights are used.
341 Returns:
342 Composite 3D JAX array of shape (1, Y, X)
343 """
344 # Validate input is 3D array
345 _validate_3d_array(stack)
347 n_slices, height, width = stack.shape
349 # Default weights if none provided
350 if weights is None:
351 # Equal weights for all slices
352 weights = [1.0 / n_slices] * n_slices
353 elif isinstance(weights, (list, tuple)):
354 # Convert tuple to list if needed
355 weights = list(weights)
356 if len(weights) != n_slices:
357 raise ValueError(f"Number of weights ({len(weights)}) must match number of slices ({n_slices})")
358 else:
359 raise TypeError(f"weights must be a list of values or None, got {type(weights)}: {weights}")
361 # Normalize weights to sum to 1
362 weight_sum = sum(weights)
363 if weight_sum == 0:
364 raise ValueError("Sum of weights cannot be zero")
365 normalized_weights = [w / weight_sum for w in weights]
367 # Convert weights to JAX array for efficient computation
368 # CRITICAL: Use float32 for weights to preserve fractional values, not stack.dtype
369 weights_array = jnp.array(normalized_weights, dtype=jnp.float32)
371 # Reshape weights for broadcasting: (N, 1, 1) to multiply with (N, Y, X)
372 weights_array = weights_array.reshape(n_slices, 1, 1)
374 # Create composite by weighted sum along the first axis
375 # Convert stack to float32 for computation to avoid precision loss
376 stack_float = stack.astype(jnp.float32)
377 weighted_stack = stack_float * weights_array
378 composite_slice = jnp.sum(weighted_stack, axis=0, keepdims=True) # Keep as (1, Y, X)
380 # Convert back to original dtype
381 composite_slice = composite_slice.astype(stack.dtype)
383 return composite_slice
385@jax_func
386def apply_mask(image: "jnp.ndarray", mask: "jnp.ndarray") -> "jnp.ndarray":
387 """
388 Apply a mask to a 3D image.
390 This applies the mask to each Z-slice independently if mask is 2D,
391 or applies the 3D mask directly if mask is 3D.
393 Args:
394 image: 3D JAX array of shape (Z, Y, X)
395 mask: 3D JAX array of shape (Z, Y, X) or 2D JAX array of shape (Y, X)
397 Returns:
398 Masked 3D JAX array of shape (Z, Y, X)
399 """
400 _validate_3d_array(image)
402 # Handle 2D mask (apply to each Z-slice)
403 if isinstance(mask, jnp.ndarray) and mask.ndim == 2:
404 if mask.shape != image.shape[1:]:
405 raise ValueError(
406 f"2D mask shape {mask.shape} doesn't match image slice shape {image.shape[1:]}"
407 )
409 # Apply 2D mask to each Z-slice
410 result_list = []
411 for z in range(image.shape[0]):
412 result_list.append(image[z].astype(jnp.float32) * mask.astype(jnp.float32))
414 result = jnp.stack(result_list, axis=0)
415 return result.astype(image.dtype)
417 # Handle 3D mask
418 if isinstance(mask, jnp.ndarray) and mask.ndim == 3:
419 if mask.shape != image.shape:
420 raise ValueError(
421 f"3D mask shape {mask.shape} doesn't match image shape {image.shape}"
422 )
424 # Apply 3D mask directly
425 masked = image.astype(jnp.float32) * mask.astype(jnp.float32)
426 return masked.astype(image.dtype)
428 # If we get here, the mask is neither 2D nor 3D JAX array
429 raise TypeError(f"mask must be a 2D or 3D JAX array, got {type(mask)}")
431@jax_func
432def create_weight_mask(
433 shape: Tuple[int, int], margin_ratio: float = 0.1
434) -> "jnp.ndarray":
435 """
436 Create a weight mask for blending images.
438 Args:
439 shape: Shape of the mask (height, width)
440 margin_ratio: Ratio of image size to use as margin
442 Returns:
443 2D JAX weight mask of shape (Y, X)
444 """
445 if not isinstance(shape, tuple) or len(shape) != 2:
446 raise TypeError("shape must be a tuple of (height, width)")
448 height, width = shape
449 return create_linear_weight_mask(height, width, margin_ratio)
451@jax_func
452def max_projection(stack: "jnp.ndarray") -> "jnp.ndarray":
453 """
454 Create a maximum intensity projection from a Z-stack.
456 Args:
457 stack: 3D JAX array of shape (Z, Y, X)
459 Returns:
460 3D JAX array of shape (1, Y, X)
461 """
462 _validate_3d_array(stack)
464 # Create max projection
465 projection_2d = jnp.max(stack, axis=0)
466 return jnp.expand_dims(projection_2d, axis=0)
468@jax_func
469def mean_projection(stack: "jnp.ndarray") -> "jnp.ndarray":
470 """
471 Create a mean intensity projection from a Z-stack.
473 Args:
474 stack: 3D JAX array of shape (Z, Y, X)
476 Returns:
477 3D JAX array of shape (1, Y, X)
478 """
479 _validate_3d_array(stack)
481 # Create mean projection
482 projection_2d = jnp.mean(stack.astype(jnp.float32), axis=0).astype(stack.dtype)
483 return jnp.expand_dims(projection_2d, axis=0)
485@jax_func
486def stack_equalize_histogram(
487 stack: "jnp.ndarray",
488 bins: int = 65536,
489 range_min: float = 0.0,
490 range_max: float = 65535.0
491) -> "jnp.ndarray":
492 """
493 Apply histogram equalization to an entire stack.
495 This ensures consistent contrast enhancement across all Z-slices by
496 computing a global histogram across the entire stack.
498 Args:
499 stack: 3D JAX array of shape (Z, Y, X)
500 bins: Number of bins for histogram computation
501 range_min: Minimum value for histogram range
502 range_max: Maximum value for histogram range
504 Returns:
505 Equalized 3D JAX array of shape (Z, Y, X)
506 """
507 _validate_3d_array(stack)
509 # Flatten the entire stack to compute the global histogram
510 flat_stack = stack.flatten()
512 # Calculate the histogram
513 hist, _ = jnp.histogram(flat_stack, bins=bins, range=(range_min, range_max))
515 # Calculate cumulative distribution function (CDF)
516 cdf = jnp.cumsum(hist)
518 # Normalize the CDF to the range [0, 65535]
519 # Avoid division by zero
520 cdf_max = jnp.max(cdf)
521 cdf_normalized = jax.lax.cond(
522 cdf_max > 0,
523 lambda x: 65535.0 * x / cdf_max,
524 lambda x: x,
525 cdf
526 )
528 # Scale input values to bin indices
529 bin_width = (range_max - range_min) / bins
530 indices = jnp.clip(
531 jnp.floor((flat_stack - range_min) / bin_width).astype(jnp.int32),
532 0, bins - 1
533 )
535 # Look up CDF values
536 equalized_flat = jnp.take(cdf_normalized, indices)
538 # Reshape back to original shape
539 equalized_stack = equalized_flat.reshape(stack.shape)
541 # Convert to uint16
542 return equalized_stack.astype(jnp.uint16)
544@jax_func
545def create_projection(
546 stack: "jnp.ndarray", method: str = "max_projection"
547) -> "jnp.ndarray":
548 """
549 Create a projection from a stack using the specified method.
551 Args:
552 stack: 3D JAX array of shape (Z, Y, X)
553 method: Projection method (max_projection, mean_projection)
555 Returns:
556 3D JAX array of shape (1, Y, X)
557 """
558 _validate_3d_array(stack)
560 if method == "max_projection":
561 return max_projection(stack)
563 if method == "mean_projection":
564 return mean_projection(stack)
566 # FAIL FAST: No fallback projection methods
567 raise ValueError(f"Unknown projection method: {method}. Valid methods: max_projection, mean_projection")
569@jax_func
570def tophat(
571 image: "jnp.ndarray",
572 selem_radius: int = 50,
573 downsample_factor: int = 4
574) -> "jnp.ndarray":
575 """
576 Apply white top-hat filter to a 3D image for background removal.
578 This applies the filter to each Z-slice independently using JAX's
579 native operations.
581 Args:
582 image: 3D JAX array of shape (Z, Y, X)
583 selem_radius: Radius of the structuring element disk
584 downsample_factor: Factor by which to downsample the image for processing
586 Returns:
587 Filtered 3D JAX array of shape (Z, Y, X)
588 """
589 _validate_3d_array(image)
591 # Process each Z-slice independently
592 result_list = []
594 # Define a function to process a single slice
595 def process_slice(slice_idx):
596 slice_data = image[slice_idx]
597 input_dtype = slice_data.dtype
599 # 1) Downsample
600 # JAX doesn't have a direct resize function, so we'll use a simple approach
601 # This is a simplified version and might not match scikit-image's resize exactly
602 new_h = slice_data.shape[0] // downsample_factor
603 new_w = slice_data.shape[1] // downsample_factor
605 # Simple block averaging for downsampling
606 slice_data_float = slice_data.astype(jnp.float32)
607 blocks = slice_data_float.reshape(
608 new_h, downsample_factor, new_w, downsample_factor
609 )
610 image_small = jnp.mean(blocks, axis=(1, 3))
612 # 2) Create a circular structuring element
613 small_selem_radius = max(1, selem_radius // downsample_factor)
615 # Create grid for structuring element
616 y_range = jnp.arange(-small_selem_radius, small_selem_radius + 1)
617 x_range = jnp.arange(-small_selem_radius, small_selem_radius + 1)
618 grid_y, grid_x = jnp.meshgrid(y_range, x_range, indexing='ij')
620 # Create circular mask
621 small_mask = (grid_x**2 + grid_y**2) <= small_selem_radius**2
622 small_selem = small_mask.astype(jnp.float32)
624 # 3) Apply white top-hat
625 # JAX doesn't have built-in morphological operations
626 # This is a simplified implementation that approximates the behavior
628 # Pad the image for convolution
629 pad_size = small_selem_radius
630 padded = jnp.pad(image_small, pad_size, mode='reflect')
632 # Implement erosion (minimum filter)
633 # For each pixel, find the minimum value in the neighborhood defined by the structuring element
634 eroded = jnp.zeros_like(image_small)
636 # This is a simplified approach - in a real implementation, we would use a more efficient method
637 for y in range(new_h):
638 for x in range(new_w):
639 # Extract neighborhood
640 neighborhood = padded[y:y+2*pad_size+1, x:x+2*pad_size+1]
641 # Apply structuring element and find minimum
642 masked_values = jnp.where(small_selem, neighborhood, jnp.inf)
643 eroded = eroded.at[y, x].set(jnp.min(masked_values))
645 # Implement dilation (maximum filter)
646 # For each pixel, find the maximum value in the neighborhood defined by the structuring element
647 opened = jnp.zeros_like(image_small)
649 # Pad the eroded image
650 padded_eroded = jnp.pad(eroded, pad_size, mode='reflect')
652 # This is a simplified approach - in a real implementation, we would use a more efficient method
653 for y in range(new_h):
654 for x in range(new_w):
655 # Extract neighborhood
656 neighborhood = padded_eroded[y:y+2*pad_size+1, x:x+2*pad_size+1]
657 # Apply structuring element and find maximum
658 masked_values = jnp.where(small_selem, neighborhood, -jnp.inf)
659 opened = opened.at[y, x].set(jnp.max(masked_values))
661 # White top-hat is original minus opening
662 tophat_small = image_small - opened
664 # 4) Calculate background
665 background_small = image_small - tophat_small
667 # 5) Upscale background to original size
668 # Simple nearest neighbor upscaling
669 background_large = jnp.repeat(
670 jnp.repeat(background_small, downsample_factor, axis=0),
671 downsample_factor, axis=1
672 )
674 # Crop to original size if needed
675 if background_large.shape != slice_data.shape:
676 background_large = background_large[:slice_data.shape[0], :slice_data.shape[1]]
678 # 6) Subtract background and clip negative values
679 slice_result = jnp.maximum(slice_data.astype(jnp.float32) - background_large, 0)
681 # 7) Convert back to original data type
682 return slice_result.astype(input_dtype)
684 # Process each slice
685 for z in range(image.shape[0]):
686 result_list.append(process_slice(z))
688 # Stack results back into a 3D array
689 result = jnp.stack(result_list, axis=0)
691 return result