Coverage for openhcs/processing/backends/enhance/basic_processor_cupy.py: 6.3%

274 statements  

« prev     ^ index     » next       coverage.py v7.10.3, created at 2025-08-14 05:57 +0000

1""" 

2BaSiC (Background and Shading Correction) Implementation using CuPy 

3 

4This module implements the BaSiC algorithm for illumination correction 

5using CuPy for GPU acceleration. The implementation is based on the paper: 

6Peng et al., "A BaSiC tool for background and shading correction of optical 

7microscopy images", Nature Communications, 2017. 

8 

9The algorithm performs low-rank + sparse matrix decomposition to separate 

10uneven illumination artifacts from structural features in microscopy images. 

11 

12Doctrinal Clauses: 

13- Clause 3 — Declarative Primacy: All functions are pure and stateless 

14- Clause 65 — Fail Loudly: No silent fallbacks or inferred capabilities 

15- Clause 88 — No Inferred Capabilities: Explicit CuPy dependency 

16- Clause 273 — Memory Backend Restrictions: GPU-only implementation 

17""" 

18from __future__ import annotations 

19 

20import logging 

21from typing import TYPE_CHECKING, Any 

22 

23# Import decorator directly from core.memory to avoid circular imports 

24from openhcs.core.memory.decorators import cupy as cupy_func 

25from openhcs.core.utils import optional_import 

26 

27# For type checking only 

28if TYPE_CHECKING: 28 ↛ 29line 28 didn't jump to line 29 because the condition on line 28 was never true

29 import cupy as cp 

30 from cupyx.scipy import linalg 

31 

32# Import CuPy as an optional dependency 

33cp = optional_import("cupy") 

34cupyx_scipy = None 

35if cp is not None: 35 ↛ 38line 35 didn't jump to line 38 because the condition on line 35 was always true

36 cupyx_scipy = optional_import("cupyx.scipy") 

37 

38logger = logging.getLogger(__name__) 

39 

40 

41def _validate_cupy_array(array: Any, name: str = "input") -> None: 

42 """ 

43 Validate that the input is a CuPy array. 

44 

45 Args: 

46 array: Array to validate 

47 name: Name of the array for error messages 

48 

49 Raises: 

50 ImportError: If CuPy is not available 

51 TypeError: If the array is not a CuPy array 

52 ValueError: If the array doesn't support DLPack 

53 """ 

54 # The compiler will ensure this function is only called when CuPy is available 

55 # No need to check for CuPy availability here 

56 

57 if not isinstance(array, cp.ndarray): 

58 raise TypeError( 

59 f"{name} must be a CuPy array, got {type(array)}. " 

60 f"No automatic conversion is performed to maintain explicit contracts. " 

61 f"Use DLPack for zero-copy GPU-to-GPU transfers." 

62 ) 

63 

64 # Ensure the array supports DLPack 

65 if not hasattr(array, "__dlpack__") and not hasattr(array, "toDlpack"): 

66 raise ValueError( 

67 f"{name} does not support DLPack protocol. " 

68 f"DLPack is required for GPU memory conversions." 

69 ) 

70 

71 

72def _low_rank_approximation(matrix: "cp.ndarray", rank: int = 3, max_memory_gb: float = 1.0) -> "cp.ndarray": 

73 """ 

74 Compute a low-rank approximation of a matrix using SVD with memory optimization. 

75 

76 Args: 

77 matrix: Input matrix to approximate 

78 rank: Target rank for the approximation 

79 max_memory_gb: Maximum memory to use for SVD (in GB) 

80 

81 Returns: 

82 Low-rank approximation of the input matrix 

83 """ 

84 # Estimate memory usage for SVD 

85 matrix_size_gb = matrix.nbytes / (1024**3) 

86 svd_memory_estimate = matrix_size_gb * 3 # U, s, Vh matrices 

87 

88 if svd_memory_estimate > max_memory_gb: 

89 # Use chunked processing for large matrices 

90 logger.info(f"🔧 MEMORY OPTIMIZATION: Matrix too large ({svd_memory_estimate:.2f}GB > {max_memory_gb}GB), using chunked SVD") 

91 return _chunked_low_rank_approximation(matrix, rank, max_memory_gb) 

92 else: 

93 # Use standard SVD for smaller matrices 

94 try: 

95 # Perform SVD using CuPy's built-in linalg 

