Coverage for src/arraybridge/gpu_cleanup.py: 64%

56 statements  

« prev     ^ index     » next       coverage.py v7.11.0, created at 2025-11-03 05:09 +0000

1""" 

2GPU memory cleanup utilities for different frameworks. 

3 

4This module provides unified GPU memory cleanup functions for PyTorch, CuPy, 

5TensorFlow, JAX, and pyclesperanto. The cleanup functions are designed to be called 

6after processing steps to free up GPU memory that's no longer needed. 

7 

8REFACTORED: Uses enum-driven metaprogramming to eliminate 67% of code duplication. 

9""" 

10 

11import gc 

12import logging 

13from typing import Optional 

14 

15from arraybridge.framework_config import _FRAMEWORK_CONFIG 

16from arraybridge.types import MemoryType 

17 

18logger = logging.getLogger(__name__) 

19 

20 

21def _create_cleanup_function(mem_type: MemoryType): 

22 """ 

23 Factory function that creates a cleanup function for a specific memory type. 

24 

25 This single factory replaces 6 nearly-identical cleanup functions. 

26 """ 

27 config = _FRAMEWORK_CONFIG[mem_type] 

28 framework_name = config["import_name"] 

29 display_name = config["display_name"] 

30 

31 # CPU memory type - no cleanup needed 

32 if config["cleanup_ops"] is None: 

33 

34 def cleanup(device_id: Optional[int] = None) -> None: 

35 """No-op cleanup for CPU memory type.""" 

36 logger.debug(f"🔥 GPU CLEANUP: No-op for {display_name} (CPU memory type)") 

37 

38 cleanup.__name__ = f"cleanup_{framework_name}_gpu" 

39 cleanup.__doc__ = f"No-op cleanup for {display_name} (CPU memory type)." 

40 return cleanup 

41 

42 # GPU memory type - generate cleanup function 

43 def cleanup(device_id: Optional[int] = None) -> None: 

44 """ 

45 Clean up {display_name} GPU memory. 

46 

47 Args: 

48 device_id: Optional GPU device ID. If None, cleans all devices. 

49 """ 

50 framework = globals().get(framework_name) 

51 

52 if framework is None: 

53 logger.debug(f"{display_name} not available, skipping cleanup") 

54 return 

55 

56 try: 

57 # Check GPU availability 

58 gpu_check_expr = config["gpu_check"].format(mod=framework_name) 

59 try: 

60 gpu_available = eval(gpu_check_expr, {framework_name: framework}) 

61 except Exception: 

62 gpu_available = False 

63 

64 if not gpu_available: 

65 return 

66 

67 # Execute cleanup operations 

68 if device_id is not None and config["device_context"] is not None: 

69 # Clean specific device with context 

70 device_ctx_expr = config["device_context"].format( 

71 device_id=device_id, mod=framework_name 

72 ) 

73 device_ctx = eval(device_ctx_expr, {framework_name: framework}) 

74 

75 with device_ctx: 

76 # Execute cleanup operations 

77 cleanup_expr = config["cleanup_ops"].format(mod=framework_name) 

78 exec(cleanup_expr, {framework_name: framework, "gc": gc}) 

79 

80 logger.debug(f"🔥 GPU CLEANUP: Cleared {display_name} for device {device_id}") 

81 else: 

82 # Clean all devices (no device context) 

83 cleanup_expr = config["cleanup_ops"].format(mod=framework_name) 

84 exec(cleanup_expr, {framework_name: framework, "gc": gc}) 

85 logger.debug(f"🔥 GPU CLEANUP: Cleared {display_name} for all devices") 

86 

87 except Exception as e: 

88 logger.warning(f"Failed to cleanup {display_name} GPU memory: {e}") 

89 

90 # Set proper function name and docstring 

91 cleanup.__name__ = f"cleanup_{framework_name}_gpu" 

92 cleanup.__doc__ = cleanup.__doc__.format(display_name=display_name) 

93 

94 return cleanup 

95 

96 

97# Auto-generate all cleanup functions 

98for mem_type in MemoryType: 

99 cleanup_func = _create_cleanup_function(mem_type) 

100 globals()[cleanup_func.__name__] = cleanup_func 

101 

102 

103# Auto-generate cleanup registry 

104MEMORY_TYPE_CLEANUP_REGISTRY = { 

105 mem_type.value: globals()[f"cleanup_{_FRAMEWORK_CONFIG[mem_type]['import_name']}_gpu"] 

106 for mem_type in MemoryType 

107} 

108 

109 

110def cleanup_all_gpu_frameworks(device_id: Optional[int] = None) -> None: 

111 """ 

112 Clean up GPU memory for all available frameworks. 

113 

114 This function calls cleanup for all GPU frameworks that are currently loaded. 

115 It's safe to call even if some frameworks aren't available. 

116 

117 Args: 

118 device_id: Optional GPU device ID. If None, cleans all devices. 

119 """ 

120 logger.debug(f"🔥 GPU CLEANUP: Starting cleanup for all GPU frameworks (device_id={device_id})") 

121 

122 # Only cleanup GPU memory types (those with cleanup operations) 

123 for mem_type, config in _FRAMEWORK_CONFIG.items(): 

124 if config["cleanup_ops"] is not None: 

125 cleanup_func = MEMORY_TYPE_CLEANUP_REGISTRY[mem_type.value] 

126 cleanup_func(device_id) 

127 

128 logger.debug("🔥 GPU CLEANUP: Completed cleanup for all GPU frameworks") 

129 

130 

131# Export all cleanup functions and utilities 

132__all__ = [ 

133 "cleanup_all_gpu_frameworks", 

134 "MEMORY_TYPE_CLEANUP_REGISTRY", 

135 "cleanup_numpy_gpu", # noqa: F822 

136 "cleanup_cupy_gpu", # noqa: F822 

137 "cleanup_torch_gpu", # noqa: F822 

138 "cleanup_tensorflow_gpu", # noqa: F822 

139 "cleanup_jax_gpu", # noqa: F822 

140 "cleanup_pyclesperanto_gpu", # noqa: F822 

141]