Coverage for openhcs/core/memory/oom_recovery.py: 5.9%
63 statements
« prev ^ index » next coverage.py v7.10.3, created at 2025-08-14 05:57 +0000
« prev ^ index » next coverage.py v7.10.3, created at 2025-08-14 05:57 +0000
1"""
2GPU Out of Memory (OOM) recovery utilities.
4Provides comprehensive OOM detection and cache clearing for all supported
5GPU frameworks in OpenHCS.
6"""
8import gc
9from typing import Optional
11from openhcs.constants.constants import (
12 MEMORY_TYPE_TORCH,
13 MEMORY_TYPE_CUPY,
14 MEMORY_TYPE_TENSORFLOW,
15 MEMORY_TYPE_JAX,
16 MEMORY_TYPE_PYCLESPERANTO,
17)
20def _is_oom_error(e: Exception, memory_type: str) -> bool:
21 """
22 Detect Out of Memory errors for all GPU frameworks.
24 Args:
25 e: Exception to check
26 memory_type: Memory type from MemoryType enum
28 Returns:
29 True if exception is an OOM error for the given framework
30 """
31 error_str = str(e).lower()
33 # Framework-specific exception types
34 if memory_type == MEMORY_TYPE_TORCH:
35 import torch
36 if hasattr(torch.cuda, 'OutOfMemoryError') and isinstance(e, torch.cuda.OutOfMemoryError):
37 return True
39 elif memory_type == MEMORY_TYPE_CUPY:
40 import cupy as cp
41 if hasattr(cp.cuda.memory, 'OutOfMemoryError') and isinstance(e, cp.cuda.memory.OutOfMemoryError):
42 return True
43 if hasattr(cp.cuda.runtime, 'CUDARuntimeError') and isinstance(e, cp.cuda.runtime.CUDARuntimeError):
44 return True
46 elif memory_type == MEMORY_TYPE_TENSORFLOW:
47 import tensorflow as tf
48 if hasattr(tf.errors, 'ResourceExhaustedError') and isinstance(e, tf.errors.ResourceExhaustedError):
49 return True
50 if hasattr(tf.errors, 'InvalidArgumentError') and isinstance(e, tf.errors.InvalidArgumentError):
51 return True
53 # String-based detection for all frameworks
54 oom_patterns = [
55 'out of memory', 'outofmemoryerror', 'resource_exhausted',
56 'cuda_error_out_of_memory', 'cl_mem_object_allocation_failure',
57 'cl_out_of_resources', 'oom when allocating', 'cannot allocate memory',
58 'allocation failure', 'memory exhausted', 'resourceexhausted'
59 ]
61 return any(pattern in error_str for pattern in oom_patterns)
64def _clear_cache_for_memory_type(memory_type: str, device_id: Optional[int] = None):
65 """
66 Clear GPU cache for specific memory type.
68 Args:
69 memory_type: Memory type from MemoryType enum
70 device_id: GPU device ID (optional)
71 """
72 if memory_type == MEMORY_TYPE_TORCH:
73 import torch
74 torch.cuda.empty_cache()
75 if device_id is not None:
76 with torch.cuda.device(device_id):
77 torch.cuda.synchronize()
78 else:
79 torch.cuda.synchronize()
81 elif memory_type == MEMORY_TYPE_CUPY:
82 import cupy as cp
83 if device_id is not None:
84 with cp.cuda.Device(device_id):
85 cp.get_default_memory_pool().free_all_blocks()
86 cp.get_default_pinned_memory_pool().free_all_blocks()
87 cp.cuda.runtime.deviceSynchronize()
88 else:
89 cp.get_default_memory_pool().free_all_blocks()
90 cp.get_default_pinned_memory_pool().free_all_blocks()
91 cp.cuda.runtime.deviceSynchronize()
93 elif memory_type == MEMORY_TYPE_TENSORFLOW:
94 # TensorFlow uses automatic memory management
95 gc.collect()
97 elif memory_type == MEMORY_TYPE_JAX:
98 import jax
99 jax.clear_caches()
100 gc.collect()
102 elif memory_type == MEMORY_TYPE_PYCLESPERANTO:
103 import pyclesperanto as cle
104 if device_id is not None and hasattr(cle, 'select_device'):
105 devices = cle.list_available_devices()
106 if device_id < len(devices):
107 cle.select_device(device_id)
108 gc.collect()
110 # Always trigger Python garbage collection
111 gc.collect()
114def _execute_with_oom_recovery(func_callable, memory_type: str, max_retries: int = 2):
115 """
116 Execute function with automatic OOM recovery.
118 Args:
119 func_callable: Function to execute
120 memory_type: Memory type from MemoryType enum
121 max_retries: Maximum number of retry attempts
123 Returns:
124 Function result
126 Raises:
127 Original exception if not OOM or retries exhausted
128 """
129 for attempt in range(max_retries + 1):
130 try:
131 return func_callable()
132 except Exception as e:
133 if not _is_oom_error(e, memory_type) or attempt == max_retries:
134 raise
136 # Clear cache and retry
137 _clear_cache_for_memory_type(memory_type)