Coverage for openhcs/core/memory/gpu_cleanup.py: 42.3%

57 statements  

« prev     ^ index     » next       coverage.py v7.11.0, created at 2025-11-04 02:09 +0000

1""" 

2GPU memory cleanup utilities for different frameworks. 

3 

4This module provides unified GPU memory cleanup functions for PyTorch, CuPy, 

5TensorFlow, JAX, and pyclesperanto. The cleanup functions are designed to be called 

6after processing steps to free up GPU memory that's no longer needed. 

7 

8REFACTORED: Uses enum-driven metaprogramming to eliminate 67% of code duplication. 

9""" 

10 

11import gc 

12import logging 

13from typing import Optional 

14from openhcs.core.utils import optional_import 

15from openhcs.constants.constants import MemoryType 

16from openhcs.core.memory.framework_config import _FRAMEWORK_CONFIG 

17 

18logger = logging.getLogger(__name__) 

19 

20 

21 

22 

23 

24 

25 

26def _create_cleanup_function(mem_type: MemoryType): 

27 """ 

28 Factory function that creates a cleanup function for a specific memory type. 

29 

30 This single factory replaces 6 nearly-identical cleanup functions. 

31 """ 

32 config = _FRAMEWORK_CONFIG[mem_type] 

33 framework_name = config['import_name'] 

34 display_name = config['display_name'] 

35 

36 # CPU memory type - no cleanup needed 

37 if config['cleanup_ops'] is None: 

38 def cleanup(device_id: Optional[int] = None) -> None: 

39 """No-op cleanup for CPU memory type.""" 

40 logger.debug(f"🔥 GPU CLEANUP: No-op for {display_name} (CPU memory type)") 

41 

42 cleanup.__name__ = f"cleanup_{framework_name}_gpu" 

43 cleanup.__doc__ = f"No-op cleanup for {display_name} (CPU memory type)." 

44 return cleanup 

45 

46 # GPU memory type - generate cleanup function 

47 def cleanup(device_id: Optional[int] = None) -> None: 

48 """ 

49 Clean up {display_name} GPU memory. 

50  

51 Args: 

52 device_id: Optional GPU device ID. If None, cleans all devices. 

53 """ 

54 framework = globals().get(framework_name) 

55 

56 if framework is None: 

57 logger.debug(f"{display_name} not available, skipping cleanup") 

58 return 

59 

60 try: 

61 # Check GPU availability 

62 gpu_check_expr = config['gpu_check'].format(mod=framework_name) 

63 try: 

64 gpu_available = eval(gpu_check_expr, {framework_name: framework}) 

65 except: 

66 gpu_available = False 

67 

68 if not gpu_available: 

69 return 

70 

71 # Execute cleanup operations 

72 if device_id is not None and config['device_context'] is not None: 

73 # Clean specific device with context 

74 device_ctx_expr = config['device_context'].format(device_id=device_id, mod=framework_name) 

75 device_ctx = eval(device_ctx_expr, {framework_name: framework}) 

76 

77 with device_ctx: 

78 # Execute cleanup operations 

79 cleanup_expr = config['cleanup_ops'].format(mod=framework_name) 

80 exec(cleanup_expr, {framework_name: framework, 'gc': gc}) 

81 

82 logger.debug(f"🔥 GPU CLEANUP: Cleared {display_name} for device {device_id}") 

83 else: 

84 # Clean all devices (no device context) 

85 cleanup_expr = config['cleanup_ops'].format(mod=framework_name) 

86 exec(cleanup_expr, {framework_name: framework, 'gc': gc}) 

87 logger.debug(f"🔥 GPU CLEANUP: Cleared {display_name} for all devices") 

88 

89 except Exception as e: 

90 logger.warning(f"Failed to cleanup {display_name} GPU memory: {e}") 

91 

92 # Set proper function name and docstring 

93 cleanup.__name__ = f"cleanup_{framework_name}_gpu" 

94 cleanup.__doc__ = cleanup.__doc__.format(display_name=display_name) 

95 

96 return cleanup 

97 

98 

99# Auto-generate all cleanup functions 

100for mem_type in MemoryType: 

101 cleanup_func = _create_cleanup_function(mem_type) 

102 globals()[cleanup_func.__name__] = cleanup_func 

103 

104 

105# Auto-generate cleanup registry 

106MEMORY_TYPE_CLEANUP_REGISTRY = { 

107 mem_type.value: globals()[f"cleanup_{_FRAMEWORK_CONFIG[mem_type]['import_name']}_gpu"] 

108 for mem_type in MemoryType 

109} 

110 

111 

112def cleanup_all_gpu_frameworks(device_id: Optional[int] = None) -> None: 

113 """ 

114 Clean up GPU memory for all available frameworks. 

115 

116 This function calls cleanup for all GPU frameworks that are currently loaded. 

117 It's safe to call even if some frameworks aren't available. 

118 

119 Args: 

120 device_id: Optional GPU device ID. If None, cleans all devices. 

121 """ 

122 logger.debug(f"🔥 GPU CLEANUP: Starting cleanup for all GPU frameworks (device_id={device_id})") 

123 

124 # Only cleanup GPU memory types (those with cleanup operations) 

125 for mem_type, config in _FRAMEWORK_CONFIG.items(): 

126 if config['cleanup_ops'] is not None: 

127 cleanup_func = MEMORY_TYPE_CLEANUP_REGISTRY[mem_type.value] 

128 cleanup_func(device_id) 

129 

130 logger.debug("🔥 GPU CLEANUP: Completed cleanup for all GPU frameworks") 

131 

132 

133 

134 

135 

136# Export all cleanup functions and utilities 

137__all__ = [ 

138 'cleanup_all_gpu_frameworks', 

139 'MEMORY_TYPE_CLEANUP_REGISTRY', 

140 'cleanup_numpy_gpu', 

141 'cleanup_cupy_gpu', 

142 'cleanup_torch_gpu', 

143 'cleanup_tensorflow_gpu', 

144 'cleanup_jax_gpu', 

145 'cleanup_pyclesperanto_gpu', 

146] 

147