Coverage for openhcs/processing/backends/enhance/basic_processor_jax.py: 23.8%
60 statements
« prev ^ index » next coverage.py v7.10.3, created at 2025-08-14 05:57 +0000
« prev ^ index » next coverage.py v7.10.3, created at 2025-08-14 05:57 +0000
1"""
2BaSiC (Background and Shading Correction) Implementation using JAX via BaSiCPy
4This module provides OpenHCS-compatible wrapper functions for BaSiCPy's
5JAX-based BaSiC implementation, integrating with OpenHCS memory decorators
6and pipeline system.
8Doctrinal Clauses:
9- Clause 3 — Declarative Primacy: All functions are pure and stateless
10- Clause 65 — Fail Loudly: No silent fallbacks or inferred capabilities
11- Clause 88 — No Inferred Capabilities: Explicit JAX dependency via BaSiCPy
12- Clause 273 — Memory Backend Restrictions: JAX-only implementation
13"""
14from __future__ import annotations
16import logging
17from typing import TYPE_CHECKING, Any, Optional, Union
19# Import decorator directly from core.memory to avoid circular imports
20from openhcs.core.memory.decorators import jax as jax_func
21from openhcs.core.utils import optional_import
23# For type checking only
24if TYPE_CHECKING: 24 ↛ 25line 24 didn't jump to line 25 because the condition on line 24 was never true
25 import jax.numpy as jnp
27# Import jax.numpy for runtime type hint evaluation
28try:
29 import jax.numpy as jnp
30except ImportError:
31 jnp = None
33# Import BaSiCPy as an optional dependency
34basicpy = optional_import("basicpy")
35if basicpy is not None: 35 ↛ 38line 35 didn't jump to line 38 because the condition on line 35 was always true
36 BaSiC = basicpy.BaSiC
37else:
38 BaSiC = None
40logger = logging.getLogger(__name__)
43def _validate_jax_array(array: Any, name: str = "input") -> None:
44 """
45 Validate that BaSiCPy is available and input is compatible.
47 Args:
48 array: Array to validate
49 name: Name of the array for error messages
51 Raises:
52 ImportError: If BaSiCPy is not available
53 ValueError: If the array is not compatible
54 """
55 if basicpy is None or BaSiC is None:
56 raise ImportError(
57 "BaSiCPy is not available. Please install BaSiCPy for BaSiC correction. "
58 "Install with: pip install basicpy"
59 )
61 if not hasattr(array, 'shape') or not hasattr(array, 'dtype'):
62 raise ValueError(
63 f"{name} must be an array-like object with shape and dtype attributes, "
64 f"got {type(array)}."
65 )
68@jax_func
69def basic_flatfield_correction_jax(
70 image: "jnp.ndarray",
71 max_iters: int = 50,
72 lambda_sparse: float = 0.01,
73 lambda_lowrank: float = 0.1,
74 epsilon: float = 0.1,
75 smoothness_flatfield: float = 1.0,
76 smoothness_darkfield: float = 1.0,
77 sparse_cost_darkfield: float = 0.01,
78 get_darkfield: bool = False,
79 fitting_mode: str = "ladmap",
80 working_size: Optional[Union[int, list]] = 128,
81 verbose: bool = False,
82 **kwargs
83) -> "jnp.ndarray":
84 """
85 Perform BaSiC-style illumination correction on a 3D image stack using JAX via BaSiCPy.
87 This function provides OpenHCS integration for BaSiCPy's sophisticated BaSiC
88 algorithm implementation, supporting both LADMAP and approximate fitting modes
89 with automatic parameter tuning capabilities.
91 Args:
92 image: 3D JAX array of shape (Z, Y, X)
93 max_iters: Maximum number of iterations for optimization
94 lambda_sparse: Regularization parameter for sparse component (mapped to epsilon)
95 lambda_lowrank: Regularization parameter for low-rank component (mapped to smoothness)
96 epsilon: Weight regularization term
97 smoothness_flatfield: Weight of flatfield term in Lagrangian
98 smoothness_darkfield: Weight of darkfield term in Lagrangian
99 sparse_cost_darkfield: Weight of darkfield sparse term in Lagrangian
100 get_darkfield: Whether to estimate darkfield component
101 fitting_mode: Fitting mode ('ladmap' or 'approximate')
102 working_size: Size for running computations (None means no rescaling)
103 verbose: Whether to print progress information
104 **kwargs: Additional parameters (ignored for compatibility)
106 Returns:
107 Corrected 3D JAX array of shape (Z, Y, X)
109 Raises:
110 ImportError: If BaSiCPy is not available
111 ValueError: If input is not a 3D array
112 RuntimeError: If BaSiC fitting fails
113 """
114 # Validate input and dependencies
115 _validate_jax_array(image)
117 if image.ndim != 3:
118 raise ValueError(f"Input must be a 3D array, got {image.ndim}D")
120 logger.debug(f"BaSiC correction: {image.shape} image, mode={fitting_mode}")
122 try:
123 # Convert JAX array to numpy for BaSiCPy (it handles JAX internally)
124 import numpy as np
125 image_np = np.asarray(image)
127 # Create BaSiC instance with parameters
128 basic = BaSiC(
129 # Core algorithm parameters
130 max_iterations=max_iters,
131 epsilon=epsilon,
132 smoothness_flatfield=smoothness_flatfield,
133 smoothness_darkfield=smoothness_darkfield,
134 sparse_cost_darkfield=sparse_cost_darkfield,
135 get_darkfield=get_darkfield,
136 fitting_mode=fitting_mode,
137 working_size=working_size,
139 # Optimization parameters
140 optimization_tol=1e-3,
141 optimization_tol_diff=1e-2,
142 reweighting_tol=1e-2,
143 max_reweight_iterations=10,
145 # Memory and performance
146 resize_mode="jax",
147 sort_intensity=False,
148 )
150 # Fit and transform the image
151 logger.debug("Starting BaSiC fit and transform")
152 corrected_np = basic.fit_transform(image_np, timelapse=False)
154 # Convert back to JAX array
155 import jax.numpy as jnp
156 corrected = jnp.asarray(corrected_np)
158 logger.debug(f"BaSiC correction completed: {corrected.shape}")
159 return corrected.astype(image.dtype)
161 except Exception as e:
162 logger.error(f"BaSiC correction failed: {e}")
163 raise RuntimeError(f"BaSiC flat field correction failed: {e}") from e
166@jax_func
167def basic_flatfield_correction_batch_jax(
168 image_batch: "jnp.ndarray",
169 *,
170 batch_dim: int = 0,
171 **kwargs
172) -> "jnp.ndarray":
173 """
174 Apply BaSiC flatfield correction to a batch of 3D image stacks.
176 Args:
177 image_batch: 4D JAX array of shape (B, Z, Y, X) or (Z, B, Y, X)
178 batch_dim: Dimension along which the batch is organized (0 or 1)
179 **kwargs: Additional parameters passed to basic_flatfield_correction_jax
181 Returns:
182 Corrected 4D JAX array of the same shape as input
184 Raises:
185 ImportError: If BaSiCPy is not available
186 ValueError: If input is not a 4D array or batch_dim is invalid
187 """
188 # Validate input
189 _validate_jax_array(image_batch)
191 if image_batch.ndim != 4:
192 raise ValueError(f"Input must be a 4D array, got {image_batch.ndim}D")
194 if batch_dim not in [0, 1]:
195 raise ValueError(f"batch_dim must be 0 or 1, got {batch_dim}")
197 logger.debug(f"BaSiC batch correction: {image_batch.shape}, batch_dim={batch_dim}")
199 # Process each 3D stack in the batch
200 result_list = []
202 if batch_dim == 0:
203 # Batch is organized as (B, Z, Y, X)
204 for b in range(image_batch.shape[0]):
205 corrected = basic_flatfield_correction_jax(image_batch[b], **kwargs)
206 result_list.append(corrected)
208 # Stack along batch dimension
209 import jax.numpy as jnp
210 return jnp.stack(result_list, axis=0)
212 # Batch is organized as (Z, B, Y, X)
213 for b in range(image_batch.shape[1]):
214 corrected = basic_flatfield_correction_jax(image_batch[:, b], **kwargs)
215 result_list.append(corrected)
217 # Stack along batch dimension
218 import jax.numpy as jnp
219 return jnp.stack(result_list, axis=1)