Coverage for src/arraybridge/gpu_cleanup.py: 64%
56 statements
« prev ^ index » next coverage.py v7.11.0, created at 2025-11-03 05:09 +0000
« prev ^ index » next coverage.py v7.11.0, created at 2025-11-03 05:09 +0000
1"""
2GPU memory cleanup utilities for different frameworks.
4This module provides unified GPU memory cleanup functions for PyTorch, CuPy,
5TensorFlow, JAX, and pyclesperanto. The cleanup functions are designed to be called
6after processing steps to free up GPU memory that's no longer needed.
8REFACTORED: Uses enum-driven metaprogramming to eliminate 67% of code duplication.
9"""
11import gc
12import logging
13from typing import Optional
15from arraybridge.framework_config import _FRAMEWORK_CONFIG
16from arraybridge.types import MemoryType
18logger = logging.getLogger(__name__)
21def _create_cleanup_function(mem_type: MemoryType):
22 """
23 Factory function that creates a cleanup function for a specific memory type.
25 This single factory replaces 6 nearly-identical cleanup functions.
26 """
27 config = _FRAMEWORK_CONFIG[mem_type]
28 framework_name = config["import_name"]
29 display_name = config["display_name"]
31 # CPU memory type - no cleanup needed
32 if config["cleanup_ops"] is None:
34 def cleanup(device_id: Optional[int] = None) -> None:
35 """No-op cleanup for CPU memory type."""
36 logger.debug(f"🔥 GPU CLEANUP: No-op for {display_name} (CPU memory type)")
38 cleanup.__name__ = f"cleanup_{framework_name}_gpu"
39 cleanup.__doc__ = f"No-op cleanup for {display_name} (CPU memory type)."
40 return cleanup
42 # GPU memory type - generate cleanup function
43 def cleanup(device_id: Optional[int] = None) -> None:
44 """
45 Clean up {display_name} GPU memory.
47 Args:
48 device_id: Optional GPU device ID. If None, cleans all devices.
49 """
50 framework = globals().get(framework_name)
52 if framework is None:
53 logger.debug(f"{display_name} not available, skipping cleanup")
54 return
56 try:
57 # Check GPU availability
58 gpu_check_expr = config["gpu_check"].format(mod=framework_name)
59 try:
60 gpu_available = eval(gpu_check_expr, {framework_name: framework})
61 except Exception:
62 gpu_available = False
64 if not gpu_available:
65 return
67 # Execute cleanup operations
68 if device_id is not None and config["device_context"] is not None:
69 # Clean specific device with context
70 device_ctx_expr = config["device_context"].format(
71 device_id=device_id, mod=framework_name
72 )
73 device_ctx = eval(device_ctx_expr, {framework_name: framework})
75 with device_ctx:
76 # Execute cleanup operations
77 cleanup_expr = config["cleanup_ops"].format(mod=framework_name)
78 exec(cleanup_expr, {framework_name: framework, "gc": gc})
80 logger.debug(f"🔥 GPU CLEANUP: Cleared {display_name} for device {device_id}")
81 else:
82 # Clean all devices (no device context)
83 cleanup_expr = config["cleanup_ops"].format(mod=framework_name)
84 exec(cleanup_expr, {framework_name: framework, "gc": gc})
85 logger.debug(f"🔥 GPU CLEANUP: Cleared {display_name} for all devices")
87 except Exception as e:
88 logger.warning(f"Failed to cleanup {display_name} GPU memory: {e}")
90 # Set proper function name and docstring
91 cleanup.__name__ = f"cleanup_{framework_name}_gpu"
92 cleanup.__doc__ = cleanup.__doc__.format(display_name=display_name)
94 return cleanup
97# Auto-generate all cleanup functions
98for mem_type in MemoryType:
99 cleanup_func = _create_cleanup_function(mem_type)
100 globals()[cleanup_func.__name__] = cleanup_func
103# Auto-generate cleanup registry
104MEMORY_TYPE_CLEANUP_REGISTRY = {
105 mem_type.value: globals()[f"cleanup_{_FRAMEWORK_CONFIG[mem_type]['import_name']}_gpu"]
106 for mem_type in MemoryType
107}
110def cleanup_all_gpu_frameworks(device_id: Optional[int] = None) -> None:
111 """
112 Clean up GPU memory for all available frameworks.
114 This function calls cleanup for all GPU frameworks that are currently loaded.
115 It's safe to call even if some frameworks aren't available.
117 Args:
118 device_id: Optional GPU device ID. If None, cleans all devices.
119 """
120 logger.debug(f"🔥 GPU CLEANUP: Starting cleanup for all GPU frameworks (device_id={device_id})")
122 # Only cleanup GPU memory types (those with cleanup operations)
123 for mem_type, config in _FRAMEWORK_CONFIG.items():
124 if config["cleanup_ops"] is not None:
125 cleanup_func = MEMORY_TYPE_CLEANUP_REGISTRY[mem_type.value]
126 cleanup_func(device_id)
128 logger.debug("🔥 GPU CLEANUP: Completed cleanup for all GPU frameworks")
131# Export all cleanup functions and utilities
132__all__ = [
133 "cleanup_all_gpu_frameworks",
134 "MEMORY_TYPE_CLEANUP_REGISTRY",
135 "cleanup_numpy_gpu", # noqa: F822
136 "cleanup_cupy_gpu", # noqa: F822
137 "cleanup_torch_gpu", # noqa: F822
138 "cleanup_tensorflow_gpu", # noqa: F822
139 "cleanup_jax_gpu", # noqa: F822
140 "cleanup_pyclesperanto_gpu", # noqa: F822
141]