Coverage for src/arraybridge/oom_recovery.py: 90%
68 statements
« prev ^ index » next coverage.py v7.11.0, created at 2025-11-03 05:09 +0000
« prev ^ index » next coverage.py v7.11.0, created at 2025-11-03 05:09 +0000
1"""
2GPU Out of Memory (OOM) recovery utilities.
4Provides comprehensive OOM detection and cache clearing for all supported
5GPU frameworks in OpenHCS.
7REFACTORED: Uses enum-driven metaprogramming to eliminate 71% of code duplication.
8All OOM patterns and cache clearing operations are defined in framework_ops.py.
9"""
11import gc
12import logging
13from typing import Optional
15from arraybridge.framework_ops import _FRAMEWORK_OPS
16from arraybridge.types import MemoryType
17from arraybridge.utils import optional_import
19logger = logging.getLogger(__name__)
22def _is_oom_error(e: Exception, memory_type: str) -> bool:
23 """
24 Detect Out of Memory errors for all GPU frameworks.
26 Auto-generated from framework_ops.py OOM patterns.
28 Args:
29 e: Exception to check
30 memory_type: Memory type string (e.g., 'torch', 'cupy')
32 Returns:
33 True if exception is an OOM error for the given framework
34 """
35 # Find the MemoryType enum for this memory_type string
36 mem_type_enum = None
37 for mt in MemoryType:
38 if mt.value == memory_type:
39 mem_type_enum = mt
40 break
42 if mem_type_enum is None:
43 return False
45 ops = _FRAMEWORK_OPS[mem_type_enum]
46 error_str = str(e).lower()
48 # Check framework-specific exception types
49 for exc_type_expr in ops["oom_exception_types"]:
50 try:
51 # Import the module and get the exception type
52 mod_name = ops["import_name"]
53 mod = optional_import(mod_name)
54 if mod is None:
55 continue
57 # Evaluate the exception type expression
58 exc_type_str = exc_type_expr.format(mod="mod")
59 # Extract the attribute path
60 # (e.g., 'mod.cuda.OutOfMemoryError' -> ['cuda', 'OutOfMemoryError'])
61 parts = exc_type_str.split(".")[1:] # Skip 'mod'
62 exc_type = mod
63 for part in parts:
64 if hasattr(exc_type, part):
65 exc_type = getattr(exc_type, part)
66 else:
67 exc_type = None
68 break
70 if exc_type is not None and isinstance(e, exc_type):
71 return True
72 except Exception:
73 continue
75 # String-based detection using framework-specific patterns
76 return any(pattern in error_str for pattern in ops["oom_string_patterns"])
79def _clear_cache_for_memory_type(memory_type: str, device_id: Optional[int] = None):
80 """
81 Clear GPU cache for specific memory type.
83 Auto-generated from framework_ops.py cache clearing operations.
85 Args:
86 memory_type: Memory type string (e.g., 'torch', 'cupy')
87 device_id: GPU device ID (optional, currently unused but kept for API compatibility)
88 """
89 # Find the MemoryType enum for this memory_type string
90 mem_type_enum = None
91 for mt in MemoryType:
92 if mt.value == memory_type:
93 mem_type_enum = mt
94 break
96 if mem_type_enum is None:
97 logger.warning(f"Unknown memory type for cache clearing: {memory_type}")
98 gc.collect()
99 return
101 ops = _FRAMEWORK_OPS[mem_type_enum]
103 # Get the module
104 mod_name = ops["import_name"]
105 mod = optional_import(mod_name)
107 if mod is None:
108 logger.warning(f"Module {mod_name} not available for cache clearing")
109 gc.collect()
110 return
112 # Execute cache clearing operations
113 cache_clear_expr = ops["oom_clear_cache"]
114 if cache_clear_expr:
115 try:
116 # Execute cache clear directly (device context handled by the operations themselves)
117 exec(cache_clear_expr.format(mod=mod_name), {mod_name: mod, "gc": gc})
118 except Exception as e:
119 logger.warning(f"Failed to clear cache for {memory_type}: {e}")
121 # Always trigger Python garbage collection
122 gc.collect()
125def _execute_with_oom_recovery(func_callable, memory_type: str, max_retries: int = 2):
126 """
127 Execute function with automatic OOM recovery.
129 Args:
130 func_callable: Function to execute
131 memory_type: Memory type from MemoryType enum
132 max_retries: Maximum number of retry attempts
134 Returns:
135 Function result
137 Raises:
138 Original exception if not OOM or retries exhausted
139 """
140 for attempt in range(max_retries + 1):
141 try:
142 return func_callable()
143 except Exception as e:
144 if not _is_oom_error(e, memory_type) or attempt == max_retries:
145 raise
147 # Clear cache and retry
148 _clear_cache_for_memory_type(memory_type)