Coverage for openhcs/core/memory/conversion_helpers.py: 76.1%

53 statements  

« prev     ^ index     » next       coverage.py v7.11.0, created at 2025-11-04 02:09 +0000

1""" 

2Memory conversion helpers for OpenHCS. 

3 

4This module provides the ABC and metaprogramming infrastructure for memory type conversions. 

5Uses enum-driven polymorphism to eliminate 1,567 lines of duplication. 

6""" 

7 

8from abc import ABC, abstractmethod 

9from openhcs.constants.constants import MemoryType 

10from openhcs.core.memory.framework_config import _FRAMEWORK_CONFIG 

11from openhcs.core.memory.utils import _ensure_module, _supports_dlpack 

12import logging 

13import numpy as np 

14 

15logger = logging.getLogger(__name__) 

16 

17 

18class MemoryTypeConverter(ABC): 

19 """Abstract base class for memory type converters. 

20  

21 Each memory type (numpy, cupy, torch, etc.) has a concrete converter 

22 that implements these four core operations. All to_X() methods are 

23 auto-generated using polymorphism. 

24 """ 

25 

26 @abstractmethod 

27 def to_numpy(self, data, gpu_id): 

28 """Extract to NumPy (type-specific implementation).""" 

29 pass 

30 

31 @abstractmethod 

32 def from_numpy(self, data, gpu_id): 

33 """Create from NumPy (type-specific implementation).""" 

34 pass 

35 

36 @abstractmethod 

37 def from_dlpack(self, data, gpu_id): 

38 """Create from DLPack capsule (type-specific implementation).""" 

39 pass 

40 

41 @abstractmethod 

42 def move_to_device(self, data, gpu_id): 

43 """Move data to specified GPU device if needed (type-specific implementation).""" 

44 pass 

45 

46 

47def _add_converter_methods(): 

48 """Add to_X() methods to MemoryTypeConverter ABC. 

49  

50 NOTE: This must be called AFTER _CONVERTERS is defined (see below). 

51  

52 For each target memory type, generates a method like to_cupy(), to_torch(), etc. 

53 that tries GPU-to-GPU conversion via DLPack first, then falls back to CPU roundtrip. 

54 """ 

55 for target_type in MemoryType: 

56 method_name = f"to_{target_type.value}" 

57 

58 def make_method(tgt): 

59 def method(self, data, gpu_id): 

60 # Try GPU-to-GPU first (DLPack) 

61 if _supports_dlpack(data): 

62 try: 

63 target_converter = _CONVERTERS[tgt] 

64 result = target_converter.from_dlpack(data, gpu_id) 

65 return target_converter.move_to_device(result, gpu_id) 

66 except Exception as e: 

67 logger.warning(f"DLPack conversion failed: {e}. Using CPU roundtrip.") 

68 

69 # CPU roundtrip using polymorphism 

70 numpy_data = self.to_numpy(data, gpu_id) 

71 target_converter = _CONVERTERS[tgt] 

72 return target_converter.from_numpy(numpy_data, gpu_id) 

73 return method 

74 

75 setattr(MemoryTypeConverter, method_name, make_method(target_type)) 

76 

77 

78# NOTE: Conversion operations now defined in framework_config.py under 'conversion_ops' 

79# This eliminates the scattered _OPS dict 

80_OPS = {mem_type: config['conversion_ops'] for mem_type, config in _FRAMEWORK_CONFIG.items()} 

81 

82# Auto-generate lambdas from strings 

83def _make_not_implemented(mem_type_value, method_name): 

84 """Create a lambda that raises NotImplementedError with the correct signature.""" 

85 def not_impl(self, data, gpu_id): 

86 raise NotImplementedError(f"DLPack not supported for {mem_type_value}") 

87 # Add proper names for better debugging 

88 not_impl.__name__ = method_name 

89 not_impl.__qualname__ = f'{mem_type_value.capitalize()}Converter.{method_name}' 

90 return not_impl 

91 

92def _make_lambda_with_name(expr_str, mem_type, method_name): 

93 """Create a lambda from expression string and add proper __name__ for debugging.""" 

94 # Pre-compute the module string to avoid nested f-strings with backslashes (Python 3.11 limitation) 

95 module_str = f'_ensure_module("{mem_type.value}")' 

96 lambda_expr = f'lambda self, data, gpu_id: {expr_str.format(mod=module_str)}' 

97 lambda_func = eval(lambda_expr) 

98 lambda_func.__name__ = method_name 

99 lambda_func.__qualname__ = f'{mem_type.value.capitalize()}Converter.{method_name}' 

100 return lambda_func 

101 

102_TYPE_OPERATIONS = { 

103 mem_type: { 

104 method_name: ( 

105 _make_lambda_with_name(expr, mem_type, method_name) 

106 if expr is not None 

107 else _make_not_implemented(mem_type.value, method_name) 

108 ) 

109 for method_name, expr in ops.items() # Iterate over dict items - self-documenting! 

110 } 

111 for mem_type, ops in _OPS.items() 

112} 

113 

114# Auto-generate all 6 converter classes 

115_CONVERTERS = { 

116 mem_type: type( 

117 f"{mem_type.value.capitalize()}Converter", 

118 (MemoryTypeConverter,), 

119 _TYPE_OPERATIONS[mem_type] 

120 )() 

121 for mem_type in MemoryType 

122} 

123 

124# NOW call _add_converter_methods() after _CONVERTERS exists 

125_add_converter_methods() 

126 

127 

128# Runtime validation: ensure all converters have required methods 

129def _validate_converters(): 

130 """Validate that all generated converters have the required methods.""" 

131 required_methods = ['to_numpy', 'from_numpy', 'from_dlpack', 'move_to_device'] 

132 

133 for mem_type, converter in _CONVERTERS.items(): 

134 # Check ABC methods 

135 for method in required_methods: 

136 if not hasattr(converter, method): 136 ↛ 137line 136 didn't jump to line 137 because the condition on line 136 was never true

137 raise RuntimeError(f"{mem_type.value} converter missing method: {method}") 

138 

139 # Check to_X() methods for all memory types 

140 for target_type in MemoryType: 

141 method_name = f'to_{target_type.value}' 

142 if not hasattr(converter, method_name): 142 ↛ 143line 142 didn't jump to line 143 because the condition on line 142 was never true

143 raise RuntimeError(f"{mem_type.value} converter missing method: {method_name}") 

144 

145 logger.debug(f"✅ Validated {len(_CONVERTERS)} memory type converters") 

146 

147# Run validation at module load time 

148_validate_converters() 

149