Coverage for openhcs/processing/backends/enhance/basic_processor_cupy.py: 6.3%

273 statements  

« prev     ^ index     » next       coverage.py v7.11.0, created at 2025-11-04 02:09 +0000

1""" 

2BaSiC (Background and Shading Correction) Implementation using CuPy 

3 

4This module implements the BaSiC algorithm for illumination correction 

5using CuPy for GPU acceleration. The implementation is based on the paper: 

6Peng et al., "A BaSiC tool for background and shading correction of optical 

7microscopy images", Nature Communications, 2017. 

8 

9The algorithm performs low-rank + sparse matrix decomposition to separate 

10uneven illumination artifacts from structural features in microscopy images. 

11 

12Doctrinal Clauses: 

13- Clause 3 — Declarative Primacy: All functions are pure and stateless 

14- Clause 65 — Fail Loudly: No silent fallbacks or inferred capabilities 

15- Clause 88 — No Inferred Capabilities: Explicit CuPy dependency 

16- Clause 273 — Memory Backend Restrictions: GPU-only implementation 

17""" 

18from __future__ import annotations 

19 

20import logging 

21from typing import TYPE_CHECKING, Any 

22 

23# Import decorator directly from core.memory to avoid circular imports 

24from openhcs.core.memory.decorators import cupy as cupy_func 

25from openhcs.core.utils import optional_import 

26 

27# For type checking only 

28if TYPE_CHECKING: 28 ↛ 29line 28 didn't jump to line 29 because the condition on line 28 was never true

29 import cupy as cp 

30 

31# Import CuPy as an optional dependency 

32cp = optional_import("cupy") 

33cupyx_scipy = None 

34if cp is not None: 34 ↛ 37line 34 didn't jump to line 37 because the condition on line 34 was always true

35 cupyx_scipy = optional_import("cupyx.scipy") 

36 

37logger = logging.getLogger(__name__) 

38 

39 

40def _validate_cupy_array(array: Any, name: str = "input") -> None: 

41 """ 

42 Validate that the input is a CuPy array. 

43 

44 Args: 

45 array: Array to validate 

46 name: Name of the array for error messages 

47 

48 Raises: 

49 ImportError: If CuPy is not available 

50 TypeError: If the array is not a CuPy array 

51 ValueError: If the array doesn't support DLPack 

52 """ 

53 # The compiler will ensure this function is only called when CuPy is available 

54 # No need to check for CuPy availability here 

55 

56 if not isinstance(array, cp.ndarray): 

57 raise TypeError( 

58 f"{name} must be a CuPy array, got {type(array)}. " 

59 f"No automatic conversion is performed to maintain explicit contracts. " 

60 f"Use DLPack for zero-copy GPU-to-GPU transfers." 

61 ) 

62 

63 # Ensure the array supports DLPack 

64 if not hasattr(array, "__dlpack__") and not hasattr(array, "toDlpack"): 

65 raise ValueError( 

66 f"{name} does not support DLPack protocol. " 

67 f"DLPack is required for GPU memory conversions." 

68 ) 

69 

70 

71def _low_rank_approximation(matrix: "cp.ndarray", rank: int = 3, max_memory_gb: float = 1.0) -> "cp.ndarray": 

72 """ 

73 Compute a low-rank approximation of a matrix using SVD with memory optimization. 

74 

75 Args: 

76 matrix: Input matrix to approximate 

77 rank: Target rank for the approximation 

78 max_memory_gb: Maximum memory to use for SVD (in GB) 

79 

80 Returns: 

81 Low-rank approximation of the input matrix 

82 """ 

83 # Estimate memory usage for SVD 

84 matrix_size_gb = matrix.nbytes / (1024**3) 

85 svd_memory_estimate = matrix_size_gb * 3 # U, s, Vh matrices 

86 

87 if svd_memory_estimate > max_memory_gb: 

88 # Use chunked processing for large matrices 

89 logger.info(f"🔧 MEMORY OPTIMIZATION: Matrix too large ({svd_memory_estimate:.2f}GB > {max_memory_gb}GB), using chunked SVD") 

90 return _chunked_low_rank_approximation(matrix, rank, max_memory_gb) 

91 else: 

92 # Use standard SVD for smaller matrices 

93 try: 

94 # Perform SVD using CuPy's built-in linalg 

