Coverage for openhcs/processing/backends/enhance/basic_processor_jax.py: 23.8%

60 statements  

« prev     ^ index     » next       coverage.py v7.10.3, created at 2025-08-14 05:57 +0000

1""" 

2BaSiC (Background and Shading Correction) Implementation using JAX via BaSiCPy 

3 

4This module provides OpenHCS-compatible wrapper functions for BaSiCPy's 

5JAX-based BaSiC implementation, integrating with OpenHCS memory decorators 

6and pipeline system. 

7 

8Doctrinal Clauses: 

9- Clause 3 — Declarative Primacy: All functions are pure and stateless 

10- Clause 65 — Fail Loudly: No silent fallbacks or inferred capabilities 

11- Clause 88 — No Inferred Capabilities: Explicit JAX dependency via BaSiCPy 

12- Clause 273 — Memory Backend Restrictions: JAX-only implementation 

13""" 

14from __future__ import annotations 

15 

16import logging 

17from typing import TYPE_CHECKING, Any, Optional, Union 

18 

19# Import decorator directly from core.memory to avoid circular imports 

20from openhcs.core.memory.decorators import jax as jax_func 

21from openhcs.core.utils import optional_import 

22 

23# For type checking only 

24if TYPE_CHECKING: 24 ↛ 25line 24 didn't jump to line 25 because the condition on line 24 was never true

25 import jax.numpy as jnp 

26 

27# Import jax.numpy for runtime type hint evaluation 

28try: 

29 import jax.numpy as jnp 

30except ImportError: 

31 jnp = None 

32 

33# Import BaSiCPy as an optional dependency 

34basicpy = optional_import("basicpy") 

35if basicpy is not None: 35 ↛ 38line 35 didn't jump to line 38 because the condition on line 35 was always true

36 BaSiC = basicpy.BaSiC 

37else: 

38 BaSiC = None 

39 

40logger = logging.getLogger(__name__) 

41 

42 

43def _validate_jax_array(array: Any, name: str = "input") -> None: 

44 """ 

45 Validate that BaSiCPy is available and input is compatible. 

46 

47 Args: 

48 array: Array to validate 

49 name: Name of the array for error messages 

50 

51 Raises: 

52 ImportError: If BaSiCPy is not available 

53 ValueError: If the array is not compatible 

54 """ 

55 if basicpy is None or BaSiC is None: 

56 raise ImportError( 

57 "BaSiCPy is not available. Please install BaSiCPy for BaSiC correction. " 

58 "Install with: pip install basicpy" 

59 ) 

60 

61 if not hasattr(array, 'shape') or not hasattr(array, 'dtype'): 

62 raise ValueError( 

63 f"{name} must be an array-like object with shape and dtype attributes, " 

64 f"got {type(array)}." 

65 ) 

66 

67 

68@jax_func 

69def basic_flatfield_correction_jax( 

70 image: "jnp.ndarray", 

71 max_iters: int = 50, 

72 lambda_sparse: float = 0.01, 

73 lambda_lowrank: float = 0.1, 

74 epsilon: float = 0.1, 

75 smoothness_flatfield: float = 1.0, 

76 smoothness_darkfield: float = 1.0, 

77 sparse_cost_darkfield: float = 0.01, 

78 get_darkfield: bool = False, 

79 fitting_mode: str = "ladmap", 

80 working_size: Optional[Union[int, list]] = 128, 

81 verbose: bool = False, 

82 **kwargs 

83) -> "jnp.ndarray": 

84 """ 

85 Perform BaSiC-style illumination correction on a 3D image stack using JAX via BaSiCPy. 

86 

87 This function provides OpenHCS integration for BaSiCPy's sophisticated BaSiC 

88 algorithm implementation, supporting both LADMAP and approximate fitting modes 

89 with automatic parameter tuning capabilities. 

90 

91 Args: 

92 image: 3D JAX array of shape (Z, Y, X) 

93 max_iters: Maximum number of iterations for optimization 

94 lambda_sparse: Regularization parameter for sparse component (mapped to epsilon) 

95 lambda_lowrank: Regularization parameter for low-rank component (mapped to smoothness) 

96 epsilon: Weight regularization term 

97 smoothness_flatfield: Weight of flatfield term in Lagrangian 

98 smoothness_darkfield: Weight of darkfield term in Lagrangian 

99 sparse_cost_darkfield: Weight of darkfield sparse term in Lagrangian 

100 get_darkfield: Whether to estimate darkfield component 

101 fitting_mode: Fitting mode ('ladmap' or 'approximate') 

102 working_size: Size for running computations (None means no rescaling) 

103 verbose: Whether to print progress information 

104 **kwargs: Additional parameters (ignored for compatibility) 

105 

106 Returns: 

107 Corrected 3D JAX array of shape (Z, Y, X) 

108 

109 Raises: 

110 ImportError: If BaSiCPy is not available 

111 ValueError: If input is not a 3D array 

112 RuntimeError: If BaSiC fitting fails 

113 """ 

