Coverage for openhcs/processing/backends/enhance/basic_processor_numpy.py: 12.1%

73 statements  

« prev     ^ index     » next       coverage.py v7.10.3, created at 2025-08-14 05:57 +0000

1""" 

2BaSiC (Background and Shading Correction) Implementation using NumPy 

3 

4This module implements the BaSiC algorithm for illumination correction 

5using NumPy for CPU processing. The implementation is based on the paper: 

6Peng et al., "A BaSiC tool for background and shading correction of optical 

7microscopy images", Nature Communications, 2017. 

8 

9The algorithm performs low-rank + sparse matrix decomposition to separate 

10uneven illumination artifacts from structural features in microscopy images. 

11 

12Doctrinal Clauses: 

13- Clause 3 — Declarative Primacy: All functions are pure and stateless 

14- Clause 65 — Fail Loudly: No silent fallbacks or inferred capabilities 

15""" 

16from __future__ import annotations 

17 

18import logging 

19from typing import Any 

20 

21import numpy as np 

22from scipy import linalg 

23 

24# Import decorator directly from core.memory to avoid circular imports 

25from openhcs.core.memory import numpy as numpy_func 

26 

27logger = logging.getLogger(__name__) 

28 

29 

30def _validate_numpy_array(array: Any, name: str = "input") -> None: 

31 """ 

32 Validate that the input is a NumPy array. 

33 

34 Args: 

35 array: Array to validate 

36 name: Name of the array for error messages 

37 

38 Raises: 

39 TypeError: If the array is not a NumPy array 

40 """ 

41 if not isinstance(array, np.ndarray): 

42 raise TypeError( 

43 f"{name} must be a NumPy array, got {type(array)}. " 

44 f"No automatic conversion is performed to maintain explicit contracts. " 

45 f"For GPU arrays, use the CuPy implementation with DLPack support." 

46 ) 

47 

48 

49def _low_rank_approximation(matrix: np.ndarray, rank: int = 3) -> np.ndarray: 

50 """ 

51 Compute a low-rank approximation of a matrix using truncated SVD. 

52 

53 Args: 

54 matrix: Input matrix to approximate 

55 rank: Target rank for the approximation 

56 

57 Returns: 

58 Low-rank approximation of the input matrix 

59 """ 

60 # Perform SVD 

61 U, s, Vh = linalg.svd(matrix, full_matrices=False) 

62 

63 # Truncate to the specified rank 

64 s[rank:] = 0 

65 

66 # Reconstruct the low-rank matrix 

67 low_rank = (U * s) @ Vh 

68 

69 return low_rank 

70 

71 

72def _soft_threshold(matrix: np.ndarray, threshold: float) -> np.ndarray: 

73 """ 

74 Apply soft thresholding (shrinkage operator) to a matrix. 

75 

76 Args: 

77 matrix: Input matrix 

78 threshold: Threshold value for soft thresholding 

79 

80 Returns: 

81 Soft-thresholded matrix 

82 """ 

83 return np.sign(matrix) * np.maximum(np.abs(matrix) - threshold, 0) 

84 

85 

86@numpy_func 

87def basic_flatfield_correction_numpy( 

88 image: np.ndarray, 

89 *, 

90 max_iters: int = 50, 

91 lambda_sparse: float = 0.01, 

92 lambda_lowrank: float = 0.1, 

93 rank: int = 3, 

94 tol: float = 1e-4, 

95 correction_mode: str = "divide", 

96 normalize_output: bool = True, 

97 verbose: bool = False, 

98 **kwargs 

99) -> np.ndarray: 

100 """ 

101 Perform BaSiC-style illumination correction on a 3D image stack using NumPy. 

102 

103 This function implements the BaSiC algorithm for illumination correction 

104 using low-rank + sparse matrix decomposition. It models the background 

105 (shading field) as a low-rank matrix across slices and the residuals 

106 (e.g., nuclei, structures) as sparse features. 

107 

108 Args: 

109 image: 3D NumPy array of shape (Z, Y, X) 

110 max_iters: Maximum number of iterations for the alternating minimization 

111 lambda_sparse: Regularization parameter for the sparse component 

112 lambda_lowrank: Regularization parameter for the low-rank component 

113 rank: Target rank for the low-rank approximation 

114 tol: Tolerance for convergence 

115 correction_mode: Mode for applying the correction ('divide' or 'subtract') 

116 normalize_output: Whether to normalize the output to preserve dynamic range 

117 verbose: Whether to print progress information 

118 **kwargs: Additional parameters (ignored) 

119 

120 Returns: 

121 Corrected 3D NumPy array of shape (Z, Y, X) 

122 

123 Raises: 

124 TypeError: If the input is not a NumPy array 

125 ValueError: If the input is not a 3D array or if correction_mode is invalid 

126 """ 

127 # Validate input 

128 _validate_numpy_array(image) 

129 

130 if image.ndim != 3: 

