Coverage for openhcs/processing/backends/enhance/basic_processor_numpy.py: 12.1%
73 statements
« prev ^ index » next coverage.py v7.10.3, created at 2025-08-14 05:57 +0000
« prev ^ index » next coverage.py v7.10.3, created at 2025-08-14 05:57 +0000
1"""
2BaSiC (Background and Shading Correction) Implementation using NumPy
4This module implements the BaSiC algorithm for illumination correction
5using NumPy for CPU processing. The implementation is based on the paper:
6Peng et al., "A BaSiC tool for background and shading correction of optical
7microscopy images", Nature Communications, 2017.
9The algorithm performs low-rank + sparse matrix decomposition to separate
10uneven illumination artifacts from structural features in microscopy images.
12Doctrinal Clauses:
13- Clause 3 — Declarative Primacy: All functions are pure and stateless
14- Clause 65 — Fail Loudly: No silent fallbacks or inferred capabilities
15"""
16from __future__ import annotations
18import logging
19from typing import Any
21import numpy as np
22from scipy import linalg
24# Import decorator directly from core.memory to avoid circular imports
25from openhcs.core.memory import numpy as numpy_func
27logger = logging.getLogger(__name__)
30def _validate_numpy_array(array: Any, name: str = "input") -> None:
31 """
32 Validate that the input is a NumPy array.
34 Args:
35 array: Array to validate
36 name: Name of the array for error messages
38 Raises:
39 TypeError: If the array is not a NumPy array
40 """
41 if not isinstance(array, np.ndarray):
42 raise TypeError(
43 f"{name} must be a NumPy array, got {type(array)}. "
44 f"No automatic conversion is performed to maintain explicit contracts. "
45 f"For GPU arrays, use the CuPy implementation with DLPack support."
46 )
49def _low_rank_approximation(matrix: np.ndarray, rank: int = 3) -> np.ndarray:
50 """
51 Compute a low-rank approximation of a matrix using truncated SVD.
53 Args:
54 matrix: Input matrix to approximate
55 rank: Target rank for the approximation
57 Returns:
58 Low-rank approximation of the input matrix
59 """
60 # Perform SVD
61 U, s, Vh = linalg.svd(matrix, full_matrices=False)
63 # Truncate to the specified rank
64 s[rank:] = 0
66 # Reconstruct the low-rank matrix
67 low_rank = (U * s) @ Vh
69 return low_rank
72def _soft_threshold(matrix: np.ndarray, threshold: float) -> np.ndarray:
73 """
74 Apply soft thresholding (shrinkage operator) to a matrix.
76 Args:
77 matrix: Input matrix
78 threshold: Threshold value for soft thresholding
80 Returns:
81 Soft-thresholded matrix
82 """
83 return np.sign(matrix) * np.maximum(np.abs(matrix) - threshold, 0)
86@numpy_func
87def basic_flatfield_correction_numpy(
88 image: np.ndarray,
89 *,
90 max_iters: int = 50,
91 lambda_sparse: float = 0.01,
92 lambda_lowrank: float = 0.1,
93 rank: int = 3,
94 tol: float = 1e-4,
95 correction_mode: str = "divide",
96 normalize_output: bool = True,
97 verbose: bool = False,
98 **kwargs
99) -> np.ndarray:
100 """
101 Perform BaSiC-style illumination correction on a 3D image stack using NumPy.
103 This function implements the BaSiC algorithm for illumination correction
104 using low-rank + sparse matrix decomposition. It models the background
105 (shading field) as a low-rank matrix across slices and the residuals
106 (e.g., nuclei, structures) as sparse features.
108 Args:
109 image: 3D NumPy array of shape (Z, Y, X)
110 max_iters: Maximum number of iterations for the alternating minimization
111 lambda_sparse: Regularization parameter for the sparse component
112 lambda_lowrank: Regularization parameter for the low-rank component
113 rank: Target rank for the low-rank approximation
114 tol: Tolerance for convergence
115 correction_mode: Mode for applying the correction ('divide' or 'subtract')
116 normalize_output: Whether to normalize the output to preserve dynamic range
117 verbose: Whether to print progress information
118 **kwargs: Additional parameters (ignored)
120 Returns:
121 Corrected 3D NumPy array of shape (Z, Y, X)
123 Raises:
124 TypeError: If the input is not a NumPy array
125 ValueError: If the input is not a 3D array or if correction_mode is invalid
126 """
127 # Validate input
128 _validate_numpy_array(image)
130 if image.ndim != 3:
131 raise ValueError(f"Input must be a 3D array, got {image.ndim}D")
133 if correction_mode not in ["divide", "subtract"]:
134 raise ValueError(f"Invalid correction mode: {correction_mode}. "
135 f"Must be 'divide' or 'subtract'")
137 # Store original shape and dtype
138 z, y, x = image.shape
139 orig_dtype = image.dtype
141 # Convert to float for processing
142 image_float = image.astype(np.float32)
144 # Flatten each Z-slice into a row vector
145 # D has shape (Z, Y*X)
146 D = image_float.reshape(z, y * x)
148 # Initialize variables for alternating minimization
149 L = np.zeros_like(D) # Low-rank component (background/illumination)
150 S = np.zeros_like(D) # Sparse component (foreground/structures)
152 # Compute initial norm for convergence check
153 norm_D = np.linalg.norm(D, 'fro')
155 # Alternating minimization loop
156 for iteration in range(max_iters):
157 # Update low-rank component (L)
158 L = _low_rank_approximation(D - S, rank=rank)
160 # Apply regularization to L if needed
161 if lambda_lowrank > 0:
162 L = L * (1 - lambda_lowrank)
164 # Update sparse component (S)
165 S = _soft_threshold(D - L, lambda_sparse)
167 # Check convergence
168 residual = np.linalg.norm(D - L - S, 'fro') / norm_D
169 if verbose and (iteration % 10 == 0 or iteration == max_iters - 1):
170 logger.info(f"Iteration {iteration+1}/{max_iters}, residual: {residual:.6f}")
172 if residual < tol:
173 if verbose:
174 logger.info(f"Converged after {iteration+1} iterations")
175 break
177 # Reshape the low-rank component back to 3D
178 L_stack = L.reshape(z, y, x)
180 # Apply correction
181 if correction_mode == "divide":
182 # Add small epsilon to avoid division by zero
183 eps = 1e-6
184 corrected = image_float / (L_stack + eps)
186 # Normalize to preserve dynamic range
187 if normalize_output:
188 corrected *= np.mean(L_stack)
189 else: # subtract
190 corrected = image_float - L_stack
192 # Normalize to preserve dynamic range
193 if normalize_output:
194 corrected += np.mean(L_stack)
196 # Clip to valid range and convert back to original dtype
197 if np.issubdtype(orig_dtype, np.integer):
198 max_val = np.iinfo(orig_dtype).max
199 corrected = np.clip(corrected, 0, max_val).astype(orig_dtype)
200 else:
201 corrected = np.clip(corrected, 0, None).astype(orig_dtype)
203 return corrected
206def basic_flatfield_correction_batch_numpy(
207 image_batch: np.ndarray,
208 *,
209 batch_dim: int = 0,
210 **kwargs
211) -> np.ndarray:
212 """
213 Apply BaSiC flatfield correction to a batch of 3D image stacks.
215 This function applies the BaSiC algorithm to each 3D stack in a batch.
217 Args:
218 image_batch: 4D NumPy array of shape (B, Z, Y, X) or (Z, B, Y, X)
219 batch_dim: Dimension along which the batch is organized (0 or 1)
220 **kwargs: Additional parameters passed to basic_flatfield_correction_numpy
222 Returns:
223 Corrected 4D NumPy array of the same shape as input
225 Raises:
226 TypeError: If the input is not a NumPy array
227 ValueError: If the input is not a 4D array or if batch_dim is invalid
228 """
229 # Validate input
230 _validate_numpy_array(image_batch)
232 if image_batch.ndim != 4:
233 raise ValueError(f"Input must be a 4D array, got {image_batch.ndim}D")
235 if batch_dim not in [0, 1]:
236 raise ValueError(f"batch_dim must be 0 or 1, got {batch_dim}")
238 # Process each 3D stack in the batch
239 result_list = []
241 if batch_dim == 0:
242 # Batch is organized as (B, Z, Y, X)
243 for b in range(image_batch.shape[0]):
244 corrected = basic_flatfield_correction_numpy(image_batch[b], **kwargs)
245 result_list.append(corrected)
247 # Stack along batch dimension
248 return np.stack(result_list, axis=0)
249 else:
250 # Batch is organized as (Z, B, Y, X)
251 for b in range(image_batch.shape[1]):
252 corrected = basic_flatfield_correction_numpy(image_batch[:, b], **kwargs)
253 result_list.append(corrected)
255 # Stack along batch dimension
256 return np.stack(result_list, axis=1)