Coverage for src/arraybridge/oom_recovery.py: 90%

68 statements  

« prev     ^ index     » next       coverage.py v7.11.0, created at 2025-11-03 05:09 +0000

1""" 

2GPU Out of Memory (OOM) recovery utilities. 

3 

4Provides comprehensive OOM detection and cache clearing for all supported 

5GPU frameworks in OpenHCS. 

6 

7REFACTORED: Uses enum-driven metaprogramming to eliminate 71% of code duplication. 

8All OOM patterns and cache clearing operations are defined in framework_ops.py. 

9""" 

10 

11import gc 

12import logging 

13from typing import Optional 

14 

15from arraybridge.framework_ops import _FRAMEWORK_OPS 

16from arraybridge.types import MemoryType 

17from arraybridge.utils import optional_import 

18 

19logger = logging.getLogger(__name__) 

20 

21 

22def _is_oom_error(e: Exception, memory_type: str) -> bool: 

23 """ 

24 Detect Out of Memory errors for all GPU frameworks. 

25 

26 Auto-generated from framework_ops.py OOM patterns. 

27 

28 Args: 

29 e: Exception to check 

30 memory_type: Memory type string (e.g., 'torch', 'cupy') 

31 

32 Returns: 

33 True if exception is an OOM error for the given framework 

34 """ 

35 # Find the MemoryType enum for this memory_type string 

36 mem_type_enum = None 

37 for mt in MemoryType: 

38 if mt.value == memory_type: 

39 mem_type_enum = mt 

40 break 

41 

42 if mem_type_enum is None: 

43 return False 

44 

45 ops = _FRAMEWORK_OPS[mem_type_enum] 

46 error_str = str(e).lower() 

47 

48 # Check framework-specific exception types 

49 for exc_type_expr in ops["oom_exception_types"]: 

50 try: 

51 # Import the module and get the exception type 

52 mod_name = ops["import_name"] 

53 mod = optional_import(mod_name) 

54 if mod is None: 

55 continue 

56 

57 # Evaluate the exception type expression 

58 exc_type_str = exc_type_expr.format(mod="mod") 

59 # Extract the attribute path 

60 # (e.g., 'mod.cuda.OutOfMemoryError' -> ['cuda', 'OutOfMemoryError']) 

61 parts = exc_type_str.split(".")[1:] # Skip 'mod' 

62 exc_type = mod 

63 for part in parts: 

64 if hasattr(exc_type, part): 

65 exc_type = getattr(exc_type, part) 

66 else: 

67 exc_type = None 

68 break 

69 

70 if exc_type is not None and isinstance(e, exc_type): 

71 return True 

72 except Exception: 

73 continue 

74 

75 # String-based detection using framework-specific patterns 

76 return any(pattern in error_str for pattern in ops["oom_string_patterns"]) 

77 

78 

79def _clear_cache_for_memory_type(memory_type: str, device_id: Optional[int] = None): 

80 """ 

81 Clear GPU cache for specific memory type. 

82 

83 Auto-generated from framework_ops.py cache clearing operations. 

84 

85 Args: 

86 memory_type: Memory type string (e.g., 'torch', 'cupy') 

87 device_id: GPU device ID (optional, currently unused but kept for API compatibility) 

88 """ 

89 # Find the MemoryType enum for this memory_type string 

90 mem_type_enum = None 

91 for mt in MemoryType: 

92 if mt.value == memory_type: 

93 mem_type_enum = mt 

94 break 

95 

96 if mem_type_enum is None: 

97 logger.warning(f"Unknown memory type for cache clearing: {memory_type}") 

98 gc.collect() 

99 return 

100 

101 ops = _FRAMEWORK_OPS[mem_type_enum] 

102 

103 # Get the module 

104 mod_name = ops["import_name"] 

105 mod = optional_import(mod_name) 

106 

107 if mod is None: 

108 logger.warning(f"Module {mod_name} not available for cache clearing") 

109 gc.collect() 

110 return 

111 

112 # Execute cache clearing operations 

113 cache_clear_expr = ops["oom_clear_cache"] 

114 if cache_clear_expr: 

115 try: 

116 # Execute cache clear directly (device context handled by the operations themselves) 

117 exec(cache_clear_expr.format(mod=mod_name), {mod_name: mod, "gc": gc}) 

118 except Exception as e: 

119 logger.warning(f"Failed to clear cache for {memory_type}: {e}") 

120 

121 # Always trigger Python garbage collection 

122 gc.collect() 

123 

124 

125def _execute_with_oom_recovery(func_callable, memory_type: str, max_retries: int = 2): 

126 """ 

127 Execute function with automatic OOM recovery. 

128 

129 Args: 

130 func_callable: Function to execute 

131 memory_type: Memory type from MemoryType enum 

132 max_retries: Maximum number of retry attempts 

133 

134 Returns: 

135 Function result 

136 

137 Raises: 

138 Original exception if not OOM or retries exhausted 

139 """ 

140 for attempt in range(max_retries + 1): 

141 try: 

142 return func_callable() 

143 except Exception as e: 

144 if not _is_oom_error(e, memory_type) or attempt == max_retries: 

145 raise 

146 

147 # Clear cache and retry 

148 _clear_cache_for_memory_type(memory_type)