131 raise ValueError(f"Input must be a 3D array, got {image.ndim}D") 

132 

133 if correction_mode not in ["divide", "subtract"]: 

134 raise ValueError(f"Invalid correction mode: {correction_mode}. " 

135 f"Must be 'divide' or 'subtract'") 

136 

137 # Store original shape and dtype 

138 z, y, x = image.shape 

139 orig_dtype = image.dtype 

140 

141 # Convert to float for processing 

142 image_float = image.astype(np.float32) 

143 

144 # Flatten each Z-slice into a row vector 

145 # D has shape (Z, Y*X) 

146 D = image_float.reshape(z, y * x) 

147 

148 # Initialize variables for alternating minimization 

149 L = np.zeros_like(D) # Low-rank component (background/illumination) 

150 S = np.zeros_like(D) # Sparse component (foreground/structures) 

151 

152 # Compute initial norm for convergence check 

153 norm_D = np.linalg.norm(D, 'fro') 

154 

155 # Alternating minimization loop 

156 for iteration in range(max_iters): 

157 # Update low-rank component (L) 

158 L = _low_rank_approximation(D - S, rank=rank) 

159 

160 # Apply regularization to L if needed 

161 if lambda_lowrank > 0: 

162 L = L * (1 - lambda_lowrank) 

163 

164 # Update sparse component (S) 

165 S = _soft_threshold(D - L, lambda_sparse) 

166 

167 # Check convergence 

168 residual = np.linalg.norm(D - L - S, 'fro') / norm_D 

169 if verbose and (iteration % 10 == 0 or iteration == max_iters - 1): 

170 logger.info(f"Iteration {iteration+1}/{max_iters}, residual: {residual:.6f}") 

171 

172 if residual < tol: 

173 if verbose: 

174 logger.info(f"Converged after {iteration+1} iterations") 

175 break 

176 

177 # Reshape the low-rank component back to 3D 

178 L_stack = L.reshape(z, y, x) 

179 

180 # Apply correction 

181 if correction_mode == "divide": 

182 # Add small epsilon to avoid division by zero 

183 eps = 1e-6 

184 corrected = image_float / (L_stack + eps) 

185 

186 # Normalize to preserve dynamic range 

187 if normalize_output: 

188 corrected *= np.mean(L_stack) 

189 else: # subtract 

190 corrected = image_float - L_stack 

191 

192 # Normalize to preserve dynamic range 

193 if normalize_output: 

194 corrected += np.mean(L_stack) 

195 

196 # Clip to valid range and convert back to original dtype 

197 if np.issubdtype(orig_dtype, np.integer): 

198 max_val = np.iinfo(orig_dtype).max 

199 corrected = np.clip(corrected, 0, max_val).astype(orig_dtype) 

200 else: 

201 corrected = np.clip(corrected, 0, None).astype(orig_dtype) 

202 

203 return corrected 

204 

205 

206def basic_flatfield_correction_batch_numpy( 

207 image_batch: np.ndarray, 

208 *, 

209 batch_dim: int = 0, 

210 **kwargs 

211) -> np.ndarray: 

212 """ 

213 Apply BaSiC flatfield correction to a batch of 3D image stacks. 

214 

215 This function applies the BaSiC algorithm to each 3D stack in a batch. 

216 

217 Args: 

218 image_batch: 4D NumPy array of shape (B, Z, Y, X) or (Z, B, Y, X) 

219 batch_dim: Dimension along which the batch is organized (0 or 1) 

220 **kwargs: Additional parameters passed to basic_flatfield_correction_numpy 

221 

222 Returns: 

223 Corrected 4D NumPy array of the same shape as input 

224 

225 Raises: 

226 TypeError: If the input is not a NumPy array 

227 ValueError: If the input is not a 4D array or if batch_dim is invalid 

228 """ 

229 # Validate input 

230 _validate_numpy_array(image_batch) 

231 

232 if image_batch.ndim != 4: 

233 raise ValueError(f"Input must be a 4D array, got {image_batch.ndim}D") 

234 

235 if batch_dim not in [0, 1]: 

236 raise ValueError(f"batch_dim must be 0 or 1, got {batch_dim}") 

237 

238 # Process each 3D stack in the batch 

239 result_list = [] 

240 

241 if batch_dim == 0: 

242 # Batch is organized as (B, Z, Y, X) 

243 for b in range(image_batch.shape[0]): 

244 corrected = basic_flatfield_correction_numpy(image_batch[b], **kwargs) 

245 result_list.append(corrected) 

246 

247 # Stack along batch dimension 

248 return np.stack(result_list, axis=0) 

249 else: 

250 # Batch is organized as (Z, B, Y, X) 

251 for b in range(image_batch.shape[1]): 

252 corrected = basic_flatfield_correction_numpy(image_batch[:, b], **kwargs) 

253 result_list.append(corrected) 

254 

255 # Stack along batch dimension 

256 return np.stack(result_list, axis=1)