114 # Validate input and dependencies 

115 _validate_jax_array(image) 

116 

117 if image.ndim != 3: 

118 raise ValueError(f"Input must be a 3D array, got {image.ndim}D") 

119 

120 logger.debug(f"BaSiC correction: {image.shape} image, mode={fitting_mode}") 

121 

122 try: 

123 # Convert JAX array to numpy for BaSiCPy (it handles JAX internally) 

124 import numpy as np 

125 image_np = np.asarray(image) 

126 

127 # Create BaSiC instance with parameters 

128 basic = BaSiC( 

129 # Core algorithm parameters 

130 max_iterations=max_iters, 

131 epsilon=epsilon, 

132 smoothness_flatfield=smoothness_flatfield, 

133 smoothness_darkfield=smoothness_darkfield, 

134 sparse_cost_darkfield=sparse_cost_darkfield, 

135 get_darkfield=get_darkfield, 

136 fitting_mode=fitting_mode, 

137 working_size=working_size, 

138 

139 # Optimization parameters 

140 optimization_tol=1e-3, 

141 optimization_tol_diff=1e-2, 

142 reweighting_tol=1e-2, 

143 max_reweight_iterations=10, 

144 

145 # Memory and performance 

146 resize_mode="jax", 

147 sort_intensity=False, 

148 ) 

149 

150 # Fit and transform the image 

151 logger.debug("Starting BaSiC fit and transform") 

152 corrected_np = basic.fit_transform(image_np, timelapse=False) 

153 

154 # Convert back to JAX array 

155 import jax.numpy as jnp 

156 corrected = jnp.asarray(corrected_np) 

157 

158 logger.debug(f"BaSiC correction completed: {corrected.shape}") 

159 return corrected.astype(image.dtype) 

160 

161 except Exception as e: 

162 logger.error(f"BaSiC correction failed: {e}") 

163 raise RuntimeError(f"BaSiC flat field correction failed: {e}") from e 

164 

165 

166@jax_func 

167def basic_flatfield_correction_batch_jax( 

168 image_batch: "jnp.ndarray", 

169 *, 

170 batch_dim: int = 0, 

171 **kwargs 

172) -> "jnp.ndarray": 

173 """ 

174 Apply BaSiC flatfield correction to a batch of 3D image stacks. 

175 

176 Args: 

177 image_batch: 4D JAX array of shape (B, Z, Y, X) or (Z, B, Y, X) 

178 batch_dim: Dimension along which the batch is organized (0 or 1) 

179 **kwargs: Additional parameters passed to basic_flatfield_correction_jax 

180 

181 Returns: 

182 Corrected 4D JAX array of the same shape as input 

183 

184 Raises: 

185 ImportError: If BaSiCPy is not available 

186 ValueError: If input is not a 4D array or batch_dim is invalid 

187 """ 

188 # Validate input 

189 _validate_jax_array(image_batch) 

190 

191 if image_batch.ndim != 4: 

192 raise ValueError(f"Input must be a 4D array, got {image_batch.ndim}D") 

193 

194 if batch_dim not in [0, 1]: 

195 raise ValueError(f"batch_dim must be 0 or 1, got {batch_dim}") 

196 

197 logger.debug(f"BaSiC batch correction: {image_batch.shape}, batch_dim={batch_dim}") 

198 

199 # Process each 3D stack in the batch 

200 result_list = [] 

201 

202 if batch_dim == 0: 

203 # Batch is organized as (B, Z, Y, X) 

204 for b in range(image_batch.shape[0]): 

205 corrected = basic_flatfield_correction_jax(image_batch[b], **kwargs) 

206 result_list.append(corrected) 

207 

208 # Stack along batch dimension 

209 import jax.numpy as jnp 

210 return jnp.stack(result_list, axis=0) 

211 

212 # Batch is organized as (Z, B, Y, X) 

213 for b in range(image_batch.shape[1]): 

214 corrected = basic_flatfield_correction_jax(image_batch[:, b], **kwargs) 

215 result_list.append(corrected) 

216 

217 # Stack along batch dimension 

218 import jax.numpy as jnp 

219 return jnp.stack(result_list, axis=1)