Coverage for openhcs/core/memory/oom_recovery.py: 5.9%

63 statements  

« prev     ^ index     » next       coverage.py v7.10.3, created at 2025-08-14 05:57 +0000

1""" 

2GPU Out of Memory (OOM) recovery utilities. 

3 

4Provides comprehensive OOM detection and cache clearing for all supported 

5GPU frameworks in OpenHCS. 

6""" 

7 

8import gc 

9from typing import Optional 

10 

11from openhcs.constants.constants import ( 

12 MEMORY_TYPE_TORCH, 

13 MEMORY_TYPE_CUPY, 

14 MEMORY_TYPE_TENSORFLOW, 

15 MEMORY_TYPE_JAX, 

16 MEMORY_TYPE_PYCLESPERANTO, 

17) 

18 

19 

20def _is_oom_error(e: Exception, memory_type: str) -> bool: 

21 """ 

22 Detect Out of Memory errors for all GPU frameworks. 

23  

24 Args: 

25 e: Exception to check 

26 memory_type: Memory type from MemoryType enum 

27  

28 Returns: 

29 True if exception is an OOM error for the given framework 

30 """ 

31 error_str = str(e).lower() 

32 

33 # Framework-specific exception types 

34 if memory_type == MEMORY_TYPE_TORCH: 

35 import torch 

36 if hasattr(torch.cuda, 'OutOfMemoryError') and isinstance(e, torch.cuda.OutOfMemoryError): 

37 return True 

38 

39 elif memory_type == MEMORY_TYPE_CUPY: 

40 import cupy as cp 

41 if hasattr(cp.cuda.memory, 'OutOfMemoryError') and isinstance(e, cp.cuda.memory.OutOfMemoryError): 

42 return True 

43 if hasattr(cp.cuda.runtime, 'CUDARuntimeError') and isinstance(e, cp.cuda.runtime.CUDARuntimeError): 

44 return True 

45 

46 elif memory_type == MEMORY_TYPE_TENSORFLOW: 

47 import tensorflow as tf 

48 if hasattr(tf.errors, 'ResourceExhaustedError') and isinstance(e, tf.errors.ResourceExhaustedError): 

49 return True 

50 if hasattr(tf.errors, 'InvalidArgumentError') and isinstance(e, tf.errors.InvalidArgumentError): 

51 return True 

52 

53 # String-based detection for all frameworks 

54 oom_patterns = [ 

55 'out of memory', 'outofmemoryerror', 'resource_exhausted', 

56 'cuda_error_out_of_memory', 'cl_mem_object_allocation_failure', 

57 'cl_out_of_resources', 'oom when allocating', 'cannot allocate memory', 

58 'allocation failure', 'memory exhausted', 'resourceexhausted' 

59 ] 

60 

61 return any(pattern in error_str for pattern in oom_patterns) 

62 

63 

64def _clear_cache_for_memory_type(memory_type: str, device_id: Optional[int] = None): 

65 """ 

66 Clear GPU cache for specific memory type. 

67  

68 Args: 

69 memory_type: Memory type from MemoryType enum 

70 device_id: GPU device ID (optional) 

71 """ 

72 if memory_type == MEMORY_TYPE_TORCH: 

73 import torch 

74 torch.cuda.empty_cache() 

75 if device_id is not None: 

76 with torch.cuda.device(device_id): 

77 torch.cuda.synchronize() 

78 else: 

79 torch.cuda.synchronize() 

80 

81 elif memory_type == MEMORY_TYPE_CUPY: 

82 import cupy as cp 

83 if device_id is not None: 

84 with cp.cuda.Device(device_id): 

85 cp.get_default_memory_pool().free_all_blocks() 

86 cp.get_default_pinned_memory_pool().free_all_blocks() 

87 cp.cuda.runtime.deviceSynchronize() 

88 else: 

89 cp.get_default_memory_pool().free_all_blocks() 

90 cp.get_default_pinned_memory_pool().free_all_blocks() 

91 cp.cuda.runtime.deviceSynchronize() 

92 

93 elif memory_type == MEMORY_TYPE_TENSORFLOW: 

94 # TensorFlow uses automatic memory management 

95 gc.collect() 

96 

97 elif memory_type == MEMORY_TYPE_JAX: 

98 import jax 

99 jax.clear_caches() 

100 gc.collect() 

101 

102 elif memory_type == MEMORY_TYPE_PYCLESPERANTO: 

103 import pyclesperanto as cle 

104 if device_id is not None and hasattr(cle, 'select_device'): 

105 devices = cle.list_available_devices() 

106 if device_id < len(devices): 

107 cle.select_device(device_id) 

108 gc.collect() 

109 

110 # Always trigger Python garbage collection 

111 gc.collect() 

112 

113 

114def _execute_with_oom_recovery(func_callable, memory_type: str, max_retries: int = 2): 

115 """ 

116 Execute function with automatic OOM recovery. 

117  

118 Args: 

119 func_callable: Function to execute 

120 memory_type: Memory type from MemoryType enum 

121 max_retries: Maximum number of retry attempts 

122  

123 Returns: 

124 Function result 

125  

126 Raises: 

127 Original exception if not OOM or retries exhausted 

128 """ 

129 for attempt in range(max_retries + 1): 

130 try: 

131 return func_callable() 

132 except Exception as e: 

133 if not _is_oom_error(e, memory_type) or attempt == max_retries: 

134 raise 

135 

136 # Clear cache and retry 

137 _clear_cache_for_memory_type(memory_type)