Coverage for src/arraybridge/converters_registry.py: 94%

90 statements  

« prev     ^ index     » next       coverage.py v7.11.0, created at 2025-11-03 05:09 +0000

1""" 

2Registry-based converter infrastructure using metaclass-registry. 

3 

4This module provides the ConverterBase class using AutoRegisterMeta, 

5concrete converter implementations for each framework, and a helper 

6function for registry lookups. 

7""" 

8 

9import logging 

10from abc import abstractmethod 

11 

12from metaclass_registry import AutoRegisterMeta 

13 

14from arraybridge.framework_config import _FRAMEWORK_CONFIG 

15from arraybridge.types import MemoryType 

16 

17logger = logging.getLogger(__name__) 

18 

19 

20class ConverterBase(metaclass=AutoRegisterMeta): 

21 """Base class for memory type converters using auto-registration. 

22 

23 Each concrete converter sets memory_type to register itself in the registry. 

24 The registry key is the memory_type attribute (e.g., "numpy", "torch"). 

25 """ 

26 

27 __registry_key__ = "memory_type" 

28 memory_type: str = None 

29 

30 @abstractmethod 

31 def to_numpy(self, data, gpu_id): 

32 """Extract to NumPy (type-specific implementation).""" 

33 pass 

34 

35 @abstractmethod 

36 def from_numpy(self, data, gpu_id): 

37 """Create from NumPy (type-specific implementation).""" 

38 pass 

39 

40 @abstractmethod 

41 def from_dlpack(self, data, gpu_id): 

42 """Create from DLPack capsule (type-specific implementation).""" 

43 pass 

44 

45 @abstractmethod 

46 def move_to_device(self, data, gpu_id): 

47 """Move data to specified GPU device if needed (type-specific implementation).""" 

48 pass 

49 

50 

51def _ensure_module(memory_type: str): 

52 """Import and return the module for the given memory type.""" 

53 from arraybridge.utils import _ensure_module as _ensure_module_impl 

54 

55 return _ensure_module_impl(memory_type) 

56 

57 

58def _make_lambda_with_name(expr_str, mem_type, method_name): 

59 """Create a lambda from expression string and add proper __name__ for debugging. 

60 

61 Note: Uses eval() for dynamic code generation from trusted framework_config.py strings. 

62 This is safe because: 

63 1. Input strings come from _FRAMEWORK_CONFIG, not user input 

64 2. Strings are defined at module load time by package maintainers 

65 3. This pattern enables declarative framework configuration 

66 """ 

67 module_str = f'_ensure_module("{mem_type.value}")' 

68 lambda_expr = f"lambda self, data, gpu_id: {expr_str.format(mod=module_str)}" 

69 lambda_func = eval(lambda_expr) 

70 lambda_func.__name__ = method_name 

71 lambda_func.__qualname__ = f"{mem_type.value.capitalize()}Converter.{method_name}" 

72 return lambda_func 

73 

74 

75def _make_not_implemented(mem_type_value, method_name): 

76 """Create a lambda that raises NotImplementedError with the correct signature.""" 

77 

78 def not_impl(self, data, gpu_id): 

79 raise NotImplementedError(f"DLPack not supported for {mem_type_value}") 

80 

81 not_impl.__name__ = method_name 

82 not_impl.__qualname__ = f"{mem_type_value.capitalize()}Converter.{method_name}" 

83 return not_impl 

84 

85 

86# Auto-generate converter classes for each memory type 

87def _create_converter_classes(): 

88 """Create concrete converter classes for each memory type.""" 

89 converters = {} 

90 

91 for mem_type in MemoryType: 

92 config = _FRAMEWORK_CONFIG[mem_type] 

93 conversion_ops = config["conversion_ops"] 

94 

95 # Build class attributes 

96 class_attrs = { 

97 "memory_type": mem_type.value, 

98 } 

99 

100 # Add conversion methods 

101 for method_name, expr in conversion_ops.items(): 

102 if expr is None: 

103 class_attrs[method_name] = _make_not_implemented(mem_type.value, method_name) 

104 else: 

105 class_attrs[method_name] = _make_lambda_with_name(expr, mem_type, method_name) 

106 

107 # Create the class 

108 class_name = f"{mem_type.value.capitalize()}Converter" 

109 converter_class = type(class_name, (ConverterBase,), class_attrs) 

110 

111 converters[mem_type] = converter_class 

112 

113 return converters 

114 

115 

116# Create all converter classes at module load time 

117_CONVERTER_CLASSES = _create_converter_classes() 

118 

119 

120def get_converter(memory_type: str): 

121 """Get a converter instance for the given memory type. 

122 

123 Args: 

124 memory_type: The memory type string (e.g., "numpy", "torch") 

125 

126 Returns: 

127 A converter instance for the memory type 

128 

129 Raises: 

130 ValueError: If memory type is not registered 

131 """ 

132 converter_class = ConverterBase.__registry__.get(memory_type) 

133 if converter_class is None: 

134 raise ValueError( 

135 f"No converter registered for memory type '{memory_type}'. " 

136 f"Available types: {sorted(ConverterBase.__registry__.keys())}" 

137 ) 

138 return converter_class() 

139 

140 

141def _add_converter_methods(): 

142 """Add to_X() methods to ConverterBase. 

143 

144 For each target memory type, generates a method like to_cupy(), to_torch(), etc. 

145 that tries GPU-to-GPU conversion via DLPack first, then falls back to CPU roundtrip. 

146 """ 

147 from arraybridge.utils import _supports_dlpack 

148 

149 for target_type in MemoryType: 

150 method_name = f"to_{target_type.value}" 

151 

152 def make_method(tgt): 

153 def method(self, data, gpu_id): 

154 # Try GPU-to-GPU first (DLPack) 

155 if _supports_dlpack(data): 

156 try: 

157 target_converter = get_converter(tgt.value) 

158 result = target_converter.from_dlpack(data, gpu_id) 

159 return target_converter.move_to_device(result, gpu_id) 

160 except Exception as e: 

161 logger.warning(f"DLPack conversion failed: {e}. Using CPU roundtrip.") 

162 

163 # CPU roundtrip using polymorphism 

164 numpy_data = self.to_numpy(data, gpu_id) 

165 target_converter = get_converter(tgt.value) 

166 return target_converter.from_numpy(numpy_data, gpu_id) 

167 

168 return method 

169 

170 setattr(ConverterBase, method_name, make_method(target_type)) 

171 

172 

173def _validate_registry(): 

174 """Validate that all memory types are registered.""" 

175 required_types = {mt.value for mt in MemoryType} 

176 registered_types = set(ConverterBase.__registry__.keys()) 

177 

178 if required_types != registered_types: 

179 missing = required_types - registered_types 

180 extra = registered_types - required_types 

181 msg_parts = [] 

182 if missing: 

183 msg_parts.append(f"Missing: {missing}") 

184 if extra: 

185 msg_parts.append(f"Extra: {extra}") 

186 raise RuntimeError(f"Registry validation failed. {', '.join(msg_parts)}") 

187 

188 logger.debug(f"✅ Validated {len(registered_types)} memory type converters in registry") 

189 

190 

191# Add to_X() conversion methods after converter classes are created 

192_add_converter_methods() 

193 

194# Run validation at module load time 

195_validate_registry()