95 U, s, Vh = cp.linalg.svd(matrix, full_matrices=False) 

96 

97 # Truncate to the specified rank 

98 s[rank:] = 0 

99 

100 # Reconstruct the low-rank matrix 

101 low_rank = (U * s) @ Vh 

102 

103 return low_rank 

104 # except (cp.cuda.memory.OutOfMemoryError, cp.cuda.cuda.CUDAError): 

105 except (cp.cuda.memory.OutOfMemoryError, 

106 cp.cuda.runtime.CUDARuntimeError, 

107 cp.cuda.cusolver.CUSOLVERError, 

108 cp.cuda.cublas.CUBLASError): 

109 # Fallback to chunked processing if standard SVD fails 

110 logger.warning("🔧 MEMORY OPTIMIZATION: Standard SVD failed, falling back to chunked processing") 

111 return _chunked_low_rank_approximation(matrix, rank, max_memory_gb) 

112 

113 

114def _chunked_low_rank_approximation(matrix: "cp.ndarray", rank: int, max_memory_gb: float) -> "cp.ndarray": 

115 """ 

116 Compute low-rank approximation using adaptive dynamic chunking. 

117 

118 Automatically adjusts chunk sizes based on available GPU memory and allocation 

119 success/failure patterns for optimal performance. 

120 

121 Args: 

122 matrix: Input matrix to approximate (Z, Y*X) 

123 rank: Target rank for the approximation 

124 max_memory_gb: Maximum memory to use per chunk (fallback limit) 

125 

126 Returns: 

127 Low-rank approximation of the input matrix 

128 """ 

129 Z, YX = matrix.shape 

130 

131 # 🔧 ADAPTIVE CHUNKING: Query available GPU memory for dynamic sizing 

132 try: 

133 free_memory, total_memory = cp.cuda.runtime.memGetInfo() 

134 available_gb = free_memory / (1024**3) 

135 logger.debug(f"🔧 ADAPTIVE CHUNKING: {available_gb:.2f}GB free GPU memory") 

136 except Exception: 

137 # Fallback to conservative estimate if memory query fails 

138 available_gb = max_memory_gb * 0.5 

139 logger.debug(f"🔧 ADAPTIVE CHUNKING: Memory query failed, using conservative {available_gb:.2f}GB") 

140 

141 # Calculate initial chunk size based on available memory 

142 bytes_per_element = matrix.dtype.itemsize 

143 svd_overhead = 8 # Conservative estimate for SVD workspace requirements 

144 

145 # Use 10% of available memory for initial chunk size 

146 usable_memory_gb = min(available_gb * 0.1, max_memory_gb) 

147 max_elements_per_chunk = int((usable_memory_gb * 1024**3) / (bytes_per_element * svd_overhead)) 