96 U, s, Vh = cp.linalg.svd(matrix, full_matrices=False) 

97 

98 # Truncate to the specified rank 

99 s[rank:] = 0 

100 

101 # Reconstruct the low-rank matrix 

102 low_rank = (U * s) @ Vh 

103 

104 return low_rank 

105 # except (cp.cuda.memory.OutOfMemoryError, cp.cuda.cuda.CUDAError): 

106 except (cp.cuda.memory.OutOfMemoryError, 

107 cp.cuda.runtime.CUDARuntimeError, 

108 cp.cuda.cusolver.CUSOLVERError, 

109 cp.cuda.cublas.CUBLASError): 

110 # Fallback to chunked processing if standard SVD fails 

111 logger.warning("🔧 MEMORY OPTIMIZATION: Standard SVD failed, falling back to chunked processing") 

112 return _chunked_low_rank_approximation(matrix, rank, max_memory_gb) 

113 

114 

115def _chunked_low_rank_approximation(matrix: "cp.ndarray", rank: int, max_memory_gb: float) -> "cp.ndarray": 

116 """ 

117 Compute low-rank approximation using adaptive dynamic chunking. 

118 

119 Automatically adjusts chunk sizes based on available GPU memory and allocation 

120 success/failure patterns for optimal performance. 

121 

122 Args: 

123 matrix: Input matrix to approximate (Z, Y*X) 

124 rank: Target rank for the approximation 

125 max_memory_gb: Maximum memory to use per chunk (fallback limit) 

126 

127 Returns: 

128 Low-rank approximation of the input matrix 

129 """ 

130 Z, YX = matrix.shape 

131 

132 # 🔧 ADAPTIVE CHUNKING: Query available GPU memory for dynamic sizing 

133 try: 

134 free_memory, total_memory = cp.cuda.runtime.memGetInfo() 

135 available_gb = free_memory / (1024**3) 

136 logger.debug(f"🔧 ADAPTIVE CHUNKING: {available_gb:.2f}GB free GPU memory") 

137 except Exception: 

138 # Fallback to conservative estimate if memory query fails 

139 available_gb = max_memory_gb * 0.5 

140 logger.debug(f"🔧 ADAPTIVE CHUNKING: Memory query failed, using conservative {available_gb:.2f}GB") 

141 

142 # Calculate initial chunk size based on available memory 

143 bytes_per_element = matrix.dtype.itemsize 

144 svd_overhead = 8 # Conservative estimate for SVD workspace requirements 

145 

146 # Use 10% of available memory for initial chunk size 

147 usable_memory_gb = min(available_gb * 0.1, max_memory_gb) 

148 max_elements_per_chunk = int((usable_memory_gb * 1024**3) / (bytes_per_element * svd_overhead)) 

