Coverage for openhcs/processing/backends/enhance/basic_processor_cupy.py: 6.3%
273 statements
« prev ^ index » next coverage.py v7.11.0, created at 2025-11-04 02:09 +0000
« prev ^ index » next coverage.py v7.11.0, created at 2025-11-04 02:09 +0000
1"""
2BaSiC (Background and Shading Correction) Implementation using CuPy
4This module implements the BaSiC algorithm for illumination correction
5using CuPy for GPU acceleration. The implementation is based on the paper:
6Peng et al., "A BaSiC tool for background and shading correction of optical
7microscopy images", Nature Communications, 2017.
9The algorithm performs low-rank + sparse matrix decomposition to separate
10uneven illumination artifacts from structural features in microscopy images.
12Doctrinal Clauses:
13- Clause 3 — Declarative Primacy: All functions are pure and stateless
14- Clause 65 — Fail Loudly: No silent fallbacks or inferred capabilities
15- Clause 88 — No Inferred Capabilities: Explicit CuPy dependency
16- Clause 273 — Memory Backend Restrictions: GPU-only implementation
17"""
18from __future__ import annotations
20import logging
21from typing import TYPE_CHECKING, Any
23# Import decorator directly from core.memory to avoid circular imports
24from openhcs.core.memory.decorators import cupy as cupy_func
25from openhcs.core.utils import optional_import
27# For type checking only
28if TYPE_CHECKING: 28 ↛ 29line 28 didn't jump to line 29 because the condition on line 28 was never true
29 import cupy as cp
31# Import CuPy as an optional dependency
32cp = optional_import("cupy")
33cupyx_scipy = None
34if cp is not None: 34 ↛ 37line 34 didn't jump to line 37 because the condition on line 34 was always true
35 cupyx_scipy = optional_import("cupyx.scipy")
37logger = logging.getLogger(__name__)
40def _validate_cupy_array(array: Any, name: str = "input") -> None:
41 """
42 Validate that the input is a CuPy array.
44 Args:
45 array: Array to validate
46 name: Name of the array for error messages
48 Raises:
49 ImportError: If CuPy is not available
50 TypeError: If the array is not a CuPy array
51 ValueError: If the array doesn't support DLPack
52 """
53 # The compiler will ensure this function is only called when CuPy is available
54 # No need to check for CuPy availability here
56 if not isinstance(array, cp.ndarray):
57 raise TypeError(
58 f"{name} must be a CuPy array, got {type(array)}. "
59 f"No automatic conversion is performed to maintain explicit contracts. "
60 f"Use DLPack for zero-copy GPU-to-GPU transfers."
61 )
63 # Ensure the array supports DLPack
64 if not hasattr(array, "__dlpack__") and not hasattr(array, "toDlpack"):
65 raise ValueError(
66 f"{name} does not support DLPack protocol. "
67 f"DLPack is required for GPU memory conversions."
68 )
71def _low_rank_approximation(matrix: "cp.ndarray", rank: int = 3, max_memory_gb: float = 1.0) -> "cp.ndarray":
72 """
73 Compute a low-rank approximation of a matrix using SVD with memory optimization.
75 Args:
76 matrix: Input matrix to approximate
77 rank: Target rank for the approximation
78 max_memory_gb: Maximum memory to use for SVD (in GB)
80 Returns:
81 Low-rank approximation of the input matrix
82 """
83 # Estimate memory usage for SVD
84 matrix_size_gb = matrix.nbytes / (1024**3)
85 svd_memory_estimate = matrix_size_gb * 3 # U, s, Vh matrices
87 if svd_memory_estimate > max_memory_gb:
88 # Use chunked processing for large matrices
89 logger.info(f"🔧 MEMORY OPTIMIZATION: Matrix too large ({svd_memory_estimate:.2f}GB > {max_memory_gb}GB), using chunked SVD")
90 return _chunked_low_rank_approximation(matrix, rank, max_memory_gb)
91 else:
92 # Use standard SVD for smaller matrices
93 try:
94 # Perform SVD using CuPy's built-in linalg
95 U, s, Vh = cp.linalg.svd(matrix, full_matrices=False)
97 # Truncate to the specified rank
98 s[rank:] = 0
100 # Reconstruct the low-rank matrix
101 low_rank = (U * s) @ Vh
103 return low_rank
104 # except (cp.cuda.memory.OutOfMemoryError, cp.cuda.cuda.CUDAError):
105 except (cp.cuda.memory.OutOfMemoryError,
106 cp.cuda.runtime.CUDARuntimeError,
107 cp.cuda.cusolver.CUSOLVERError,
108 cp.cuda.cublas.CUBLASError):
109 # Fallback to chunked processing if standard SVD fails
110 logger.warning("🔧 MEMORY OPTIMIZATION: Standard SVD failed, falling back to chunked processing")
111 return _chunked_low_rank_approximation(matrix, rank, max_memory_gb)
114def _chunked_low_rank_approximation(matrix: "cp.ndarray", rank: int, max_memory_gb: float) -> "cp.ndarray":
115 """
116 Compute low-rank approximation using adaptive dynamic chunking.
118 Automatically adjusts chunk sizes based on available GPU memory and allocation
119 success/failure patterns for optimal performance.
121 Args:
122 matrix: Input matrix to approximate (Z, Y*X)
123 rank: Target rank for the approximation
124 max_memory_gb: Maximum memory to use per chunk (fallback limit)
126 Returns:
127 Low-rank approximation of the input matrix
128 """
129 Z, YX = matrix.shape
131 # 🔧 ADAPTIVE CHUNKING: Query available GPU memory for dynamic sizing
132 try:
133 free_memory, total_memory = cp.cuda.runtime.memGetInfo()
134 available_gb = free_memory / (1024**3)
135 logger.debug(f"🔧 ADAPTIVE CHUNKING: {available_gb:.2f}GB free GPU memory")
136 except Exception:
137 # Fallback to conservative estimate if memory query fails
138 available_gb = max_memory_gb * 0.5
139 logger.debug(f"🔧 ADAPTIVE CHUNKING: Memory query failed, using conservative {available_gb:.2f}GB")
141 # Calculate initial chunk size based on available memory
142 bytes_per_element = matrix.dtype.itemsize
143 svd_overhead = 8 # Conservative estimate for SVD workspace requirements
145 # Use 10% of available memory for initial chunk size
146 usable_memory_gb = min(available_gb * 0.1, max_memory_gb)
147 max_elements_per_chunk = int((usable_memory_gb * 1024**3) / (bytes_per_element * svd_overhead))
148 initial_chunk_size = min(YX, max_elements_per_chunk // Z)
150 # Enforce reasonable bounds
151 min_chunk_size = 100
152 max_chunk_size = YX // 4 # Don't make chunks larger than 25% of total
153 current_chunk_size = max(min_chunk_size, min(initial_chunk_size, max_chunk_size))
155 logger.debug(f"🔧 ADAPTIVE CHUNKING: Starting with chunk size {current_chunk_size:,} (of {YX:,} total)")
157 # Adaptive feedback variables
158 success_count = 0
159 low_rank_chunks = []
160 current_pos = 0
162 # Process matrix with adaptive chunking
163 while current_pos < YX:
164 end_pos = min(current_pos + current_chunk_size, YX)
165 chunk = matrix[:, current_pos:end_pos]
167 try:
168 # Try to process this chunk on GPU
169 U, s, Vh = cp.linalg.svd(chunk, full_matrices=False)
171 # Truncate to the specified rank
172 s_truncated = s.copy()
173 s_truncated[rank:] = 0
175 # Reconstruct the low-rank chunk
176 low_rank_chunk = (U * s_truncated) @ Vh
177 low_rank_chunks.append(low_rank_chunk)
179 # 🔧 SUCCESS: Move to next chunk and track success
180 current_pos = end_pos
181 success_count += 1
183 # Grow chunk size after 3 consecutive successes
184 if success_count >= 3:
185 old_size = current_chunk_size
186 current_chunk_size = min(int(current_chunk_size * 1.5), max_chunk_size)
187 if current_chunk_size > old_size:
188 logger.debug(f"🔧 ADAPTIVE CHUNKING: Growing chunk size {old_size:,} → {current_chunk_size:,}")
189 success_count = 0
191 except (cp.cuda.memory.OutOfMemoryError,
192 cp.cuda.runtime.CUDARuntimeError,
193 cp.cuda.cusolver.CUSOLVERError,
194 cp.cuda.cublas.CUBLASError):
195 # 🔧 FAILURE: Shrink chunk size and retry
196 old_size = current_chunk_size
197 current_chunk_size = max(int(current_chunk_size * 0.5), min_chunk_size)
198 success_count = 0
200 if current_chunk_size < old_size:
201 logger.info(f"🔧 ADAPTIVE CHUNKING: OOM, shrinking chunk size {old_size:,} → {current_chunk_size:,}")
202 # Don't advance position, retry with smaller chunk
203 continue
204 else:
205 # Chunk size can't be reduced further, fallback to CPU for this chunk
206 logger.warning(f"🔧 ADAPTIVE CHUNKING: Minimum chunk size reached, using CPU for chunk {current_pos}:{end_pos}")
208 try:
209 # Process on CPU
210 chunk_cpu = chunk.get()
211 import numpy as np
212 U_cpu, s_cpu, Vh_cpu = np.linalg.svd(chunk_cpu, full_matrices=False)
213 s_cpu[rank:] = 0
214 low_rank_chunk_cpu = (U_cpu * s_cpu) @ Vh_cpu
215 low_rank_chunk = cp.asarray(low_rank_chunk_cpu)
216 low_rank_chunks.append(low_rank_chunk)
218 logger.debug("🔧 ADAPTIVE CHUNKING: Successfully processed chunk on CPU")
219 current_pos = end_pos
221 except Exception as cpu_error:
222 logger.error(f"🔧 ADAPTIVE CHUNKING: CPU fallback failed: {cpu_error}")
223 # Last resort: use identity (no correction for this chunk)
224 logger.warning(f"🔧 ADAPTIVE CHUNKING: Using identity matrix for chunk {current_pos}:{end_pos}")
225 low_rank_chunk = chunk.copy()
226 low_rank_chunks.append(low_rank_chunk)
227 current_pos = end_pos
229 # Concatenate all processed chunks
230 low_rank = cp.concatenate(low_rank_chunks, axis=1)
232 logger.debug(f"🔧 ADAPTIVE CHUNKING: Completed processing {len(low_rank_chunks)} chunks")
233 return low_rank
236def _soft_threshold(matrix: "cp.ndarray", threshold: float) -> "cp.ndarray":
237 """
238 Apply soft thresholding (shrinkage operator) to a matrix.
240 Args:
241 matrix: Input matrix
242 threshold: Threshold value for soft thresholding
244 Returns:
245 Soft-thresholded matrix
246 """
247 return cp.sign(matrix) * cp.maximum(cp.abs(matrix) - threshold, 0)
250@cupy_func
251def basic_flatfield_correction_cupy(
252 image: "cp.ndarray",
253 *,
254 max_iters: int = 50,
255 lambda_sparse: float = 0.01,
256 lambda_lowrank: float = 0.1,
257 rank: int = 3,
258 tol: float = 1e-4,
259 correction_mode: str = "divide",
260 normalize_output: bool = True,
261 verbose: bool = False,
262 max_memory_gb: float = 1.0,
263 **kwargs
264) -> "cp.ndarray":
265 """
266 Perform BaSiC-style illumination correction on a 3D image stack using CuPy.
268 This function implements the BaSiC algorithm for illumination correction
269 using low-rank + sparse matrix decomposition. It models the background
270 (shading field) as a low-rank matrix across slices and the residuals
271 (e.g., nuclei, structures) as sparse features.
273 Memory-optimized version that automatically uses chunked processing for
274 large images to prevent CUDA out-of-memory errors.
276 Args:
277 image: 3D CuPy array of shape (Z, Y, X)
278 max_iters: Maximum number of iterations for the alternating minimization
279 lambda_sparse: Regularization parameter for the sparse component
280 lambda_lowrank: Regularization parameter for the low-rank component
281 rank: Target rank for the low-rank approximation
282 tol: Tolerance for convergence
283 correction_mode: Mode for applying the correction ('divide' or 'subtract')
284 normalize_output: Whether to normalize the output to preserve dynamic range
285 verbose: Whether to print progress information
286 max_memory_gb: Maximum memory to use for SVD operations (in GB)
287 **kwargs: Additional parameters (ignored)
289 Returns:
290 Corrected 3D CuPy array of shape (Z, Y, X)
292 Raises:
293 ImportError: If CuPy is not available
294 TypeError: If the input is not a CuPy array
295 ValueError: If the input is not a 3D array or if correction_mode is invalid
296 or if the input array doesn't support DLPack
297 """
298 # Validate input
299 _validate_cupy_array(image)
301 if image.ndim != 3:
302 raise ValueError(f"Input must be a 3D array, got {image.ndim}D")
304 if correction_mode not in ["divide", "subtract"]:
305 raise ValueError(f"Invalid correction mode: {correction_mode}. "
306 f"Must be 'divide' or 'subtract'")
308 # Store original shape and dtype
309 z, y, x = image.shape
310 orig_dtype = image.dtype
312 # 🔍 MEMORY ESTIMATION: Check if image is likely to cause memory issues
313 image_size_gb = image.nbytes / (1024**3)
314 estimated_peak_memory = image_size_gb * 4 # Original + float + L + S matrices
316 # Try GPU processing first, fallback to CPU on OOM
317 try:
318 return _gpu_flatfield_correction(
319 image, max_iters, lambda_sparse, lambda_lowrank, rank, tol,
320 correction_mode, normalize_output, verbose, max_memory_gb,
321 image_size_gb, estimated_peak_memory, z, y, x, orig_dtype
322 )
323 # except (cp.cuda.memory.OutOfMemoryError, cp.cuda.cuda.CUDAError):
324 except (cp.cuda.memory.OutOfMemoryError,
325 cp.cuda.runtime.CUDARuntimeError,
326 cp.cuda.cusolver.CUSOLVERError,
327 cp.cuda.cublas.CUBLASError) as oom_error:
328 logger.warning(f"🔧 GPU OOM: {oom_error}")
329 logger.info(f"🔧 CPU FALLBACK: GPU processing failed, switching to CPU for {z}×{y}×{x} image")
331 # 🔧 CRITICAL: Delete ALL intermediate variables from failed GPU processing
332 logger.debug("🔧 CPU FALLBACK: Cleaning up intermediate GPU variables...")
333 try:
334 # These variables exist in _gpu_flatfield_correction scope, need to pass them out
335 # For now, just do aggressive memory cleanup
336 cp.get_default_memory_pool().free_all_blocks()
337 cp.get_default_pinned_memory_pool().free_all_blocks()
338 cp.cuda.runtime.deviceSynchronize()
340 # Force garbage collection
341 import gc
342 gc.collect()
344 # Check memory after cleanup
345 free_after_cleanup, total = cp.cuda.runtime.memGetInfo()
346 logger.info(f"🔧 CPU FALLBACK: After cleanup: {free_after_cleanup / 1e9:.2f}GB free of {total / 1e9:.2f}GB total")
348 except Exception as cleanup_error:
349 logger.warning(f"🔧 CPU FALLBACK: Cleanup warning: {cleanup_error}")
351 # Fallback to CPU processing
352 return _cpu_fallback_flatfield_correction(
353 image, max_iters, lambda_sparse, lambda_lowrank, rank, tol,
354 correction_mode, normalize_output, verbose
355 )
358def _gpu_flatfield_correction(
359 image: "cp.ndarray", max_iters: int, lambda_sparse: float, lambda_lowrank: float,
360 rank: int, tol: float, correction_mode: str, normalize_output: bool, verbose: bool,
361 max_memory_gb: float, image_size_gb: float, estimated_peak_memory: float, z: int, y: int, x: int, orig_dtype
362) -> "cp.ndarray":
363 """GPU-based flatfield correction implementation."""
365 if estimated_peak_memory > max_memory_gb * 2:
366 logger.warning(f"⚠️ Large image detected: {z}×{y}×{x} ({image_size_gb:.2f}GB). "
367 f"Estimated peak memory: {estimated_peak_memory:.2f}GB. "
368 f"Consider reducing image size or increasing max_memory_gb parameter.")
370 logger.debug(f"🔧 MEMORY INFO: Image size {z}×{y}×{x}, {image_size_gb:.2f}GB, "
371 f"max_memory_gb={max_memory_gb}, estimated peak={estimated_peak_memory:.2f}GB")
373 # Initialize variables to None for proper cleanup
374 image_float = None
375 D = None
376 L = None
377 S = None
378 L_stack = None
379 corrected = None
381 try:
382 # Convert to float for processing
383 image_float = image.astype(cp.float32)
385 # Flatten each Z-slice into a row vector
386 # D has shape (Z, Y*X)
387 D = image_float.reshape(z, y * x)
389 # Initialize variables for alternating minimization
390 L = cp.zeros_like(D) # Low-rank component (background/illumination)
391 S = cp.zeros_like(D) # Sparse component (foreground/structures)
393 # Compute initial norm for convergence check
394 norm_D = cp.linalg.norm(D, 'fro')
396 # Track convergence for early termination
397 prev_residual = float('inf')
398 stagnation_count = 0
399 max_stagnation = 5 # Stop if no improvement for 5 iterations
401 # Alternating minimization loop
402 for iteration in range(max_iters):
403 # Update low-rank component (L) with memory optimization
404 L = _low_rank_approximation(D - S, rank=rank, max_memory_gb=max_memory_gb)
406 # Apply regularization to L if needed
407 if lambda_lowrank > 0:
408 L = L * (1 - lambda_lowrank)
410 # Update sparse component (S)
411 S = _soft_threshold(D - L, lambda_sparse)
413 # Check convergence
414 residual = cp.linalg.norm(D - L - S, 'fro') / norm_D
415 if verbose and (iteration % 10 == 0 or iteration == max_iters - 1):
416 logger.info(f"Iteration {iteration+1}/{max_iters}, residual: {residual:.6f}")
418 # Early termination conditions
419 if residual < tol:
420 if verbose:
421 logger.info(f"Converged after {iteration+1} iterations (residual < {tol})")
422 break
424 # Check for stagnation (no significant improvement)
425 improvement = prev_residual - residual
426 if improvement < tol * 0.1: # Less than 10% of tolerance improvement
427 stagnation_count += 1
428 if stagnation_count >= max_stagnation:
429 if verbose:
430 logger.info(f"Early termination after {iteration+1} iterations (stagnation)")
431 break
432 else:
433 stagnation_count = 0 # Reset counter if we see improvement
435 prev_residual = residual
437 # Reshape the low-rank component back to 3D
438 L_stack = L.reshape(z, y, x)
440 # Apply correction
441 if correction_mode == "divide":
442 # Add small epsilon to avoid division by zero
443 eps = 1e-6
444 corrected = image_float / (L_stack + eps)
446 # Normalize to preserve dynamic range
447 if normalize_output:
448 corrected *= cp.mean(L_stack)
449 else: # subtract
450 corrected = image_float - L_stack
452 # Normalize to preserve dynamic range
453 if normalize_output:
454 corrected += cp.mean(L_stack)
456 # Clip to valid range and convert back to original dtype
457 if cp.issubdtype(orig_dtype, cp.integer):
458 max_val = cp.iinfo(orig_dtype).max
459 corrected = cp.clip(corrected, 0, max_val).astype(orig_dtype)
460 else:
461 corrected = cp.clip(corrected, 0, None).astype(orig_dtype)
463 return corrected
465 except (cp.cuda.memory.OutOfMemoryError,
466 cp.cuda.runtime.CUDARuntimeError,
467 cp.cuda.cusolver.CUSOLVERError,
468 cp.cuda.cublas.CUBLASError) as gpu_error:
470 # 🔧 CRITICAL: Clean up ALL intermediate variables before re-raising
471 logger.debug("🔧 GPU PROCESSING: Cleaning up intermediate variables after OOM...")
472 try:
473 # Delete all local intermediate variables
474 if 'image_float' in locals() and image_float is not None:
475 del image_float
476 if 'D' in locals() and D is not None:
477 del D
478 if 'L' in locals() and L is not None:
479 del L
480 if 'S' in locals() and S is not None:
481 del S
482 if 'L_stack' in locals() and L_stack is not None:
483 del L_stack
484 if 'corrected' in locals() and corrected is not None:
485 del corrected
487 # Force garbage collection
488 import gc
489 gc.collect()
491 logger.debug("🔧 GPU PROCESSING: Intermediate variables cleaned up")
493 except Exception as cleanup_error:
494 logger.warning(f"🔧 GPU PROCESSING: Variable cleanup warning: {cleanup_error}")
496 # Re-raise the original GPU error for the outer exception handler
497 raise gpu_error
500def basic_flatfield_correction_batch_cupy(
501 image_batch: "cp.ndarray",
502 *,
503 batch_dim: int = 0,
504 **kwargs
505) -> "cp.ndarray":
506 """
507 Apply BaSiC flatfield correction to a batch of 3D image stacks.
509 This function applies the BaSiC algorithm to each 3D stack in a batch.
511 Args:
512 image_batch: 4D CuPy array of shape (B, Z, Y, X) or (Z, B, Y, X)
513 batch_dim: Dimension along which the batch is organized (0 or 1)
514 **kwargs: Additional parameters passed to basic_flatfield_correction_cupy
516 Returns:
517 Corrected 4D CuPy array of the same shape as input
519 Raises:
520 ImportError: If CuPy is not available
521 TypeError: If the input is not a CuPy array
522 ValueError: If the input is not a 4D array or if batch_dim is invalid
523 or if the input array doesn't support DLPack
524 """
525 # Validate input
526 _validate_cupy_array(image_batch)
528 if image_batch.ndim != 4:
529 raise ValueError(f"Input must be a 4D array, got {image_batch.ndim}D")
531 if batch_dim not in [0, 1]:
532 raise ValueError(f"batch_dim must be 0 or 1, got {batch_dim}")
534 # Process each 3D stack in the batch
535 result_list = []
537 if batch_dim == 0:
538 # Batch is organized as (B, Z, Y, X)
539 for b in range(image_batch.shape[0]):
540 corrected = basic_flatfield_correction_cupy(image_batch[b], **kwargs)
541 result_list.append(corrected)
543 # Stack along batch dimension
544 return cp.stack(result_list, axis=0)
546 # Batch is organized as (Z, B, Y, X)
547 for b in range(image_batch.shape[1]):
548 corrected = basic_flatfield_correction_cupy(image_batch[:, b], **kwargs)
549 result_list.append(corrected)
551 # Stack along batch dimension
552 return cp.stack(result_list, axis=1)
555def _cpu_fallback_flatfield_correction(
556 image: "cp.ndarray", max_iters: int, lambda_sparse: float, lambda_lowrank: float,
557 rank: int, tol: float, correction_mode: str, normalize_output: bool, verbose: bool
558) -> "cp.ndarray":
559 """CPU fallback for flatfield correction when GPU runs out of memory."""
561 try:
562 from openhcs.processing.backends.enhance.basic_processor_numpy import basic_flatfield_correction_numpy
564 # 🔧 AGGRESSIVE GPU CLEANUP: Free as much GPU memory as possible before CPU conversion
565 logger.info("🔧 CPU FALLBACK: Clearing GPU memory before CPU conversion...")
566 try:
567 cp.get_default_memory_pool().free_all_blocks()
568 cp.get_default_pinned_memory_pool().free_all_blocks()
569 cp.cuda.runtime.deviceSynchronize()
570 except Exception:
571 pass
573 # Convert CuPy array to NumPy in chunks to avoid large GPU allocation
574 logger.info("🔧 CPU FALLBACK: Converting CuPy array to NumPy in chunks...")
575 z, y, x = image.shape
577 # Convert slice by slice to minimize GPU memory usage
578 image_cpu_slices = []
579 for i in range(z):
580 try:
581 slice_cpu = image[i].get() # Convert one slice at a time
582 image_cpu_slices.append(slice_cpu)
584 # Clear GPU memory after each slice
585 if i % 10 == 0: # Every 10 slices
586 cp.get_default_memory_pool().free_all_blocks()
588 except cp.cuda.memory.OutOfMemoryError:
589 # If even single slice fails, try smaller chunks
590 logger.warning(f"🔧 CPU FALLBACK: Slice {i} too large, converting in sub-chunks...")
591 slice_chunks = []
592 chunk_size = y // 4 # Quarter the slice
594 for start_y in range(0, y, chunk_size):
595 end_y = min(start_y + chunk_size, y)
596 chunk_cpu = image[i, start_y:end_y].get()
597 slice_chunks.append(chunk_cpu)
598 cp.get_default_memory_pool().free_all_blocks()
600 # Reassemble slice on CPU
601 import numpy as np
602 slice_cpu = np.concatenate(slice_chunks, axis=0)
603 image_cpu_slices.append(slice_cpu)
605 # Reassemble full image on CPU
606 import numpy as np
607 image_cpu = np.stack(image_cpu_slices, axis=0)
609 logger.info(f"🔧 CPU FALLBACK: Successfully converted to CPU, processing {image_cpu.shape} image")
611 # Process on CPU
612 corrected_cpu = basic_flatfield_correction_numpy(
613 image_cpu,
614 max_iters=max_iters,
615 lambda_sparse=lambda_sparse,
616 lambda_lowrank=lambda_lowrank,
617 rank=rank,
618 tol=tol,
619 correction_mode=correction_mode,
620 normalize_output=normalize_output,
621 verbose=verbose
622 )
624 # 🔧 AGGRESSIVE GPU CLEANUP: Clear ALL intermediate data before converting result back
625 logger.info("🔧 CPU FALLBACK: Clearing all GPU memory before converting result back...")
626 try:
627 # Clear all GPU memory pools
628 cp.get_default_memory_pool().free_all_blocks()
629 cp.get_default_pinned_memory_pool().free_all_blocks()
631 # Force GPU synchronization and cleanup
632 cp.cuda.runtime.deviceSynchronize()
634 # Additional cleanup - clear any cached kernels (if available)
635 try:
636 cp.fuse.clear_memo()
637 except AttributeError:
638 pass # Not available in this CuPy version
640 # Force garbage collection
641 import gc
642 gc.collect()
644 logger.debug("🔧 CPU FALLBACK: GPU memory cleared, converting result back to CuPy...")
646 except Exception as cleanup_error:
647 logger.warning(f"🔧 CPU FALLBACK: GPU cleanup warning: {cleanup_error}")
649 # Convert result back to CuPy (should now have enough memory)
650 logger.info("🔧 CPU FALLBACK: Converting result back to CuPy...")
651 corrected = cp.asarray(corrected_cpu)
653 logger.info("🔧 CPU FALLBACK: Successfully processed on CPU and converted back to GPU")
654 return corrected
656 except Exception as cpu_error:
657 logger.error(f"🔧 CPU FALLBACK: Failed to process on CPU: {cpu_error}")
658 raise RuntimeError(f"Both GPU and CPU processing failed. GPU OOM, CPU error: {cpu_error}") from cpu_error