Coverage for openhcs/core/memory/gpu_cleanup.py: 27.4%
208 statements
« prev ^ index » next coverage.py v7.10.3, created at 2025-08-14 05:57 +0000
« prev ^ index » next coverage.py v7.10.3, created at 2025-08-14 05:57 +0000
1"""
2GPU memory cleanup utilities for different frameworks.
4This module provides unified GPU memory cleanup functions for PyTorch, CuPy,
5TensorFlow, and JAX. The cleanup functions are designed to be called after
6processing steps to free up GPU memory that's no longer needed.
7"""
9import logging
10from typing import Optional
11from openhcs.core.utils import optional_import
12from openhcs.constants.constants import VALID_GPU_MEMORY_TYPES # Import directly if always available
14logger = logging.getLogger(__name__)
16# --- Top-level optional imports for GPU frameworks ---
17torch = optional_import("torch")
18cupy = optional_import("cupy")
19tensorflow = optional_import("tensorflow")
20jax = optional_import("jax")
21pyclesperanto = optional_import("pyclesperanto")
23# --- Cleanup functions ---
25def is_gpu_memory_type(memory_type: str) -> bool:
26 """
27 Check if a memory type is a GPU memory type.
29 Args:
30 memory_type: Memory type string
32 Returns:
33 True if it's a GPU memory type, False otherwise
34 """
35 # Using VALID_GPU_MEMORY_TYPES directly after top-level import
36 # If openhcs.constants.constants is itself optional, then this function
37 # might need to revert to its try-except, or ensure that constants are
38 # always available for core utilities. Assuming it's always available now.
39 return memory_type in VALID_GPU_MEMORY_TYPES
42def cleanup_pytorch_gpu(device_id: Optional[int] = None) -> None:
43 """
44 Clean up PyTorch GPU memory.
46 Args:
47 device_id: Optional GPU device ID. If None, cleans all devices.
48 """
49 if torch is None: 49 ↛ 50line 49 didn't jump to line 50 because the condition on line 49 was never true
50 logger.debug("PyTorch not available, skipping PyTorch GPU cleanup")
51 return
53 try:
54 if not torch.cuda.is_available(): 54 ↛ anywhereline 54 didn't jump anywhere: it always raised an exception.
55 return
57 if device_id is not None:
58 # Clean specific device
59 with torch.cuda.device(device_id):
60 torch.cuda.empty_cache()
61 torch.cuda.synchronize()
62 logger.debug(f"🔥 GPU CLEANUP: Cleared PyTorch CUDA cache for device {device_id}")
63 else:
64 # Clean all devices
65 torch.cuda.empty_cache()
66 torch.cuda.synchronize()
67 logger.debug("🔥 GPU CLEANUP: Cleared PyTorch CUDA cache for all devices")
69 except Exception as e:
70 logger.warning(f"Failed to cleanup PyTorch GPU memory: {e}")
73def cleanup_cupy_gpu(device_id: Optional[int] = None) -> None:
74 """
75 Clean up CuPy GPU memory with aggressive defragmentation.
77 Args:
78 device_id: Optional GPU device ID. If None, cleans current device.
79 """
80 if cupy is None: 80 ↛ 81line 80 didn't jump to line 81 because the condition on line 80 was never true
81 logger.debug("CuPy not available, skipping CuPy GPU cleanup")
82 return
84 try:
85 if device_id is not None: 85 ↛ 87line 85 didn't jump to line 87 because the condition on line 85 was never true
86 # Clean specific device
87 with cupy.cuda.Device(device_id):
88 # Get memory info before cleanup
89 mempool = cupy.get_default_memory_pool()
90 used_before = mempool.used_bytes()
92 # Aggressive cleanup to defragment memory
93 cupy.get_default_memory_pool().free_all_blocks()
94 cupy.get_default_pinned_memory_pool().free_all_blocks()
96 # Force memory pool reset to defragment
97 cupy.cuda.runtime.deviceSynchronize()
99 used_after = mempool.used_bytes()
100 freed_mb = (used_before - used_after) / 1e6
102 logger.debug(f"🔥 GPU CLEANUP: Cleared CuPy memory pools for device {device_id}, freed {freed_mb:.1f}MB")
103 else:
104 # Clean current device
105 mempool = cupy.get_default_memory_pool()
106 used_before = mempool.used_bytes()
108 # Aggressive cleanup to defragment memory
109 cupy.get_default_memory_pool().free_all_blocks()
110 cupy.get_default_pinned_memory_pool().free_all_blocks()
112 # Force memory pool reset to defragment
113 cupy.cuda.runtime.deviceSynchronize()
115 used_after = mempool.used_bytes()
116 freed_mb = (used_before - used_after) / 1e6
118 logger.debug(f"🔥 GPU CLEANUP: Cleared CuPy memory pools for current device, freed {freed_mb:.1f}MB")
120 except Exception as e:
121 logger.warning(f"Failed to cleanup CuPy GPU memory: {e}")
124def cleanup_tensorflow_gpu(device_id: Optional[int] = None) -> None:
125 """
126 Clean up TensorFlow GPU memory.
128 Args:
129 device_id: Optional GPU device ID. If None, cleans all devices.
130 """
131 if tensorflow is None: 131 ↛ 132line 131 didn't jump to line 132 because the condition on line 131 was never true
132 logger.debug("TensorFlow not available, skipping TensorFlow GPU cleanup")
133 return
135 try:
136 # Get list of GPU devices
137 gpus = tensorflow.config.list_physical_devices('GPU')
138 if not gpus:
139 return
141 if device_id is not None and device_id < len(gpus):
142 # Clean specific device - TensorFlow doesn't have per-device cleanup
143 # so we trigger garbage collection which helps with memory management
144 import gc
145 gc.collect()
146 logger.debug(f"🔥 GPU CLEANUP: Triggered garbage collection for TensorFlow GPU {device_id}")
147 else:
148 # Clean all devices - trigger garbage collection
149 import gc
150 gc.collect()
151 logger.debug("🔥 GPU CLEANUP: Triggered garbage collection for TensorFlow GPUs")
153 except Exception as e:
154 logger.warning(f"Failed to cleanup TensorFlow GPU memory: {e}")
157def cleanup_jax_gpu(device_id: Optional[int] = None) -> None:
158 """
159 Clean up JAX GPU memory.
161 Args:
162 device_id: Optional GPU device ID. If None, cleans all devices.
163 """
164 if jax is None: 164 ↛ 165line 164 didn't jump to line 165 because the condition on line 164 was never true
165 logger.debug("JAX not available, skipping JAX GPU cleanup")
166 return
168 try:
169 # JAX doesn't have explicit memory cleanup like PyTorch/CuPy
170 # but we can trigger garbage collection and clear compilation cache
171 import gc
172 gc.collect()
174 # Clear JAX compilation cache which can hold GPU memory
175 jax.clear_caches()
177 if device_id is not None: 177 ↛ 178line 177 didn't jump to line 178 because the condition on line 177 was never true
178 logger.debug(f"🔥 GPU CLEANUP: Cleared JAX caches and triggered GC for device {device_id}")
179 else:
180 logger.debug("🔥 GPU CLEANUP: Cleared JAX caches and triggered GC for all devices")
182 except Exception as e:
183 logger.warning(f"Failed to cleanup JAX GPU memory: {e}")
186def cleanup_pyclesperanto_gpu(device_id: Optional[int] = None) -> None:
187 """
188 Clean up pyclesperanto GPU memory.
190 Args:
191 device_id: Optional GPU device ID. If None, cleans current device.
192 """
193 if pyclesperanto is None: 193 ↛ 194line 193 didn't jump to line 194 because the condition on line 193 was never true
194 logger.debug("pyclesperanto not available, skipping pyclesperanto GPU cleanup")
195 return
197 try:
198 import gc
200 # pyclesperanto doesn't have explicit memory cleanup like PyTorch/CuPy
201 # but we can trigger garbage collection and clear any cached data
203 if device_id is not None: 203 ↛ 205line 203 didn't jump to line 205 because the condition on line 203 was never true
204 # Select the specific device
205 devices = pyclesperanto.list_available_devices()
206 if device_id < len(devices):
207 pyclesperanto.select_device(device_id)
208 logger.debug(f"🔥 GPU CLEANUP: Selected pyclesperanto device {device_id}")
209 else:
210 logger.warning(f"🔥 GPU CLEANUP: Device {device_id} not available in pyclesperanto")
212 # Trigger garbage collection to clean up any unreferenced GPU arrays
213 collected = gc.collect()
215 # pyclesperanto uses OpenCL which manages memory automatically
216 # but we can help by ensuring Python objects are cleaned up
217 if device_id is not None: 217 ↛ 218line 217 didn't jump to line 218 because the condition on line 217 was never true
218 logger.debug(f"🔥 GPU CLEANUP: Triggered GC for pyclesperanto device {device_id}, collected {collected} objects")
219 else:
220 logger.debug(f"🔥 GPU CLEANUP: Triggered GC for pyclesperanto current device, collected {collected} objects")
222 except Exception as e:
223 logger.warning(f"Failed to cleanup pyclesperanto GPU memory: {e}")
226def cleanup_gpu_memory_by_framework(memory_type: str, device_id: Optional[int] = None) -> None:
227 """
228 Clean up GPU memory based on the OpenHCS memory type.
230 Args:
231 memory_type: OpenHCS memory type string ("torch", "cupy", "tensorflow", "jax", "numpy")
232 device_id: Optional GPU device ID
233 """
234 # Handle exact OpenHCS memory type values
235 if memory_type == "torch":
236 cleanup_pytorch_gpu(device_id)
237 elif memory_type == "cupy":
238 cleanup_cupy_gpu(device_id)
239 elif memory_type == "tensorflow":
240 cleanup_tensorflow_gpu(device_id)
241 elif memory_type == "jax":
242 cleanup_jax_gpu(device_id)
243 elif memory_type == "pyclesperanto":
244 cleanup_pyclesperanto_gpu(device_id)
245 elif memory_type == "numpy":
246 # CPU memory type - no GPU cleanup needed
247 logger.debug(f"No GPU cleanup needed for CPU memory type: {memory_type}")
248 else:
249 # Fallback for unknown types - try pattern matching
250 memory_type_lower = memory_type.lower()
251 if "torch" in memory_type_lower or "pytorch" in memory_type_lower:
252 cleanup_pytorch_gpu(device_id)
253 elif "cupy" in memory_type_lower:
254 cleanup_cupy_gpu(device_id)
255 elif "tensorflow" in memory_type_lower or "tf" in memory_type_lower:
256 cleanup_tensorflow_gpu(device_id)
257 elif "jax" in memory_type_lower:
258 cleanup_jax_gpu(device_id)
259 elif "pyclesperanto" in memory_type_lower or "clesperanto" in memory_type_lower:
260 cleanup_pyclesperanto_gpu(device_id)
261 else:
262 logger.debug(f"Unknown memory type for GPU cleanup: {memory_type}")
265def cleanup_numpy_noop(device_id: Optional[int] = None) -> None:
266 """
267 No-op cleanup for numpy (CPU memory type).
269 Args:
270 device_id: Optional GPU device ID (ignored for CPU)
271 """
272 logger.debug("🔥 GPU CLEANUP: No-op for numpy (CPU memory type)")
275def cleanup_all_gpu_frameworks(device_id: Optional[int] = None) -> None:
276 """
277 Clean up GPU memory for all available frameworks.
279 Args:
280 device_id: Optional GPU device ID
281 """
282 cleanup_pytorch_gpu(device_id)
283 cleanup_cupy_gpu(device_id)
284 cleanup_tensorflow_gpu(device_id)
285 cleanup_jax_gpu(device_id)
286 cleanup_pyclesperanto_gpu(device_id)
288 # Also trigger Python garbage collection
289 import gc
290 gc.collect()
292 logger.debug("🔥 GPU CLEANUP: Performed comprehensive cleanup for all GPU frameworks")
295# Registry mapping memory types to their cleanup functions
296MEMORY_TYPE_CLEANUP_REGISTRY = {
297 "torch": cleanup_pytorch_gpu,
298 "cupy": cleanup_cupy_gpu,
299 "tensorflow": cleanup_tensorflow_gpu,
300 "jax": cleanup_jax_gpu,
301 "pyclesperanto": cleanup_pyclesperanto_gpu,
302 "numpy": cleanup_numpy_noop,
303}
306def cleanup_memory_by_type(memory_type: str, device_id: Optional[int] = None) -> None:
307 """
308 Clean up memory using the registered cleanup function for the memory type.
310 Args:
311 memory_type: OpenHCS memory type string ("torch", "cupy", "tensorflow", "jax", "numpy")
312 device_id: Optional GPU device ID
313 """
314 cleanup_func = MEMORY_TYPE_CLEANUP_REGISTRY.get(memory_type)
316 if cleanup_func:
317 cleanup_func(device_id)
318 else:
319 logger.warning(f"No cleanup function registered for memory type: {memory_type}")
320 logger.debug(f"Available memory types: {list(MEMORY_TYPE_CLEANUP_REGISTRY.keys())}")
323def check_gpu_memory_usage() -> None:
324 """
325 Check and log current GPU memory usage for all available frameworks.
327 This is a utility function for debugging memory issues.
328 """
329 logger.debug("🔍 GPU Memory Usage Report:")
331 # Check PyTorch
332 if torch is not None:
333 if torch.cuda.is_available():
334 for i in range(torch.cuda.device_count()):
335 allocated = torch.cuda.memory_allocated(i) / 1024**3
336 reserved = torch.cuda.memory_reserved(i) / 1024**3
337 logger.debug(f" PyTorch GPU {i}: {allocated:.2f}GB allocated, {reserved:.2f}GB reserved")
338 else:
339 logger.debug(" PyTorch: No CUDA available")
340 else:
341 logger.debug(" PyTorch: Not installed")
343 # Check CuPy
344 if cupy is not None:
345 mempool = cupy.get_default_memory_pool()
346 used_bytes = mempool.used_bytes()
347 total_bytes = mempool.total_bytes()
348 logger.debug(f" CuPy: {used_bytes / 1024**3:.2f}GB used, {total_bytes / 1024**3:.2f}GB total")
349 else:
350 logger.debug(" CuPy: Not installed") # Added missing log for consistency
352 # Note: TensorFlow and JAX don't have easy memory introspection
353 logger.debug(" TensorFlow/JAX: Memory usage not easily queryable. Check if installed:")
354 if tensorflow is None:
355 logger.debug(" TensorFlow: Not installed")
356 if jax is None:
357 logger.debug(" JAX: Not installed")
360def log_gpu_memory_usage(context: str = "") -> None:
361 """
362 Log GPU memory usage with a specific context for tracking.
364 Args:
365 context: Description of when/where this memory check is happening
366 """
367 context_str = f" ({context})" if context else ""
369 if torch is not None: 369 ↛ 382line 369 didn't jump to line 382 because the condition on line 369 was always true
370 try: # Keep try-except for runtime CUDA availability check
371 if torch.cuda.is_available(): 371 ↛ anywhereline 371 didn't jump anywhere: it always raised an exception.
372 for i in range(torch.cuda.device_count()):
373 allocated = torch.cuda.memory_allocated(i) / 1024**3
374 reserved = torch.cuda.memory_reserved(i) / 1024**3
375 free_memory = torch.cuda.get_device_properties(i).total_memory / 1024**3 - reserved
376 logger.debug(f"🔍 VRAM{context_str} GPU {i}: {allocated:.2f}GB alloc, {reserved:.2f}GB reserved, {free_memory:.2f}GB free")
377 else:
378 logger.debug(f"🔍 VRAM{context_str}: No CUDA available")
379 except Exception as e:
380 logger.warning(f"🔍 VRAM{context_str}: Error checking PyTorch memory - {e}")
381 else:
382 logger.debug(f"🔍 VRAM{context_str}: PyTorch not available")
385def get_gpu_memory_summary() -> dict:
386 """
387 Get GPU memory usage as a dictionary for programmatic use.
389 Returns:
390 Dictionary with memory usage information
391 """
392 memory_info = {
393 "pytorch": {"available": False, "devices": []},
394 "cupy": {"available": False, "used_gb": 0, "total_gb": 0}
395 }
397 # Check PyTorch
398 if torch is not None:
399 try: # Keep try-except for runtime CUDA availability check
400 if torch.cuda.is_available():
401 memory_info["pytorch"]["available"] = True
402 for i in range(torch.cuda.device_count()):
403 allocated = torch.cuda.memory_allocated(i) / 1024**3
404 reserved = torch.cuda.memory_reserved(i) / 1024**3
405 total = torch.cuda.get_device_properties(i).total_memory / 1024**3
406 memory_info["pytorch"]["devices"].append({
407 "device_id": i,
408 "allocated_gb": allocated,
409 "reserved_gb": reserved,
410 "total_gb": total,
411 "free_gb": total - reserved
412 })
413 except Exception: # Catch exceptions related to CUDA operations if available
414 pass # Suppress specific error details if main check is for availability
416 # Check CuPy
417 if cupy is not None:
418 memory_info["cupy"]["available"] = True
419 mempool = cupy.get_default_memory_pool()
420 memory_info["cupy"]["used_gb"] = mempool.used_bytes() / 1024**3
421 memory_info["cupy"]["total_gb"] = mempool.total_bytes() / 1024**3
423 return memory_info
426def force_comprehensive_cleanup() -> None:
427 """
428 Force comprehensive GPU cleanup across all frameworks and trigger garbage collection.
430 This is the nuclear option for clearing GPU memory when you suspect leaks.
431 """
432 logger.debug("🧹 FORCE COMPREHENSIVE CLEANUP: Starting nuclear cleanup...")
434 # Clean all GPU frameworks
435 cleanup_all_gpu_frameworks()
437 # Multiple rounds of garbage collection
438 import gc
439 for i in range(3):
440 collected = gc.collect()
441 logger.debug(f"🧹 Garbage collection round {i+1}: collected {collected} objects")
443 # Check memory usage after cleanup
444 check_gpu_memory_usage()
446 logger.debug("🧹 FORCE COMPREHENSIVE CLEANUP: Complete")