149 initial_chunk_size = min(YX, max_elements_per_chunk // Z) 

150 

151 # Enforce reasonable bounds 

152 min_chunk_size = 100 

153 max_chunk_size = YX // 4 # Don't make chunks larger than 25% of total 

154 current_chunk_size = max(min_chunk_size, min(initial_chunk_size, max_chunk_size)) 

155 

156 logger.debug(f"🔧 ADAPTIVE CHUNKING: Starting with chunk size {current_chunk_size:,} (of {YX:,} total)") 

157 

158 # Adaptive feedback variables 

159 success_count = 0 

160 low_rank_chunks = [] 

161 current_pos = 0 

162 

163 # Process matrix with adaptive chunking 

164 while current_pos < YX: 

165 end_pos = min(current_pos + current_chunk_size, YX) 

166 chunk = matrix[:, current_pos:end_pos] 

167 

168 try: 

169 # Try to process this chunk on GPU 

170 U, s, Vh = cp.linalg.svd(chunk, full_matrices=False) 

171 

172 # Truncate to the specified rank 

173 s_truncated = s.copy() 

174 s_truncated[rank:] = 0 

175 

176 # Reconstruct the low-rank chunk 

177 low_rank_chunk = (U * s_truncated) @ Vh 

178 low_rank_chunks.append(low_rank_chunk) 

179 

180 # 🔧 SUCCESS: Move to next chunk and track success 

181 current_pos = end_pos 

182 success_count += 1 

183 

184 # Grow chunk size after 3 consecutive successes 

185 if success_count >= 3: 

186 old_size = current_chunk_size 

187 current_chunk_size = min(int(current_chunk_size * 1.5), max_chunk_size) 

188 if current_chunk_size > old_size: 

189 logger.debug(f"🔧 ADAPTIVE CHUNKING: Growing chunk size {old_size:,}{current_chunk_size:,}") 

190 success_count = 0 

191 

192 except (cp.cuda.memory.OutOfMemoryError, 

193 cp.cuda.runtime.CUDARuntimeError, 

194 cp.cuda.cusolver.CUSOLVERError, 

195 cp.cuda.cublas.CUBLASError): 

196 # 🔧 FAILURE: Shrink chunk size and retry 

197 old_size = current_chunk_size 

198 current_chunk_size = max(int(current_chunk_size * 0.5), min_chunk_size) 

199 success_count = 0 

200 

201 if current_chunk_size < old_size: 

202 logger.info(f"🔧 ADAPTIVE CHUNKING: OOM, shrinking chunk size {old_size:,}{current_chunk_size:,}") 

203 # Don't advance position, retry with smaller chunk 

204 continue 

205 else: 

206 # Chunk size can't be reduced further, fallback to CPU for this chunk 

207 logger.warning(f"🔧 ADAPTIVE CHUNKING: Minimum chunk size reached, using CPU for chunk {current_pos}:{end_pos}") 

208 

209 try: 

210 # Process on CPU 

211 chunk_cpu = chunk.get() 

212 import numpy as np 

213 U_cpu, s_cpu, Vh_cpu = np.linalg.svd(chunk_cpu, full_matrices=False) 

214 s_cpu[rank:] = 0 

215 low_rank_chunk_cpu = (U_cpu * s_cpu) @ Vh_cpu 

216 low_rank_chunk = cp.asarray(low_rank_chunk_cpu) 

217 low_rank_chunks.append(low_rank_chunk) 

218 

219 logger.debug(f"🔧 ADAPTIVE CHUNKING: Successfully processed chunk on CPU") 

220 current_pos = end_pos 

221 

222 except Exception as cpu_error: 

223 logger.error(f"🔧 ADAPTIVE CHUNKING: CPU fallback failed: {cpu_error}") 

224 # Last resort: use identity (no correction for this chunk) 

225 logger.warning(f"🔧 ADAPTIVE CHUNKING: Using identity matrix for chunk {current_pos}:{end_pos}") 

226 low_rank_chunk = chunk.copy() 

227 low_rank_chunks.append(low_rank_chunk) 

228 current_pos = end_pos 

229 

230 # Concatenate all processed chunks 

231 low_rank = cp.concatenate(low_rank_chunks, axis=1) 

232 

233 logger.debug(f"🔧 ADAPTIVE CHUNKING: Completed processing {len(low_rank_chunks)} chunks") 

234 return low_rank 

235 

236 

237def _soft_threshold(matrix: "cp.ndarray", threshold: float) -> "cp.ndarray": 

238 """ 

239 Apply soft thresholding (shrinkage operator) to a matrix. 

240 

241 Args: 

242 matrix: Input matrix 

243 threshold: Threshold value for soft thresholding 

244 

245 Returns: 

246 Soft-thresholded matrix 

247 """ 

248 return cp.sign(matrix) * cp.maximum(cp.abs(matrix) - threshold, 0) 

249 

250 

251@cupy_func 

252def basic_flatfield_correction_cupy( 

253 image: "cp.ndarray", 

254 *, 

255 max_iters: int = 50, 

256 lambda_sparse: float = 0.01, 

257 lambda_lowrank: float = 0.1, 

258 rank: int = 3, 

259 tol: float = 1e-4, 

260 correction_mode: str = "divide", 

261 normalize_output: bool = True, 

262 verbose: bool = False, 

263 max_memory_gb: float = 1.0, 

264 **kwargs 

265) -> "cp.ndarray": 

266 """ 

267 Perform BaSiC-style illumination correction on a 3D image stack using CuPy. 

268 

269 This function implements the BaSiC algorithm for illumination correction 

270 using low-rank + sparse matrix decomposition. It models the background 

271 (shading field) as a low-rank matrix across slices and the residuals 

272 (e.g., nuclei, structures) as sparse features. 

273 

274 Memory-optimized version that automatically uses chunked processing for 

275 large images to prevent CUDA out-of-memory errors. 

276 

277 Args: 

278 image: 3D CuPy array of shape (Z, Y, X) 

279 max_iters: Maximum number of iterations for the alternating minimization 

280 lambda_sparse: Regularization parameter for the sparse component 

281 lambda_lowrank: Regularization parameter for the low-rank component 

282 rank: Target rank for the low-rank approximation 

283 tol: Tolerance for convergence 

284 correction_mode: Mode for applying the correction ('divide' or 'subtract') 

285 normalize_output: Whether to normalize the output to preserve dynamic range 

286 verbose: Whether to print progress information 

287 max_memory_gb: Maximum memory to use for SVD operations (in GB) 

288 **kwargs: Additional parameters (ignored) 

289 

290 Returns: 

291 Corrected 3D CuPy array of shape (Z, Y, X) 

292 

293 Raises: 

294 ImportError: If CuPy is not available 

295 TypeError: If the input is not a CuPy array 

296 ValueError: If the input is not a 3D array or if correction_mode is invalid 

297 or if the input array doesn't support DLPack 

298 """ 

299 # Validate input 

300 _validate_cupy_array(image) 

301 

302 if image.ndim != 3: 

303 raise ValueError(f"Input must be a 3D array, got {image.ndim}D") 

304 

305 if correction_mode not in ["divide", "subtract"]: 

306 raise ValueError(f"Invalid correction mode: {correction_mode}. " 

307 f"Must be 'divide' or 'subtract'") 

308 

309 # Store original shape and dtype 

310 z, y, x = image.shape 

311 orig_dtype = image.dtype 

312 

313 # 🔍 MEMORY ESTIMATION: Check if image is likely to cause memory issues 

314 image_size_gb = image.nbytes / (1024**3) 

315 estimated_peak_memory = image_size_gb * 4 # Original + float + L + S matrices 

316 

317 # Try GPU processing first, fallback to CPU on OOM 

318 try: 

319 return _gpu_flatfield_correction( 

320 image, max_iters, lambda_sparse, lambda_lowrank, rank, tol, 

321 correction_mode, normalize_output, verbose, max_memory_gb, 

322 image_size_gb, estimated_peak_memory, z, y, x, orig_dtype 

323 ) 

324 # except (cp.cuda.memory.OutOfMemoryError, cp.cuda.cuda.CUDAError): 

325 except (cp.cuda.memory.OutOfMemoryError, 

326 cp.cuda.runtime.CUDARuntimeError, 

327 cp.cuda.cusolver.CUSOLVERError, 

328 cp.cuda.cublas.CUBLASError) as oom_error: 

329 logger.warning(f"🔧 GPU OOM: {oom_error}") 

330 logger.info(f"🔧 CPU FALLBACK: GPU processing failed, switching to CPU for {z}×{y}×{x} image") 

331 

332 # 🔧 CRITICAL: Delete ALL intermediate variables from failed GPU processing 

333 logger.debug(f"🔧 CPU FALLBACK: Cleaning up intermediate GPU variables...") 

334 try: 

335 # These variables exist in _gpu_flatfield_correction scope, need to pass them out 

336 # For now, just do aggressive memory cleanup 

337 cp.get_default_memory_pool().free_all_blocks() 

338 cp.get_default_pinned_memory_pool().free_all_blocks() 

339 cp.cuda.runtime.deviceSynchronize() 

340 

341 # Force garbage collection 

342 import gc 

343 gc.collect() 

344 

345 # Check memory after cleanup 

346 free_after_cleanup, total = cp.cuda.runtime.memGetInfo() 

347 logger.info(f"🔧 CPU FALLBACK: After cleanup: {free_after_cleanup / 1e9:.2f}GB free of {total / 1e9:.2f}GB total") 

348 

349 except Exception as cleanup_error: 

350 logger.warning(f"🔧 CPU FALLBACK: Cleanup warning: {cleanup_error}") 

351 

352 # Fallback to CPU processing 

353 return _cpu_fallback_flatfield_correction( 

354 image, max_iters, lambda_sparse, lambda_lowrank, rank, tol, 

355 correction_mode, normalize_output, verbose 

356 ) 

357 

358 

359def _gpu_flatfield_correction( 

360 image: "cp.ndarray", max_iters: int, lambda_sparse: float, lambda_lowrank: float, 

361 rank: int, tol: float, correction_mode: str, normalize_output: bool, verbose: bool, 

362 max_memory_gb: float, image_size_gb: float, estimated_peak_memory: float, z: int, y: int, x: int, orig_dtype 

363) -> "cp.ndarray": 

364 """GPU-based flatfield correction implementation.""" 

365 

366 if estimated_peak_memory > max_memory_gb * 2: 

367 logger.warning(f"⚠️ Large image detected: {z}×{y}×{x} ({image_size_gb:.2f}GB). " 

368 f"Estimated peak memory: {estimated_peak_memory:.2f}GB. " 

369 f"Consider reducing image size or increasing max_memory_gb parameter.") 

370 

371 logger.debug(f"🔧 MEMORY INFO: Image size {z}×{y}×{x}, {image_size_gb:.2f}GB, " 

372 f"max_memory_gb={max_memory_gb}, estimated peak={estimated_peak_memory:.2f}GB") 

373 

374 # Initialize variables to None for proper cleanup 

375 image_float = None 

376 D = None 

377 L = None 

378 S = None 

379 L_stack = None 

380 corrected = None 

381 

382 try: 

383 # Convert to float for processing 

384 image_float = image.astype(cp.float32) 

385 

386 # Flatten each Z-slice into a row vector 

387 # D has shape (Z, Y*X) 

388 D = image_float.reshape(z, y * x) 

389 

390 # Initialize variables for alternating minimization 

391 L = cp.zeros_like(D) # Low-rank component (background/illumination) 

392 S = cp.zeros_like(D) # Sparse component (foreground/structures) 

393 

394 # Compute initial norm for convergence check 

395 norm_D = cp.linalg.norm(D, 'fro') 

396 

397 # Track convergence for early termination 

398 prev_residual = float('inf') 

399 stagnation_count = 0 

400 max_stagnation = 5 # Stop if no improvement for 5 iterations 

401 

402 # Alternating minimization loop 

403 for iteration in range(max_iters): 

404 # Update low-rank component (L) with memory optimization 

405 L = _low_rank_approximation(D - S, rank=rank, max_memory_gb=max_memory_gb) 

406 

407 # Apply regularization to L if needed 

408 if lambda_lowrank > 0: 

409 L = L * (1 - lambda_lowrank) 

410 

411 # Update sparse component (S) 

412 S = _soft_threshold(D - L, lambda_sparse) 

413 

414 # Check convergence 

415 residual = cp.linalg.norm(D - L - S, 'fro') / norm_D 

416 if verbose and (iteration % 10 == 0 or iteration == max_iters - 1): 

417 logger.info(f"Iteration {iteration+1}/{max_iters}, residual: {residual:.6f}") 

418 

419 # Early termination conditions 

420 if residual < tol: 

421 if verbose: 

422 logger.info(f"Converged after {iteration+1} iterations (residual < {tol})") 

423 break 

424 

425 # Check for stagnation (no significant improvement) 

426 improvement = prev_residual - residual 

427 if improvement < tol * 0.1: # Less than 10% of tolerance improvement 

428 stagnation_count += 1 

429 if stagnation_count >= max_stagnation: 

430 if verbose: 

431 logger.info(f"Early termination after {iteration+1} iterations (stagnation)") 

432 break 

433 else: 

434 stagnation_count = 0 # Reset counter if we see improvement 

435 

436 prev_residual = residual 

437 

438 # Reshape the low-rank component back to 3D 

439 L_stack = L.reshape(z, y, x) 

440 

441 # Apply correction 

442 if correction_mode == "divide": 

443 # Add small epsilon to avoid division by zero 

444 eps = 1e-6 

445 corrected = image_float / (L_stack + eps) 

446 

447 # Normalize to preserve dynamic range 

448 if normalize_output: 

449 corrected *= cp.mean(L_stack) 

450 else: # subtract 

451 corrected = image_float - L_stack 

452 

453 # Normalize to preserve dynamic range 

454 if normalize_output: 

455 corrected += cp.mean(L_stack) 

456 

457 # Clip to valid range and convert back to original dtype 

458 if cp.issubdtype(orig_dtype, cp.integer): 

459 max_val = cp.iinfo(orig_dtype).max 

460 corrected = cp.clip(corrected, 0, max_val).astype(orig_dtype) 

461 else: 

462 corrected = cp.clip(corrected, 0, None).astype(orig_dtype) 

463 

464 return corrected 

465 

466 except (cp.cuda.memory.OutOfMemoryError, 

467 cp.cuda.runtime.CUDARuntimeError, 

468 cp.cuda.cusolver.CUSOLVERError, 

469 cp.cuda.cublas.CUBLASError) as gpu_error: 

470 

471 # 🔧 CRITICAL: Clean up ALL intermediate variables before re-raising 

472 logger.debug(f"🔧 GPU PROCESSING: Cleaning up intermediate variables after OOM...") 

473 try: 

474 # Delete all local intermediate variables 

475 if 'image_float' in locals() and image_float is not None: 

476 del image_float 

477 if 'D' in locals() and D is not None: 

478 del D 

479 if 'L' in locals() and L is not None: 

480 del L 

481 if 'S' in locals() and S is not None: 

482 del S 

483 if 'L_stack' in locals() and L_stack is not None: 

484 del L_stack 

485 if 'corrected' in locals() and corrected is not None: 

486 del corrected 

487 

488 # Force garbage collection 

489 import gc 

490 gc.collect() 

491 

492 logger.debug(f"🔧 GPU PROCESSING: Intermediate variables cleaned up") 

493 

494 except Exception as cleanup_error: 

495 logger.warning(f"🔧 GPU PROCESSING: Variable cleanup warning: {cleanup_error}") 

496 

497 # Re-raise the original GPU error for the outer exception handler 

498 raise gpu_error 

499 

500 

501def basic_flatfield_correction_batch_cupy( 

502 image_batch: "cp.ndarray", 

503 *, 

504 batch_dim: int = 0, 

505 **kwargs 

506) -> "cp.ndarray": 

507 """ 

508 Apply BaSiC flatfield correction to a batch of 3D image stacks. 

509 

510 This function applies the BaSiC algorithm to each 3D stack in a batch. 

511 

512 Args: 

513 image_batch: 4D CuPy array of shape (B, Z, Y, X) or (Z, B, Y, X) 

514 batch_dim: Dimension along which the batch is organized (0 or 1) 

515 **kwargs: Additional parameters passed to basic_flatfield_correction_cupy 

516 

517 Returns: 

518 Corrected 4D CuPy array of the same shape as input 

519 

520 Raises: 

521 ImportError: If CuPy is not available 

522 TypeError: If the input is not a CuPy array 

523 ValueError: If the input is not a 4D array or if batch_dim is invalid 

524 or if the input array doesn't support DLPack 

525 """ 

526 # Validate input 

527 _validate_cupy_array(image_batch) 

528 

529 if image_batch.ndim != 4: 

530 raise ValueError(f"Input must be a 4D array, got {image_batch.ndim}D") 

531 

532 if batch_dim not in [0, 1]: 

533 raise ValueError(f"batch_dim must be 0 or 1, got {batch_dim}") 

534 

535 # Process each 3D stack in the batch 

536 result_list = [] 

537 

538 if batch_dim == 0: 

539 # Batch is organized as (B, Z, Y, X) 

540 for b in range(image_batch.shape[0]): 

541 corrected = basic_flatfield_correction_cupy(image_batch[b], **kwargs) 

542 result_list.append(corrected) 

543 

544 # Stack along batch dimension 

545 return cp.stack(result_list, axis=0) 

546 

547 # Batch is organized as (Z, B, Y, X) 

548 for b in range(image_batch.shape[1]): 

549 corrected = basic_flatfield_correction_cupy(image_batch[:, b], **kwargs) 

550 result_list.append(corrected) 

551 

552 # Stack along batch dimension 

553 return cp.stack(result_list, axis=1) 

554 

555 

556def _cpu_fallback_flatfield_correction( 

557 image: "cp.ndarray", max_iters: int, lambda_sparse: float, lambda_lowrank: float, 

558 rank: int, tol: float, correction_mode: str, normalize_output: bool, verbose: bool 

559) -> "cp.ndarray": 

560 """CPU fallback for flatfield correction when GPU runs out of memory.""" 

561 

562 try: 

563 from openhcs.processing.backends.enhance.basic_processor_numpy import basic_flatfield_correction_numpy 

564 

565 # 🔧 AGGRESSIVE GPU CLEANUP: Free as much GPU memory as possible before CPU conversion 

566 logger.info(f"🔧 CPU FALLBACK: Clearing GPU memory before CPU conversion...") 

567 try: 

568 cp.get_default_memory_pool().free_all_blocks() 

569 cp.get_default_pinned_memory_pool().free_all_blocks() 

570 cp.cuda.runtime.deviceSynchronize() 

571 except Exception: 

572 pass 

573 

574 # Convert CuPy array to NumPy in chunks to avoid large GPU allocation 

575 logger.info(f"🔧 CPU FALLBACK: Converting CuPy array to NumPy in chunks...") 

576 z, y, x = image.shape 

577 

578 # Convert slice by slice to minimize GPU memory usage 

579 image_cpu_slices = [] 

580 for i in range(z): 

581 try: 

582 slice_cpu = image[i].get() # Convert one slice at a time 

583 image_cpu_slices.append(slice_cpu) 

584 

585 # Clear GPU memory after each slice 

586 if i % 10 == 0: # Every 10 slices 

587 cp.get_default_memory_pool().free_all_blocks() 

588 

589 except cp.cuda.memory.OutOfMemoryError: 

590 # If even single slice fails, try smaller chunks 

591 logger.warning(f"🔧 CPU FALLBACK: Slice {i} too large, converting in sub-chunks...") 

592 slice_chunks = [] 

593 chunk_size = y // 4 # Quarter the slice 

594 

595 for start_y in range(0, y, chunk_size): 

596 end_y = min(start_y + chunk_size, y) 

597 chunk_cpu = image[i, start_y:end_y].get() 

598 slice_chunks.append(chunk_cpu) 

599 cp.get_default_memory_pool().free_all_blocks() 

600 

601 # Reassemble slice on CPU 

602 import numpy as np 

603 slice_cpu = np.concatenate(slice_chunks, axis=0) 

604 image_cpu_slices.append(slice_cpu) 

605 

606 # Reassemble full image on CPU 

607 import numpy as np 

608 image_cpu = np.stack(image_cpu_slices, axis=0) 

609 

610 logger.info(f"🔧 CPU FALLBACK: Successfully converted to CPU, processing {image_cpu.shape} image") 

611 

612 # Process on CPU 

613 corrected_cpu = basic_flatfield_correction_numpy( 

614 image_cpu, 

615 max_iters=max_iters, 

616 lambda_sparse=lambda_sparse, 

617 lambda_lowrank=lambda_lowrank, 

618 rank=rank, 

619 tol=tol, 

620 correction_mode=correction_mode, 

621 normalize_output=normalize_output, 

622 verbose=verbose 

623 ) 

624 

625 # 🔧 AGGRESSIVE GPU CLEANUP: Clear ALL intermediate data before converting result back 

626 logger.info(f"🔧 CPU FALLBACK: Clearing all GPU memory before converting result back...") 

627 try: 

628 # Clear all GPU memory pools 

629 cp.get_default_memory_pool().free_all_blocks() 

630 cp.get_default_pinned_memory_pool().free_all_blocks() 

631 

632 # Force GPU synchronization and cleanup 

633 cp.cuda.runtime.deviceSynchronize() 

634 

635 # Additional cleanup - clear any cached kernels (if available) 

636 try: 

637 cp.fuse.clear_memo() 

638 except AttributeError: 

639 pass # Not available in this CuPy version 

640 

641 # Force garbage collection 

642 import gc 

643 gc.collect() 

644 

645 logger.debug(f"🔧 CPU FALLBACK: GPU memory cleared, converting result back to CuPy...") 

646 

647 except Exception as cleanup_error: 

648 logger.warning(f"🔧 CPU FALLBACK: GPU cleanup warning: {cleanup_error}") 

649 

650 # Convert result back to CuPy (should now have enough memory) 

651 logger.info(f"🔧 CPU FALLBACK: Converting result back to CuPy...") 

652 corrected = cp.asarray(corrected_cpu) 

653 

654 logger.info(f"🔧 CPU FALLBACK: Successfully processed on CPU and converted back to GPU") 

655 return corrected 

656 

657 except Exception as cpu_error: 

658 logger.error(f"🔧 CPU FALLBACK: Failed to process on CPU: {cpu_error}") 

659 raise RuntimeError(f"Both GPU and CPU processing failed. GPU OOM, CPU error: {cpu_error}") from cpu_error