Coverage for openhcs/core/memory/oom_recovery.py: 10.2%
68 statements
« prev ^ index » next coverage.py v7.11.0, created at 2025-11-04 02:09 +0000
« prev ^ index » next coverage.py v7.11.0, created at 2025-11-04 02:09 +0000
1"""
2GPU Out of Memory (OOM) recovery utilities.
4Provides comprehensive OOM detection and cache clearing for all supported
5GPU frameworks in OpenHCS.
7REFACTORED: Uses enum-driven metaprogramming to eliminate 71% of code duplication.
8All OOM patterns and cache clearing operations are defined in framework_ops.py.
9"""
11import gc
12import logging
13from typing import Optional
15from openhcs.constants.constants import MemoryType
16from openhcs.core.memory.framework_ops import _FRAMEWORK_OPS
17from openhcs.core.utils import optional_import
19logger = logging.getLogger(__name__)
22def _is_oom_error(e: Exception, memory_type: str) -> bool:
23 """
24 Detect Out of Memory errors for all GPU frameworks.
26 Auto-generated from framework_ops.py OOM patterns.
28 Args:
29 e: Exception to check
30 memory_type: Memory type string (e.g., 'torch', 'cupy')
32 Returns:
33 True if exception is an OOM error for the given framework
34 """
35 # Find the MemoryType enum for this memory_type string
36 mem_type_enum = None
37 for mt in MemoryType:
38 if mt.value == memory_type:
39 mem_type_enum = mt
40 break
42 if mem_type_enum is None:
43 return False
45 ops = _FRAMEWORK_OPS[mem_type_enum]
46 error_str = str(e).lower()
48 # Check framework-specific exception types
49 for exc_type_expr in ops['oom_exception_types']:
50 try:
51 # Import the module and get the exception type
52 mod_name = ops['import_name']
53 mod = optional_import(mod_name)
54 if mod is None:
55 continue
57 # Evaluate the exception type expression
58 exc_type_str = exc_type_expr.format(mod='mod')
59 # Extract the attribute path (e.g., 'mod.cuda.OutOfMemoryError' -> ['cuda', 'OutOfMemoryError'])
60 parts = exc_type_str.split('.')[1:] # Skip 'mod'
61 exc_type = mod
62 for part in parts:
63 if hasattr(exc_type, part):
64 exc_type = getattr(exc_type, part)
65 else:
66 exc_type = None
67 break
69 if exc_type is not None and isinstance(e, exc_type):
70 return True
71 except Exception:
72 continue
74 # String-based detection using framework-specific patterns
75 return any(pattern in error_str for pattern in ops['oom_string_patterns'])
78def _clear_cache_for_memory_type(memory_type: str, device_id: Optional[int] = None):
79 """
80 Clear GPU cache for specific memory type.
82 Auto-generated from framework_ops.py cache clearing operations.
84 Args:
85 memory_type: Memory type string (e.g., 'torch', 'cupy')
86 device_id: GPU device ID (optional, currently unused but kept for API compatibility)
87 """
88 # Find the MemoryType enum for this memory_type string
89 mem_type_enum = None
90 for mt in MemoryType:
91 if mt.value == memory_type:
92 mem_type_enum = mt
93 break
95 if mem_type_enum is None:
96 logger.warning(f"Unknown memory type for cache clearing: {memory_type}")
97 gc.collect()
98 return
100 ops = _FRAMEWORK_OPS[mem_type_enum]
102 # Get the module
103 mod_name = ops['import_name']
104 mod = optional_import(mod_name)
106 if mod is None:
107 logger.warning(f"Module {mod_name} not available for cache clearing")
108 gc.collect()
109 return
111 # Execute cache clearing operations
112 cache_clear_expr = ops['oom_clear_cache']
113 if cache_clear_expr:
114 try:
115 # Execute cache clear directly (device context handled by the operations themselves)
116 exec(cache_clear_expr.format(mod=mod_name), {mod_name: mod, 'gc': gc})
117 except Exception as e:
118 logger.warning(f"Failed to clear cache for {memory_type}: {e}")
120 # Always trigger Python garbage collection
121 gc.collect()
124def _execute_with_oom_recovery(func_callable, memory_type: str, max_retries: int = 2):
125 """
126 Execute function with automatic OOM recovery.
128 Args:
129 func_callable: Function to execute
130 memory_type: Memory type from MemoryType enum
131 max_retries: Maximum number of retry attempts
133 Returns:
134 Function result
136 Raises:
137 Original exception if not OOM or retries exhausted
138 """
139 for attempt in range(max_retries + 1):
140 try:
141 return func_callable()
142 except Exception as e:
143 if not _is_oom_error(e, memory_type) or attempt == max_retries:
144 raise
146 # Clear cache and retry
147 _clear_cache_for_memory_type(memory_type)