Coverage for src / arraybridge / converters_registry.py: 95%

91 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-01-10 22:33 +0000

1""" 

2Registry-based converter infrastructure using metaclass-registry. 

3 

4This module provides the ConverterBase class using AutoRegisterMeta, 

5concrete converter implementations for each framework, and a helper 

6function for registry lookups. 

7""" 

8 

9import logging 

10from abc import abstractmethod 

11 

12from metaclass_registry import AutoRegisterMeta 

13 

14from arraybridge.framework_config import _FRAMEWORK_CONFIG 

15from arraybridge.types import MemoryType 

16 

17logger = logging.getLogger(__name__) 

18 

19 

20class ConverterBase(metaclass=AutoRegisterMeta): 

21 """Base class for memory type converters using auto-registration. 

22 

23 Each concrete converter sets memory_type to register itself in the registry. 

24 The registry key is the memory_type attribute (e.g., "numpy", "torch"). 

25 """ 

26 

27 __registry_key__ = "memory_type" 

28 __registry__ = {} # Simple dict - no lazy discovery needed (converters created dynamically below) 

29 memory_type: str = None 

30 

31 @abstractmethod 

32 def to_numpy(self, data, gpu_id): 

33 """Extract to NumPy (type-specific implementation).""" 

34 pass 

35 

36 @abstractmethod 

37 def from_numpy(self, data, gpu_id): 

38 """Create from NumPy (type-specific implementation).""" 

39 pass 

40 

41 @abstractmethod 

42 def from_dlpack(self, data, gpu_id): 

43 """Create from DLPack capsule (type-specific implementation).""" 

44 pass 

45 

46 @abstractmethod 

47 def move_to_device(self, data, gpu_id): 

48 """Move data to specified GPU device if needed (type-specific implementation).""" 

49 pass 

50 

51 

52def _ensure_module(memory_type: str): 

53 """Import and return the module for the given memory type.""" 

54 from arraybridge.utils import _ensure_module as _ensure_module_impl 

55 

56 return _ensure_module_impl(memory_type) 

57 

58 

59def _make_lambda_with_name(expr_str, mem_type, method_name): 

60 """Create a lambda from expression string and add proper __name__ for debugging. 

61 

62 Note: Uses eval() for dynamic code generation from trusted framework_config.py strings. 

63 This is safe because: 

64 1. Input strings come from _FRAMEWORK_CONFIG, not user input 

65 2. Strings are defined at module load time by package maintainers 

66 3. This pattern enables declarative framework configuration 

67 """ 

68 module_str = f'_ensure_module("{mem_type.value}")' 

69 lambda_expr = f"lambda self, data, gpu_id: {expr_str.format(mod=module_str)}" 

70 lambda_func = eval(lambda_expr) 

71 lambda_func.__name__ = method_name 

72 lambda_func.__qualname__ = f"{mem_type.value.capitalize()}Converter.{method_name}" 

73 return lambda_func 

74 

75 

76def _make_not_implemented(mem_type_value, method_name): 

77 """Create a lambda that raises NotImplementedError with the correct signature.""" 

78 

79 def not_impl(self, data, gpu_id): 

80 raise NotImplementedError(f"DLPack not supported for {mem_type_value}") 

81 

82 not_impl.__name__ = method_name 

83 not_impl.__qualname__ = f"{mem_type_value.capitalize()}Converter.{method_name}" 

84 return not_impl 

85 

86 

87# Auto-generate converter classes for each memory type 

88def _create_converter_classes(): 

89 """Create concrete converter classes for each memory type.""" 

90 converters = {} 

91 

92 for mem_type in MemoryType: 

93 config = _FRAMEWORK_CONFIG[mem_type] 

94 conversion_ops = config["conversion_ops"] 

95 

96 # Build class attributes 

97 class_attrs = { 

98 "memory_type": mem_type.value, 

99 } 

100 

101 # Add conversion methods 

102 for method_name, expr in conversion_ops.items(): 

103 if expr is None: 

104 class_attrs[method_name] = _make_not_implemented(mem_type.value, method_name) 

105 else: 

106 class_attrs[method_name] = _make_lambda_with_name(expr, mem_type, method_name) 

107 

108 # Create the class 

109 class_name = f"{mem_type.value.capitalize()}Converter" 

110 converter_class = type(class_name, (ConverterBase,), class_attrs) 

111 

112 converters[mem_type] = converter_class 

113 

114 return converters 

115 

116 

117# Create all converter classes at module load time 

118_CONVERTER_CLASSES = _create_converter_classes() 

119 

120 

121def get_converter(memory_type: str): 

122 """Get a converter instance for the given memory type. 

123 

124 Args: 

125 memory_type: The memory type string (e.g., "numpy", "torch") 

126 

127 Returns: 

128 A converter instance for the memory type 

129 

130 Raises: 

131 ValueError: If memory type is not registered 

132 """ 

133 converter_class = ConverterBase.__registry__.get(memory_type) 

134 if converter_class is None: 

135 raise ValueError( 

136 f"No converter registered for memory type '{memory_type}'. " 

137 f"Available types: {sorted(ConverterBase.__registry__.keys())}" 

138 ) 

139 return converter_class() 

140 

141 

142def _add_converter_methods(): 

143 """Add to_X() methods to ConverterBase. 

144 

145 For each target memory type, generates a method like to_cupy(), to_torch(), etc. 

146 that tries GPU-to-GPU conversion via DLPack first, then falls back to CPU roundtrip. 

147 """ 

148 from arraybridge.utils import _supports_dlpack 

149 

150 for target_type in MemoryType: 

151 method_name = f"to_{target_type.value}" 

152 

153 def make_method(tgt): 

154 def method(self, data, gpu_id): 

155 # Try GPU-to-GPU first (DLPack) 

156 if _supports_dlpack(data): 

157 try: 

158 target_converter = get_converter(tgt.value) 

159 result = target_converter.from_dlpack(data, gpu_id) 

160 return target_converter.move_to_device(result, gpu_id) 

161 except Exception as e: 

162 logger.warning(f"DLPack conversion failed: {e}. Using CPU roundtrip.") 

163 

164 # CPU roundtrip using polymorphism 

165 numpy_data = self.to_numpy(data, gpu_id) 

166 target_converter = get_converter(tgt.value) 

167 return target_converter.from_numpy(numpy_data, gpu_id) 

168 

169 return method 

170 

171 setattr(ConverterBase, method_name, make_method(target_type)) 

172 

173 

174def _validate_registry(): 

175 """Validate that all memory types are registered.""" 

176 required_types = {mt.value for mt in MemoryType} 

177 registered_types = set(ConverterBase.__registry__.keys()) 

178 

179 if required_types != registered_types: 

180 missing = required_types - registered_types 

181 extra = registered_types - required_types 

182 msg_parts = [] 

183 if missing: 

184 msg_parts.append(f"Missing: {missing}") 

185 if extra: 

186 msg_parts.append(f"Extra: {extra}") 

187 raise RuntimeError(f"Registry validation failed. {', '.join(msg_parts)}") 

188 

189 logger.debug(f"✅ Validated {len(registered_types)} memory type converters in registry") 

190 

191 

192# Add to_X() conversion methods after converter classes are created 

193_add_converter_methods() 

194 

195# Run validation at module load time 

196_validate_registry()