Coverage for openhcs/core/memory/oom_recovery.py: 10.2%

68 statements  

« prev     ^ index     » next       coverage.py v7.11.0, created at 2025-11-04 02:09 +0000

1""" 

2GPU Out of Memory (OOM) recovery utilities. 

3 

4Provides comprehensive OOM detection and cache clearing for all supported 

5GPU frameworks in OpenHCS. 

6 

7REFACTORED: Uses enum-driven metaprogramming to eliminate 71% of code duplication. 

8All OOM patterns and cache clearing operations are defined in framework_ops.py. 

9""" 

10 

11import gc 

12import logging 

13from typing import Optional 

14 

15from openhcs.constants.constants import MemoryType 

16from openhcs.core.memory.framework_ops import _FRAMEWORK_OPS 

17from openhcs.core.utils import optional_import 

18 

19logger = logging.getLogger(__name__) 

20 

21 

22def _is_oom_error(e: Exception, memory_type: str) -> bool: 

23 """ 

24 Detect Out of Memory errors for all GPU frameworks. 

25 

26 Auto-generated from framework_ops.py OOM patterns. 

27 

28 Args: 

29 e: Exception to check 

30 memory_type: Memory type string (e.g., 'torch', 'cupy') 

31 

32 Returns: 

33 True if exception is an OOM error for the given framework 

34 """ 

35 # Find the MemoryType enum for this memory_type string 

36 mem_type_enum = None 

37 for mt in MemoryType: 

38 if mt.value == memory_type: 

39 mem_type_enum = mt 

40 break 

41 

42 if mem_type_enum is None: 

43 return False 

44 

45 ops = _FRAMEWORK_OPS[mem_type_enum] 

46 error_str = str(e).lower() 

47 

48 # Check framework-specific exception types 

49 for exc_type_expr in ops['oom_exception_types']: 

50 try: 

51 # Import the module and get the exception type 

52 mod_name = ops['import_name'] 

53 mod = optional_import(mod_name) 

54 if mod is None: 

55 continue 

56 

57 # Evaluate the exception type expression 

58 exc_type_str = exc_type_expr.format(mod='mod') 

59 # Extract the attribute path (e.g., 'mod.cuda.OutOfMemoryError' -> ['cuda', 'OutOfMemoryError']) 

60 parts = exc_type_str.split('.')[1:] # Skip 'mod' 

61 exc_type = mod 

62 for part in parts: 

63 if hasattr(exc_type, part): 

64 exc_type = getattr(exc_type, part) 

65 else: 

66 exc_type = None 

67 break 

68 

69 if exc_type is not None and isinstance(e, exc_type): 

70 return True 

71 except Exception: 

72 continue 

73 

74 # String-based detection using framework-specific patterns 

75 return any(pattern in error_str for pattern in ops['oom_string_patterns']) 

76 

77 

78def _clear_cache_for_memory_type(memory_type: str, device_id: Optional[int] = None): 

79 """ 

80 Clear GPU cache for specific memory type. 

81 

82 Auto-generated from framework_ops.py cache clearing operations. 

83 

84 Args: 

85 memory_type: Memory type string (e.g., 'torch', 'cupy') 

86 device_id: GPU device ID (optional, currently unused but kept for API compatibility) 

87 """ 

88 # Find the MemoryType enum for this memory_type string 

89 mem_type_enum = None 

90 for mt in MemoryType: 

91 if mt.value == memory_type: 

92 mem_type_enum = mt 

93 break 

94 

95 if mem_type_enum is None: 

96 logger.warning(f"Unknown memory type for cache clearing: {memory_type}") 

97 gc.collect() 

98 return 

99 

100 ops = _FRAMEWORK_OPS[mem_type_enum] 

101 

102 # Get the module 

103 mod_name = ops['import_name'] 

104 mod = optional_import(mod_name) 

105 

106 if mod is None: 

107 logger.warning(f"Module {mod_name} not available for cache clearing") 

108 gc.collect() 

109 return 

110 

111 # Execute cache clearing operations 

112 cache_clear_expr = ops['oom_clear_cache'] 

113 if cache_clear_expr: 

114 try: 

115 # Execute cache clear directly (device context handled by the operations themselves) 

116 exec(cache_clear_expr.format(mod=mod_name), {mod_name: mod, 'gc': gc}) 

117 except Exception as e: 

118 logger.warning(f"Failed to clear cache for {memory_type}: {e}") 

119 

120 # Always trigger Python garbage collection 

121 gc.collect() 

122 

123 

124def _execute_with_oom_recovery(func_callable, memory_type: str, max_retries: int = 2): 

125 """ 

126 Execute function with automatic OOM recovery. 

127  

128 Args: 

129 func_callable: Function to execute 

130 memory_type: Memory type from MemoryType enum 

131 max_retries: Maximum number of retry attempts 

132  

133 Returns: 

134 Function result 

135  

136 Raises: 

137 Original exception if not OOM or retries exhausted 

138 """ 

139 for attempt in range(max_retries + 1): 

140 try: 

141 return func_callable() 

142 except Exception as e: 

143 if not _is_oom_error(e, memory_type) or attempt == max_retries: 

144 raise 

145 

146 # Clear cache and retry 

147 _clear_cache_for_memory_type(memory_type)