Coverage for openhcs/core/memory/gpu_cleanup.py: 25.2%
216 statements
« prev ^ index » next coverage.py v7.10.7, created at 2025-10-01 18:33 +0000
« prev ^ index » next coverage.py v7.10.7, created at 2025-10-01 18:33 +0000
1"""
2GPU memory cleanup utilities for different frameworks.
4This module provides unified GPU memory cleanup functions for PyTorch, CuPy,
5TensorFlow, and JAX. The cleanup functions are designed to be called after
6processing steps to free up GPU memory that's no longer needed.
7"""
9import logging
10import os
11from typing import Optional
12from openhcs.core.utils import optional_import
13from openhcs.constants.constants import VALID_GPU_MEMORY_TYPES # Import directly if always available
15logger = logging.getLogger(__name__)
17# Check if we're in subprocess runner mode and should skip GPU imports
18if os.getenv('OPENHCS_SUBPROCESS_NO_GPU') == '1': 18 ↛ 20line 18 didn't jump to line 20 because the condition on line 18 was never true
19 # Subprocess runner mode - skip GPU imports
20 torch = None
21 cupy = None
22 tensorflow = None
23 jax = None
24 pyclesperanto = None
25 logger.info("Subprocess runner mode - skipping GPU library imports in gpu_cleanup")
26else:
27 # Normal mode - import GPU frameworks as optional dependencies
28 torch = optional_import("torch")
29 cupy = optional_import("cupy")
30 tensorflow = optional_import("tensorflow")
31 jax = optional_import("jax")
32 pyclesperanto = optional_import("pyclesperanto")
34# --- Cleanup functions ---
36def is_gpu_memory_type(memory_type: str) -> bool:
37 """
38 Check if a memory type is a GPU memory type.
40 Args:
41 memory_type: Memory type string
43 Returns:
44 True if it's a GPU memory type, False otherwise
45 """
46 # Using VALID_GPU_MEMORY_TYPES directly after top-level import
47 # If openhcs.constants.constants is itself optional, then this function
48 # might need to revert to its try-except, or ensure that constants are
49 # always available for core utilities. Assuming it's always available now.
50 return memory_type in VALID_GPU_MEMORY_TYPES
53def cleanup_pytorch_gpu(device_id: Optional[int] = None) -> None:
54 """
55 Clean up PyTorch GPU memory.
57 Args:
58 device_id: Optional GPU device ID. If None, cleans all devices.
59 """
60 if torch is None: 60 ↛ 61line 60 didn't jump to line 61 because the condition on line 60 was never true
61 logger.debug("PyTorch not available, skipping PyTorch GPU cleanup")
62 return
64 try:
65 if not torch.cuda.is_available(): 65 ↛ anywhereline 65 didn't jump anywhere: it always raised an exception.
66 return
68 if device_id is not None:
69 # Clean specific device
70 with torch.cuda.device(device_id):
71 torch.cuda.empty_cache()
72 torch.cuda.synchronize()
73 logger.debug(f"🔥 GPU CLEANUP: Cleared PyTorch CUDA cache for device {device_id}")
74 else:
75 # Clean all devices
76 torch.cuda.empty_cache()
77 torch.cuda.synchronize()
78 logger.debug("🔥 GPU CLEANUP: Cleared PyTorch CUDA cache for all devices")
80 except Exception as e:
81 logger.warning(f"Failed to cleanup PyTorch GPU memory: {e}")
84def cleanup_cupy_gpu(device_id: Optional[int] = None) -> None:
85 """
86 Clean up CuPy GPU memory with aggressive defragmentation.
88 Args:
89 device_id: Optional GPU device ID. If None, cleans current device.
90 """
91 if cupy is None: 91 ↛ 92line 91 didn't jump to line 92 because the condition on line 91 was never true
92 logger.debug("CuPy not available, skipping CuPy GPU cleanup")
93 return
95 try:
96 if device_id is not None: 96 ↛ 98line 96 didn't jump to line 98 because the condition on line 96 was never true
97 # Clean specific device
98 with cupy.cuda.Device(device_id):
99 # Get memory info before cleanup
100 mempool = cupy.get_default_memory_pool()
101 used_before = mempool.used_bytes()
103 # Aggressive cleanup to defragment memory
104 cupy.get_default_memory_pool().free_all_blocks()
105 cupy.get_default_pinned_memory_pool().free_all_blocks()
107 # Force memory pool reset to defragment
108 cupy.cuda.runtime.deviceSynchronize()
110 used_after = mempool.used_bytes()
111 freed_mb = (used_before - used_after) / 1e6
113 logger.debug(f"🔥 GPU CLEANUP: Cleared CuPy memory pools for device {device_id}, freed {freed_mb:.1f}MB")
114 else:
115 # Clean current device
116 mempool = cupy.get_default_memory_pool()
117 used_before = mempool.used_bytes()
119 # Aggressive cleanup to defragment memory
120 cupy.get_default_memory_pool().free_all_blocks()
121 cupy.get_default_pinned_memory_pool().free_all_blocks()
123 # Force memory pool reset to defragment
124 cupy.cuda.runtime.deviceSynchronize()
126 used_after = mempool.used_bytes()
127 freed_mb = (used_before - used_after) / 1e6
129 logger.debug(f"🔥 GPU CLEANUP: Cleared CuPy memory pools for current device, freed {freed_mb:.1f}MB")
131 except Exception as e:
132 logger.warning(f"Failed to cleanup CuPy GPU memory: {e}")
135def cleanup_tensorflow_gpu(device_id: Optional[int] = None) -> None:
136 """
137 Clean up TensorFlow GPU memory.
139 Args:
140 device_id: Optional GPU device ID. If None, cleans all devices.
141 """
142 if tensorflow is None: 142 ↛ 143line 142 didn't jump to line 143 because the condition on line 142 was never true
143 logger.debug("TensorFlow not available, skipping TensorFlow GPU cleanup")
144 return
146 try:
147 # Get list of GPU devices
148 gpus = tensorflow.config.list_physical_devices('GPU')
149 if not gpus:
150 return
152 if device_id is not None and device_id < len(gpus):
153 # Clean specific device - TensorFlow doesn't have per-device cleanup
154 # so we trigger garbage collection which helps with memory management
155 import gc
156 gc.collect()
157 logger.debug(f"🔥 GPU CLEANUP: Triggered garbage collection for TensorFlow GPU {device_id}")
158 else:
159 # Clean all devices - trigger garbage collection
160 import gc
161 gc.collect()
162 logger.debug("🔥 GPU CLEANUP: Triggered garbage collection for TensorFlow GPUs")
164 except Exception as e:
165 logger.warning(f"Failed to cleanup TensorFlow GPU memory: {e}")
168def cleanup_jax_gpu(device_id: Optional[int] = None) -> None:
169 """
170 Clean up JAX GPU memory.
172 Args:
173 device_id: Optional GPU device ID. If None, cleans all devices.
174 """
175 if jax is None: 175 ↛ 176line 175 didn't jump to line 176 because the condition on line 175 was never true
176 logger.debug("JAX not available, skipping JAX GPU cleanup")
177 return
179 try:
180 # JAX doesn't have explicit memory cleanup like PyTorch/CuPy
181 # but we can trigger garbage collection and clear compilation cache
182 import gc
183 gc.collect()
185 # Clear JAX compilation cache which can hold GPU memory
186 jax.clear_caches()
188 if device_id is not None: 188 ↛ 189line 188 didn't jump to line 189 because the condition on line 188 was never true
189 logger.debug(f"🔥 GPU CLEANUP: Cleared JAX caches and triggered GC for device {device_id}")
190 else:
191 logger.debug("🔥 GPU CLEANUP: Cleared JAX caches and triggered GC for all devices")
193 except Exception as e:
194 logger.warning(f"Failed to cleanup JAX GPU memory: {e}")
197def cleanup_pyclesperanto_gpu(device_id: Optional[int] = None) -> None:
198 """
199 Clean up pyclesperanto GPU memory.
201 Args:
202 device_id: Optional GPU device ID. If None, cleans current device.
203 """
204 if pyclesperanto is None: 204 ↛ 205line 204 didn't jump to line 205 because the condition on line 204 was never true
205 logger.debug("pyclesperanto not available, skipping pyclesperanto GPU cleanup")
206 return
208 try:
209 import gc
211 # pyclesperanto doesn't have explicit memory cleanup like PyTorch/CuPy
212 # but we can trigger garbage collection and clear any cached data
214 if device_id is not None: 214 ↛ 216line 214 didn't jump to line 216 because the condition on line 214 was never true
215 # Select the specific device
216 devices = pyclesperanto.list_available_devices()
217 if device_id < len(devices):
218 pyclesperanto.select_device(device_id)
219 logger.debug(f"🔥 GPU CLEANUP: Selected pyclesperanto device {device_id}")
220 else:
221 logger.warning(f"🔥 GPU CLEANUP: Device {device_id} not available in pyclesperanto")
223 # Trigger garbage collection to clean up any unreferenced GPU arrays
224 collected = gc.collect()
226 # pyclesperanto uses OpenCL which manages memory automatically
227 # but we can help by ensuring Python objects are cleaned up
228 if device_id is not None: 228 ↛ 229line 228 didn't jump to line 229 because the condition on line 228 was never true
229 logger.debug(f"🔥 GPU CLEANUP: Triggered GC for pyclesperanto device {device_id}, collected {collected} objects")
230 else:
231 logger.debug(f"🔥 GPU CLEANUP: Triggered GC for pyclesperanto current device, collected {collected} objects")
233 except Exception as e:
234 logger.warning(f"Failed to cleanup pyclesperanto GPU memory: {e}")
237def cleanup_gpu_memory_by_framework(memory_type: str, device_id: Optional[int] = None) -> None:
238 """
239 Clean up GPU memory based on the OpenHCS memory type.
241 Args:
242 memory_type: OpenHCS memory type string ("torch", "cupy", "tensorflow", "jax", "numpy")
243 device_id: Optional GPU device ID
244 """
245 # Handle exact OpenHCS memory type values
246 if memory_type == "torch":
247 cleanup_pytorch_gpu(device_id)
248 elif memory_type == "cupy":
249 cleanup_cupy_gpu(device_id)
250 elif memory_type == "tensorflow":
251 cleanup_tensorflow_gpu(device_id)
252 elif memory_type == "jax":
253 cleanup_jax_gpu(device_id)
254 elif memory_type == "pyclesperanto":
255 cleanup_pyclesperanto_gpu(device_id)
256 elif memory_type == "numpy":
257 # CPU memory type - no GPU cleanup needed
258 logger.debug(f"No GPU cleanup needed for CPU memory type: {memory_type}")
259 else:
260 # Fallback for unknown types - try pattern matching
261 memory_type_lower = memory_type.lower()
262 if "torch" in memory_type_lower or "pytorch" in memory_type_lower:
263 cleanup_pytorch_gpu(device_id)
264 elif "cupy" in memory_type_lower:
265 cleanup_cupy_gpu(device_id)
266 elif "tensorflow" in memory_type_lower or "tf" in memory_type_lower:
267 cleanup_tensorflow_gpu(device_id)
268 elif "jax" in memory_type_lower:
269 cleanup_jax_gpu(device_id)
270 elif "pyclesperanto" in memory_type_lower or "clesperanto" in memory_type_lower:
271 cleanup_pyclesperanto_gpu(device_id)
272 else:
273 logger.debug(f"Unknown memory type for GPU cleanup: {memory_type}")
276def cleanup_numpy_noop(device_id: Optional[int] = None) -> None:
277 """
278 No-op cleanup for numpy (CPU memory type).
280 Args:
281 device_id: Optional GPU device ID (ignored for CPU)
282 """
283 logger.debug("🔥 GPU CLEANUP: No-op for numpy (CPU memory type)")
286def cleanup_all_gpu_frameworks(device_id: Optional[int] = None) -> None:
287 """
288 Clean up GPU memory for all available frameworks.
290 Args:
291 device_id: Optional GPU device ID
292 """
293 cleanup_pytorch_gpu(device_id)
294 cleanup_cupy_gpu(device_id)
295 cleanup_tensorflow_gpu(device_id)
296 cleanup_jax_gpu(device_id)
297 cleanup_pyclesperanto_gpu(device_id)
299 # Also trigger Python garbage collection
300 import gc
301 gc.collect()
303 logger.debug("🔥 GPU CLEANUP: Performed comprehensive cleanup for all GPU frameworks")
306# Registry mapping memory types to their cleanup functions
307MEMORY_TYPE_CLEANUP_REGISTRY = {
308 "torch": cleanup_pytorch_gpu,
309 "cupy": cleanup_cupy_gpu,
310 "tensorflow": cleanup_tensorflow_gpu,
311 "jax": cleanup_jax_gpu,
312 "pyclesperanto": cleanup_pyclesperanto_gpu,
313 "numpy": cleanup_numpy_noop,
314}
317def cleanup_memory_by_type(memory_type: str, device_id: Optional[int] = None) -> None:
318 """
319 Clean up memory using the registered cleanup function for the memory type.
321 Args:
322 memory_type: OpenHCS memory type string ("torch", "cupy", "tensorflow", "jax", "numpy")
323 device_id: Optional GPU device ID
324 """
325 cleanup_func = MEMORY_TYPE_CLEANUP_REGISTRY.get(memory_type)
327 if cleanup_func:
328 cleanup_func(device_id)
329 else:
330 logger.warning(f"No cleanup function registered for memory type: {memory_type}")
331 logger.debug(f"Available memory types: {list(MEMORY_TYPE_CLEANUP_REGISTRY.keys())}")
334def check_gpu_memory_usage() -> None:
335 """
336 Check and log current GPU memory usage for all available frameworks.
338 This is a utility function for debugging memory issues.
339 """
340 logger.debug("🔍 GPU Memory Usage Report:")
342 # Check PyTorch
343 if torch is not None:
344 if torch.cuda.is_available():
345 for i in range(torch.cuda.device_count()):
346 allocated = torch.cuda.memory_allocated(i) / 1024**3
347 reserved = torch.cuda.memory_reserved(i) / 1024**3
348 logger.debug(f" PyTorch GPU {i}: {allocated:.2f}GB allocated, {reserved:.2f}GB reserved")
349 else:
350 logger.debug(" PyTorch: No CUDA available")
351 else:
352 logger.debug(" PyTorch: Not installed")
354 # Check CuPy
355 if cupy is not None:
356 mempool = cupy.get_default_memory_pool()
357 used_bytes = mempool.used_bytes()
358 total_bytes = mempool.total_bytes()
359 logger.debug(f" CuPy: {used_bytes / 1024**3:.2f}GB used, {total_bytes / 1024**3:.2f}GB total")
360 else:
361 logger.debug(" CuPy: Not installed") # Added missing log for consistency
363 # Note: TensorFlow and JAX don't have easy memory introspection
364 logger.debug(" TensorFlow/JAX: Memory usage not easily queryable. Check if installed:")
365 if tensorflow is None:
366 logger.debug(" TensorFlow: Not installed")
367 if jax is None:
368 logger.debug(" JAX: Not installed")
371def log_gpu_memory_usage(context: str = "") -> None:
372 """
373 Log GPU memory usage with a specific context for tracking.
375 Args:
376 context: Description of when/where this memory check is happening
377 """
378 context_str = f" ({context})" if context else ""
380 if torch is not None:
381 try: # Keep try-except for runtime CUDA availability check
382 if torch.cuda.is_available():
383 for i in range(torch.cuda.device_count()):
384 allocated = torch.cuda.memory_allocated(i) / 1024**3
385 reserved = torch.cuda.memory_reserved(i) / 1024**3
386 free_memory = torch.cuda.get_device_properties(i).total_memory / 1024**3 - reserved
387 logger.debug(f"🔍 VRAM{context_str} GPU {i}: {allocated:.2f}GB alloc, {reserved:.2f}GB reserved, {free_memory:.2f}GB free")
388 else:
389 logger.debug(f"🔍 VRAM{context_str}: No CUDA available")
390 except Exception as e:
391 logger.warning(f"🔍 VRAM{context_str}: Error checking PyTorch memory - {e}")
392 else:
393 logger.debug(f"🔍 VRAM{context_str}: PyTorch not available")
396def get_gpu_memory_summary() -> dict:
397 """
398 Get GPU memory usage as a dictionary for programmatic use.
400 Returns:
401 Dictionary with memory usage information
402 """
403 memory_info = {
404 "pytorch": {"available": False, "devices": []},
405 "cupy": {"available": False, "used_gb": 0, "total_gb": 0}
406 }
408 # Check PyTorch
409 if torch is not None:
410 try: # Keep try-except for runtime CUDA availability check
411 if torch.cuda.is_available():
412 memory_info["pytorch"]["available"] = True
413 for i in range(torch.cuda.device_count()):
414 allocated = torch.cuda.memory_allocated(i) / 1024**3
415 reserved = torch.cuda.memory_reserved(i) / 1024**3
416 total = torch.cuda.get_device_properties(i).total_memory / 1024**3
417 memory_info["pytorch"]["devices"].append({
418 "device_id": i,
419 "allocated_gb": allocated,
420 "reserved_gb": reserved,
421 "total_gb": total,
422 "free_gb": total - reserved
423 })
424 except Exception: # Catch exceptions related to CUDA operations if available
425 pass # Suppress specific error details if main check is for availability
427 # Check CuPy
428 if cupy is not None:
429 memory_info["cupy"]["available"] = True
430 mempool = cupy.get_default_memory_pool()
431 memory_info["cupy"]["used_gb"] = mempool.used_bytes() / 1024**3
432 memory_info["cupy"]["total_gb"] = mempool.total_bytes() / 1024**3
434 return memory_info
437def force_comprehensive_cleanup() -> None:
438 """
439 Force comprehensive GPU cleanup across all frameworks and trigger garbage collection.
441 This is the nuclear option for clearing GPU memory when you suspect leaks.
442 """
443 logger.debug("🧹 FORCE COMPREHENSIVE CLEANUP: Starting nuclear cleanup...")
445 # Clean all GPU frameworks
446 cleanup_all_gpu_frameworks()
448 # Multiple rounds of garbage collection
449 import gc
450 for i in range(3):
451 collected = gc.collect()
452 logger.debug(f"🧹 Garbage collection round {i+1}: collected {collected} objects")
454 # Check memory usage after cleanup
455 check_gpu_memory_usage()
457 logger.debug("🧹 FORCE COMPREHENSIVE CLEANUP: Complete")