148 initial_chunk_size = min(YX, max_elements_per_chunk // Z) 

149 

150 # Enforce reasonable bounds 

151 min_chunk_size = 100 

152 max_chunk_size = YX // 4 # Don't make chunks larger than 25% of total 

153 current_chunk_size = max(min_chunk_size, min(initial_chunk_size, max_chunk_size)) 

154 

155 logger.debug(f"🔧 ADAPTIVE CHUNKING: Starting with chunk size {current_chunk_size:,} (of {YX:,} total)") 

156 

157 # Adaptive feedback variables 

158 success_count = 0 

159 low_rank_chunks = [] 

160 current_pos = 0 

161 

162 # Process matrix with adaptive chunking 

163 while current_pos < YX: 

164 end_pos = min(current_pos + current_chunk_size, YX) 

165 chunk = matrix[:, current_pos:end_pos] 

166 

167 try: 

168 # Try to process this chunk on GPU 

169 U, s, Vh = cp.linalg.svd(chunk, full_matrices=False) 

170 

171 # Truncate to the specified rank 

172 s_truncated = s.copy() 

173 s_truncated[rank:] = 0 

174 

175 # Reconstruct the low-rank chunk 

176 low_rank_chunk = (U * s_truncated) @ Vh 

177 low_rank_chunks.append(low_rank_chunk) 

178 

179 # 🔧 SUCCESS: Move to next chunk and track success 

180 current_pos = end_pos 

181 success_count += 1 

182 

183 # Grow chunk size after 3 consecutive successes 

184 if success_count >= 3: 

185 old_size = current_chunk_size 

186 current_chunk_size = min(int(current_chunk_size * 1.5), max_chunk_size) 

187 if current_chunk_size > old_size: 

188 logger.debug(f"🔧 ADAPTIVE CHUNKING: Growing chunk size {old_size:,}{current_chunk_size:,}") 

189 success_count = 0 

190 

191 except (cp.cuda.memory.OutOfMemoryError, 

192 cp.cuda.runtime.CUDARuntimeError, 

193 cp.cuda.cusolver.CUSOLVERError, 

194 cp.cuda.cublas.CUBLASError): 

195 # 🔧 FAILURE: Shrink chunk size and retry 

196 old_size = current_chunk_size 

197 current_chunk_size = max(int(current_chunk_size * 0.5), min_chunk_size) 

198 success_count = 0 

199 

200 if current_chunk_size < old_size: 

201 logger.info(f"🔧 ADAPTIVE CHUNKING: OOM, shrinking chunk size {old_size:,}{current_chunk_size:,}") 

202 # Don't advance position, retry with smaller chunk 

203 continue 

204 else: 

205 # Chunk size can't be reduced further, fallback to CPU for this chunk 

206 logger.warning(f"🔧 ADAPTIVE CHUNKING: Minimum chunk size reached, using CPU for chunk {current_pos}:{end_pos}") 

207 

208 try: 

209 # Process on CPU 

210 chunk_cpu = chunk.get() 

211 import numpy as np 

212 U_cpu, s_cpu, Vh_cpu = np.linalg.svd(chunk_cpu, full_matrices=False) 

213 s_cpu[rank:] = 0 

214 low_rank_chunk_cpu = (U_cpu * s_cpu) @ Vh_cpu 

215 low_rank_chunk = cp.asarray(low_rank_chunk_cpu) 

216 low_rank_chunks.append(low_rank_chunk) 

217 

218 logger.debug("🔧 ADAPTIVE CHUNKING: Successfully processed chunk on CPU") 

219 current_pos = end_pos 

220 

221 except Exception as cpu_error: 

222 logger.error(f"🔧 ADAPTIVE CHUNKING: CPU fallback failed: {cpu_error}") 

223 # Last resort: use identity (no correction for this chunk) 

224 logger.warning(f"🔧 ADAPTIVE CHUNKING: Using identity matrix for chunk {current_pos}:{end_pos}") 

225 low_rank_chunk = chunk.copy() 

226 low_rank_chunks.append(low_rank_chunk) 

227 current_pos = end_pos 

228 

229 # Concatenate all processed chunks 

230 low_rank = cp.concatenate(low_rank_chunks, axis=1) 

231 

232 logger.debug(f"🔧 ADAPTIVE CHUNKING: Completed processing {len(low_rank_chunks)} chunks") 

233 return low_rank 

234 

235 

236def _soft_threshold(matrix: "cp.ndarray", threshold: float) -> "cp.ndarray": 

237 """ 

238 Apply soft thresholding (shrinkage operator) to a matrix. 

239 

240 Args: 

241 matrix: Input matrix 

242 threshold: Threshold value for soft thresholding 

243 

244 Returns: 

245 Soft-thresholded matrix 

246 """ 

247 return cp.sign(matrix) * cp.maximum(cp.abs(matrix) - threshold, 0) 

248 

249 

250@cupy_func 

251def basic_flatfield_correction_cupy( 

252 image: "cp.ndarray", 

253 *, 

254 max_iters: int = 50, 

255 lambda_sparse: float = 0.01, 

256 lambda_lowrank: float = 0.1, 

257 rank: int = 3, 

258 tol: float = 1e-4, 

259 correction_mode: str = "divide", 

260 normalize_output: bool = True, 

261 verbose: bool = False, 

262 max_memory_gb: float = 1.0, 

263 **kwargs 

264) -> "cp.ndarray": 

265 """ 

266 Perform BaSiC-style illumination correction on a 3D image stack using CuPy. 

267 

268 This function implements the BaSiC algorithm for illumination correction 

269 using low-rank + sparse matrix decomposition. It models the background 

270 (shading field) as a low-rank matrix across slices and the residuals 

271 (e.g., nuclei, structures) as sparse features. 

272 

273 Memory-optimized version that automatically uses chunked processing for 

274 large images to prevent CUDA out-of-memory errors. 

275 

276 Args: 

277 image: 3D CuPy array of shape (Z, Y, X) 

278 max_iters: Maximum number of iterations for the alternating minimization 

279 lambda_sparse: Regularization parameter for the sparse component 

280 lambda_lowrank: Regularization parameter for the low-rank component 

281 rank: Target rank for the low-rank approximation 

282 tol: Tolerance for convergence 

283 correction_mode: Mode for applying the correction ('divide' or 'subtract') 

284 normalize_output: Whether to normalize the output to preserve dynamic range 

285 verbose: Whether to print progress information 

286 max_memory_gb: Maximum memory to use for SVD operations (in GB) 

287 **kwargs: Additional parameters (ignored) 

288 

289 Returns: 

290 Corrected 3D CuPy array of shape (Z, Y, X) 

291 

292 Raises: 

293 ImportError: If CuPy is not available 

294 TypeError: If the input is not a CuPy array 

295 ValueError: If the input is not a 3D array or if correction_mode is invalid 

296 or if the input array doesn't support DLPack 

297 """ 

298 # Validate input 

299 _validate_cupy_array(image) 

300 

301 if image.ndim != 3: 

302 raise ValueError(f"Input must be a 3D array, got {image.ndim}D") 

303 

304 if correction_mode not in ["divide", "subtract"]: 

305 raise ValueError(f"Invalid correction mode: {correction_mode}. " 

306 f"Must be 'divide' or 'subtract'") 

307 

308 # Store original shape and dtype 

309 z, y, x = image.shape 

310 orig_dtype = image.dtype 

311 

312 # 🔍 MEMORY ESTIMATION: Check if image is likely to cause memory issues 

313 image_size_gb = image.nbytes / (1024**3) 

314 estimated_peak_memory = image_size_gb * 4 # Original + float + L + S matrices 

315 

316 # Try GPU processing first, fallback to CPU on OOM 

317 try: 

318 return _gpu_flatfield_correction( 

319 image, max_iters, lambda_sparse, lambda_lowrank, rank, tol, 

320 correction_mode, normalize_output, verbose, max_memory_gb, 

321 image_size_gb, estimated_peak_memory, z, y, x, orig_dtype 

322 ) 

323 # except (cp.cuda.memory.OutOfMemoryError, cp.cuda.cuda.CUDAError): 

324 except (cp.cuda.memory.OutOfMemoryError, 

325 cp.cuda.runtime.CUDARuntimeError, 

326 cp.cuda.cusolver.CUSOLVERError, 

327 cp.cuda.cublas.CUBLASError) as oom_error: 

328 logger.warning(f"🔧 GPU OOM: {oom_error}") 

329 logger.info(f"🔧 CPU FALLBACK: GPU processing failed, switching to CPU for {z}×{y}×{x} image") 

330 

331 # 🔧 CRITICAL: Delete ALL intermediate variables from failed GPU processing 

332 logger.debug("🔧 CPU FALLBACK: Cleaning up intermediate GPU variables...") 

333 try: 

334 # These variables exist in _gpu_flatfield_correction scope, need to pass them out 

335 # For now, just do aggressive memory cleanup 

336 cp.get_default_memory_pool().free_all_blocks() 

337 cp.get_default_pinned_memory_pool().free_all_blocks() 

338 cp.cuda.runtime.deviceSynchronize() 

339 

340 # Force garbage collection 

341 import gc 

342 gc.collect() 

343 

344 # Check memory after cleanup 

345 free_after_cleanup, total = cp.cuda.runtime.memGetInfo() 

346 logger.info(f"🔧 CPU FALLBACK: After cleanup: {free_after_cleanup / 1e9:.2f}GB free of {total / 1e9:.2f}GB total") 

347 

348 except Exception as cleanup_error: 

349 logger.warning(f"🔧 CPU FALLBACK: Cleanup warning: {cleanup_error}") 

350 

351 # Fallback to CPU processing 

352 return _cpu_fallback_flatfield_correction( 

353 image, max_iters, lambda_sparse, lambda_lowrank, rank, tol, 

354 correction_mode, normalize_output, verbose 

355 ) 

356 

357 

358def _gpu_flatfield_correction( 

359 image: "cp.ndarray", max_iters: int, lambda_sparse: float, lambda_lowrank: float, 

360 rank: int, tol: float, correction_mode: str, normalize_output: bool, verbose: bool, 

361 max_memory_gb: float, image_size_gb: float, estimated_peak_memory: float, z: int, y: int, x: int, orig_dtype 

362) -> "cp.ndarray": 

363 """GPU-based flatfield correction implementation.""" 

364 

365 if estimated_peak_memory > max_memory_gb * 2: 

366 logger.warning(f"⚠️ Large image detected: {z}×{y}×{x} ({image_size_gb:.2f}GB). " 

367 f"Estimated peak memory: {estimated_peak_memory:.2f}GB. " 

368 f"Consider reducing image size or increasing max_memory_gb parameter.") 

369 

370 logger.debug(f"🔧 MEMORY INFO: Image size {z}×{y}×{x}, {image_size_gb:.2f}GB, " 

371 f"max_memory_gb={max_memory_gb}, estimated peak={estimated_peak_memory:.2f}GB") 

372 

373 # Initialize variables to None for proper cleanup 

374 image_float = None 

375 D = None 

376 L = None 

377 S = None 

378 L_stack = None 

379 corrected = None 

380 

381 try: 

382 # Convert to float for processing 

383 image_float = image.astype(cp.float32) 

384 

385 # Flatten each Z-slice into a row vector 

386 # D has shape (Z, Y*X) 

387 D = image_float.reshape(z, y * x) 

388 

389 # Initialize variables for alternating minimization 

390 L = cp.zeros_like(D) # Low-rank component (background/illumination) 

391 S = cp.zeros_like(D) # Sparse component (foreground/structures) 

392 

393 # Compute initial norm for convergence check 

394 norm_D = cp.linalg.norm(D, 'fro') 

395 

396 # Track convergence for early termination 

397 prev_residual = float('inf') 

398 stagnation_count = 0 

399 max_stagnation = 5 # Stop if no improvement for 5 iterations 

400 

401 # Alternating minimization loop 

402 for iteration in range(max_iters): 

403 # Update low-rank component (L) with memory optimization 

404 L = _low_rank_approximation(D - S, rank=rank, max_memory_gb=max_memory_gb) 

405 

406 # Apply regularization to L if needed 

407 if lambda_lowrank > 0: 

408 L = L * (1 - lambda_lowrank) 

409 

410 # Update sparse component (S) 

411 S = _soft_threshold(D - L, lambda_sparse) 

412 

413 # Check convergence 

414 residual = cp.linalg.norm(D - L - S, 'fro') / norm_D 

415 if verbose and (iteration % 10 == 0 or iteration == max_iters - 1): 

416 logger.info(f"Iteration {iteration+1}/{max_iters}, residual: {residual:.6f}") 

417 

418 # Early termination conditions 

419 if residual < tol: 

420 if verbose: 

421 logger.info(f"Converged after {iteration+1} iterations (residual < {tol})") 

422 break 

423 

424 # Check for stagnation (no significant improvement) 

425 improvement = prev_residual - residual 

426 if improvement < tol * 0.1: # Less than 10% of tolerance improvement 

427 stagnation_count += 1 

428 if stagnation_count >= max_stagnation: 

429 if verbose: 

430 logger.info(f"Early termination after {iteration+1} iterations (stagnation)") 

431 break 

432 else: 

433 stagnation_count = 0 # Reset counter if we see improvement 

434 

435 prev_residual = residual 

436 

437 # Reshape the low-rank component back to 3D 

438 L_stack = L.reshape(z, y, x) 

439 

440 # Apply correction 

441 if correction_mode == "divide": 

442 # Add small epsilon to avoid division by zero 

443 eps = 1e-6 

444 corrected = image_float / (L_stack + eps) 

445 

446 # Normalize to preserve dynamic range 

447 if normalize_output: 

448 corrected *= cp.mean(L_stack) 

449 else: # subtract 

450 corrected = image_float - L_stack 

451 

452 # Normalize to preserve dynamic range 

453 if normalize_output: 

454 corrected += cp.mean(L_stack) 

455 

456 # Clip to valid range and convert back to original dtype 

457 if cp.issubdtype(orig_dtype, cp.integer): 

458 max_val = cp.iinfo(orig_dtype).max 

459 corrected = cp.clip(corrected, 0, max_val).astype(orig_dtype) 

460 else: 

461 corrected = cp.clip(corrected, 0, None).astype(orig_dtype) 

462 

463 return corrected 

464 

465 except (cp.cuda.memory.OutOfMemoryError, 

466 cp.cuda.runtime.CUDARuntimeError, 

467 cp.cuda.cusolver.CUSOLVERError, 

468 cp.cuda.cublas.CUBLASError) as gpu_error: 

469 

470 # 🔧 CRITICAL: Clean up ALL intermediate variables before re-raising 

471 logger.debug("🔧 GPU PROCESSING: Cleaning up intermediate variables after OOM...") 

472 try: 

473 # Delete all local intermediate variables 

474 if 'image_float' in locals() and image_float is not None: 

475 del image_float 

476 if 'D' in locals() and D is not None: 

477 del D 

478 if 'L' in locals() and L is not None: 

479 del L 

480 if 'S' in locals() and S is not None: 

481 del S 

482 if 'L_stack' in locals() and L_stack is not None: 

483 del L_stack 

484 if 'corrected' in locals() and corrected is not None: 

485 del corrected 

486 

487 # Force garbage collection 

488 import gc 

489 gc.collect() 

490 

491 logger.debug("🔧 GPU PROCESSING: Intermediate variables cleaned up") 

492 

493 except Exception as cleanup_error: 

494 logger.warning(f"🔧 GPU PROCESSING: Variable cleanup warning: {cleanup_error}") 

495 

496 # Re-raise the original GPU error for the outer exception handler 

497 raise gpu_error 

498 

499 

500def basic_flatfield_correction_batch_cupy( 

501 image_batch: "cp.ndarray", 

502 *, 

503 batch_dim: int = 0, 

504 **kwargs 

505) -> "cp.ndarray": 

506 """ 

507 Apply BaSiC flatfield correction to a batch of 3D image stacks. 

508 

509 This function applies the BaSiC algorithm to each 3D stack in a batch. 

510 

511 Args: 

512 image_batch: 4D CuPy array of shape (B, Z, Y, X) or (Z, B, Y, X) 

513 batch_dim: Dimension along which the batch is organized (0 or 1) 

514 **kwargs: Additional parameters passed to basic_flatfield_correction_cupy 

515 

516 Returns: 

517 Corrected 4D CuPy array of the same shape as input 

518 

519 Raises: 

520 ImportError: If CuPy is not available 

521 TypeError: If the input is not a CuPy array 

522 ValueError: If the input is not a 4D array or if batch_dim is invalid 

523 or if the input array doesn't support DLPack 

524 """ 

525 # Validate input 

526 _validate_cupy_array(image_batch) 

527 

528 if image_batch.ndim != 4: 

529 raise ValueError(f"Input must be a 4D array, got {image_batch.ndim}D") 

530 

531 if batch_dim not in [0, 1]: 

532 raise ValueError(f"batch_dim must be 0 or 1, got {batch_dim}") 

533 

534 # Process each 3D stack in the batch 

535 result_list = [] 

536 

537 if batch_dim == 0: 

538 # Batch is organized as (B, Z, Y, X) 

539 for b in range(image_batch.shape[0]): 

540 corrected = basic_flatfield_correction_cupy(image_batch[b], **kwargs) 

541 result_list.append(corrected) 

542 

543 # Stack along batch dimension 

544 return cp.stack(result_list, axis=0) 

545 

546 # Batch is organized as (Z, B, Y, X) 

547 for b in range(image_batch.shape[1]): 

548 corrected = basic_flatfield_correction_cupy(image_batch[:, b], **kwargs) 

549 result_list.append(corrected) 

550 

551 # Stack along batch dimension 

552 return cp.stack(result_list, axis=1) 

553 

554 

555def _cpu_fallback_flatfield_correction( 

556 image: "cp.ndarray", max_iters: int, lambda_sparse: float, lambda_lowrank: float, 

557 rank: int, tol: float, correction_mode: str, normalize_output: bool, verbose: bool 

558) -> "cp.ndarray": 

559 """CPU fallback for flatfield correction when GPU runs out of memory.""" 

560 

561 try: 

562 from openhcs.processing.backends.enhance.basic_processor_numpy import basic_flatfield_correction_numpy 

563 

564 # 🔧 AGGRESSIVE GPU CLEANUP: Free as much GPU memory as possible before CPU conversion 

565 logger.info("🔧 CPU FALLBACK: Clearing GPU memory before CPU conversion...") 

566 try: 

567 cp.get_default_memory_pool().free_all_blocks() 

568 cp.get_default_pinned_memory_pool().free_all_blocks() 

569 cp.cuda.runtime.deviceSynchronize() 

570 except Exception: 

571 pass 

572 

573 # Convert CuPy array to NumPy in chunks to avoid large GPU allocation 

574 logger.info("🔧 CPU FALLBACK: Converting CuPy array to NumPy in chunks...") 

575 z, y, x = image.shape 

576 

577 # Convert slice by slice to minimize GPU memory usage 

578 image_cpu_slices = [] 

579 for i in range(z): 

580 try: 

581 slice_cpu = image[i].get() # Convert one slice at a time 

582 image_cpu_slices.append(slice_cpu) 

583 

584 # Clear GPU memory after each slice 

585 if i % 10 == 0: # Every 10 slices 

586 cp.get_default_memory_pool().free_all_blocks() 

587 

588 except cp.cuda.memory.OutOfMemoryError: 

589 # If even single slice fails, try smaller chunks 

590 logger.warning(f"🔧 CPU FALLBACK: Slice {i} too large, converting in sub-chunks...") 

591 slice_chunks = [] 

592 chunk_size = y // 4 # Quarter the slice 

593 

594 for start_y in range(0, y, chunk_size): 

595 end_y = min(start_y + chunk_size, y) 

596 chunk_cpu = image[i, start_y:end_y].get() 

597 slice_chunks.append(chunk_cpu) 

598 cp.get_default_memory_pool().free_all_blocks() 

599 

600 # Reassemble slice on CPU 

601 import numpy as np 

602 slice_cpu = np.concatenate(slice_chunks, axis=0) 

603 image_cpu_slices.append(slice_cpu) 

604 

605 # Reassemble full image on CPU 

606 import numpy as np 

607 image_cpu = np.stack(image_cpu_slices, axis=0) 

608 

609 logger.info(f"🔧 CPU FALLBACK: Successfully converted to CPU, processing {image_cpu.shape} image") 

610 

611 # Process on CPU 

612 corrected_cpu = basic_flatfield_correction_numpy( 

613 image_cpu, 

614 max_iters=max_iters, 

615 lambda_sparse=lambda_sparse, 

616 lambda_lowrank=lambda_lowrank, 

617 rank=rank, 

618 tol=tol, 

619 correction_mode=correction_mode, 

620 normalize_output=normalize_output, 

621 verbose=verbose 

622 ) 

623 

624 # 🔧 AGGRESSIVE GPU CLEANUP: Clear ALL intermediate data before converting result back 

625 logger.info("🔧 CPU FALLBACK: Clearing all GPU memory before converting result back...") 

626 try: 

627 # Clear all GPU memory pools 

628 cp.get_default_memory_pool().free_all_blocks() 

629 cp.get_default_pinned_memory_pool().free_all_blocks() 

630 

631 # Force GPU synchronization and cleanup 

632 cp.cuda.runtime.deviceSynchronize() 

633 

634 # Additional cleanup - clear any cached kernels (if available) 

635 try: 

636 cp.fuse.clear_memo() 

637 except AttributeError: 

638 pass # Not available in this CuPy version 

639 

640 # Force garbage collection 

641 import gc 

642 gc.collect() 

643 

644 logger.debug("🔧 CPU FALLBACK: GPU memory cleared, converting result back to CuPy...") 

645 

646 except Exception as cleanup_error: 

647 logger.warning(f"🔧 CPU FALLBACK: GPU cleanup warning: {cleanup_error}") 

648 

649 # Convert result back to CuPy (should now have enough memory) 

650 logger.info("🔧 CPU FALLBACK: Converting result back to CuPy...") 

651 corrected = cp.asarray(corrected_cpu) 

652 

653 logger.info("🔧 CPU FALLBACK: Successfully processed on CPU and converted back to GPU") 

654 return corrected 

655 

656 except Exception as cpu_error: 

657 logger.error(f"🔧 CPU FALLBACK: Failed to process on CPU: {cpu_error}") 

658 raise RuntimeError(f"Both GPU and CPU processing failed. GPU OOM, CPU error: {cpu_error}") from cpu_error