Coverage for openhcs/core/memory/gpu_cleanup.py: 42.3%
57 statements
« prev ^ index » next coverage.py v7.11.0, created at 2025-11-04 02:09 +0000
« prev ^ index » next coverage.py v7.11.0, created at 2025-11-04 02:09 +0000
1"""
2GPU memory cleanup utilities for different frameworks.
4This module provides unified GPU memory cleanup functions for PyTorch, CuPy,
5TensorFlow, JAX, and pyclesperanto. The cleanup functions are designed to be called
6after processing steps to free up GPU memory that's no longer needed.
8REFACTORED: Uses enum-driven metaprogramming to eliminate 67% of code duplication.
9"""
11import gc
12import logging
13from typing import Optional
14from openhcs.core.utils import optional_import
15from openhcs.constants.constants import MemoryType
16from openhcs.core.memory.framework_config import _FRAMEWORK_CONFIG
18logger = logging.getLogger(__name__)
26def _create_cleanup_function(mem_type: MemoryType):
27 """
28 Factory function that creates a cleanup function for a specific memory type.
30 This single factory replaces 6 nearly-identical cleanup functions.
31 """
32 config = _FRAMEWORK_CONFIG[mem_type]
33 framework_name = config['import_name']
34 display_name = config['display_name']
36 # CPU memory type - no cleanup needed
37 if config['cleanup_ops'] is None:
38 def cleanup(device_id: Optional[int] = None) -> None:
39 """No-op cleanup for CPU memory type."""
40 logger.debug(f"🔥 GPU CLEANUP: No-op for {display_name} (CPU memory type)")
42 cleanup.__name__ = f"cleanup_{framework_name}_gpu"
43 cleanup.__doc__ = f"No-op cleanup for {display_name} (CPU memory type)."
44 return cleanup
46 # GPU memory type - generate cleanup function
47 def cleanup(device_id: Optional[int] = None) -> None:
48 """
49 Clean up {display_name} GPU memory.
51 Args:
52 device_id: Optional GPU device ID. If None, cleans all devices.
53 """
54 framework = globals().get(framework_name)
56 if framework is None:
57 logger.debug(f"{display_name} not available, skipping cleanup")
58 return
60 try:
61 # Check GPU availability
62 gpu_check_expr = config['gpu_check'].format(mod=framework_name)
63 try:
64 gpu_available = eval(gpu_check_expr, {framework_name: framework})
65 except:
66 gpu_available = False
68 if not gpu_available:
69 return
71 # Execute cleanup operations
72 if device_id is not None and config['device_context'] is not None:
73 # Clean specific device with context
74 device_ctx_expr = config['device_context'].format(device_id=device_id, mod=framework_name)
75 device_ctx = eval(device_ctx_expr, {framework_name: framework})
77 with device_ctx:
78 # Execute cleanup operations
79 cleanup_expr = config['cleanup_ops'].format(mod=framework_name)
80 exec(cleanup_expr, {framework_name: framework, 'gc': gc})
82 logger.debug(f"🔥 GPU CLEANUP: Cleared {display_name} for device {device_id}")
83 else:
84 # Clean all devices (no device context)
85 cleanup_expr = config['cleanup_ops'].format(mod=framework_name)
86 exec(cleanup_expr, {framework_name: framework, 'gc': gc})
87 logger.debug(f"🔥 GPU CLEANUP: Cleared {display_name} for all devices")
89 except Exception as e:
90 logger.warning(f"Failed to cleanup {display_name} GPU memory: {e}")
92 # Set proper function name and docstring
93 cleanup.__name__ = f"cleanup_{framework_name}_gpu"
94 cleanup.__doc__ = cleanup.__doc__.format(display_name=display_name)
96 return cleanup
99# Auto-generate all cleanup functions
100for mem_type in MemoryType:
101 cleanup_func = _create_cleanup_function(mem_type)
102 globals()[cleanup_func.__name__] = cleanup_func
105# Auto-generate cleanup registry
106MEMORY_TYPE_CLEANUP_REGISTRY = {
107 mem_type.value: globals()[f"cleanup_{_FRAMEWORK_CONFIG[mem_type]['import_name']}_gpu"]
108 for mem_type in MemoryType
109}
112def cleanup_all_gpu_frameworks(device_id: Optional[int] = None) -> None:
113 """
114 Clean up GPU memory for all available frameworks.
116 This function calls cleanup for all GPU frameworks that are currently loaded.
117 It's safe to call even if some frameworks aren't available.
119 Args:
120 device_id: Optional GPU device ID. If None, cleans all devices.
121 """
122 logger.debug(f"🔥 GPU CLEANUP: Starting cleanup for all GPU frameworks (device_id={device_id})")
124 # Only cleanup GPU memory types (those with cleanup operations)
125 for mem_type, config in _FRAMEWORK_CONFIG.items():
126 if config['cleanup_ops'] is not None:
127 cleanup_func = MEMORY_TYPE_CLEANUP_REGISTRY[mem_type.value]
128 cleanup_func(device_id)
130 logger.debug("🔥 GPU CLEANUP: Completed cleanup for all GPU frameworks")
136# Export all cleanup functions and utilities
137__all__ = [
138 'cleanup_all_gpu_frameworks',
139 'MEMORY_TYPE_CLEANUP_REGISTRY',
140 'cleanup_numpy_gpu',
141 'cleanup_cupy_gpu',
142 'cleanup_torch_gpu',
143 'cleanup_tensorflow_gpu',
144 'cleanup_jax_gpu',
145 'cleanup_pyclesperanto_gpu',
146]