Coverage for src/arraybridge/converters_registry.py: 94%
90 statements
« prev ^ index » next coverage.py v7.11.0, created at 2025-11-03 05:09 +0000
« prev ^ index » next coverage.py v7.11.0, created at 2025-11-03 05:09 +0000
1"""
2Registry-based converter infrastructure using metaclass-registry.
4This module provides the ConverterBase class using AutoRegisterMeta,
5concrete converter implementations for each framework, and a helper
6function for registry lookups.
7"""
9import logging
10from abc import abstractmethod
12from metaclass_registry import AutoRegisterMeta
14from arraybridge.framework_config import _FRAMEWORK_CONFIG
15from arraybridge.types import MemoryType
17logger = logging.getLogger(__name__)
20class ConverterBase(metaclass=AutoRegisterMeta):
21 """Base class for memory type converters using auto-registration.
23 Each concrete converter sets memory_type to register itself in the registry.
24 The registry key is the memory_type attribute (e.g., "numpy", "torch").
25 """
27 __registry_key__ = "memory_type"
28 memory_type: str = None
30 @abstractmethod
31 def to_numpy(self, data, gpu_id):
32 """Extract to NumPy (type-specific implementation)."""
33 pass
35 @abstractmethod
36 def from_numpy(self, data, gpu_id):
37 """Create from NumPy (type-specific implementation)."""
38 pass
40 @abstractmethod
41 def from_dlpack(self, data, gpu_id):
42 """Create from DLPack capsule (type-specific implementation)."""
43 pass
45 @abstractmethod
46 def move_to_device(self, data, gpu_id):
47 """Move data to specified GPU device if needed (type-specific implementation)."""
48 pass
51def _ensure_module(memory_type: str):
52 """Import and return the module for the given memory type."""
53 from arraybridge.utils import _ensure_module as _ensure_module_impl
55 return _ensure_module_impl(memory_type)
58def _make_lambda_with_name(expr_str, mem_type, method_name):
59 """Create a lambda from expression string and add proper __name__ for debugging.
61 Note: Uses eval() for dynamic code generation from trusted framework_config.py strings.
62 This is safe because:
63 1. Input strings come from _FRAMEWORK_CONFIG, not user input
64 2. Strings are defined at module load time by package maintainers
65 3. This pattern enables declarative framework configuration
66 """
67 module_str = f'_ensure_module("{mem_type.value}")'
68 lambda_expr = f"lambda self, data, gpu_id: {expr_str.format(mod=module_str)}"
69 lambda_func = eval(lambda_expr)
70 lambda_func.__name__ = method_name
71 lambda_func.__qualname__ = f"{mem_type.value.capitalize()}Converter.{method_name}"
72 return lambda_func
75def _make_not_implemented(mem_type_value, method_name):
76 """Create a lambda that raises NotImplementedError with the correct signature."""
78 def not_impl(self, data, gpu_id):
79 raise NotImplementedError(f"DLPack not supported for {mem_type_value}")
81 not_impl.__name__ = method_name
82 not_impl.__qualname__ = f"{mem_type_value.capitalize()}Converter.{method_name}"
83 return not_impl
86# Auto-generate converter classes for each memory type
87def _create_converter_classes():
88 """Create concrete converter classes for each memory type."""
89 converters = {}
91 for mem_type in MemoryType:
92 config = _FRAMEWORK_CONFIG[mem_type]
93 conversion_ops = config["conversion_ops"]
95 # Build class attributes
96 class_attrs = {
97 "memory_type": mem_type.value,
98 }
100 # Add conversion methods
101 for method_name, expr in conversion_ops.items():
102 if expr is None:
103 class_attrs[method_name] = _make_not_implemented(mem_type.value, method_name)
104 else:
105 class_attrs[method_name] = _make_lambda_with_name(expr, mem_type, method_name)
107 # Create the class
108 class_name = f"{mem_type.value.capitalize()}Converter"
109 converter_class = type(class_name, (ConverterBase,), class_attrs)
111 converters[mem_type] = converter_class
113 return converters
116# Create all converter classes at module load time
117_CONVERTER_CLASSES = _create_converter_classes()
120def get_converter(memory_type: str):
121 """Get a converter instance for the given memory type.
123 Args:
124 memory_type: The memory type string (e.g., "numpy", "torch")
126 Returns:
127 A converter instance for the memory type
129 Raises:
130 ValueError: If memory type is not registered
131 """
132 converter_class = ConverterBase.__registry__.get(memory_type)
133 if converter_class is None:
134 raise ValueError(
135 f"No converter registered for memory type '{memory_type}'. "
136 f"Available types: {sorted(ConverterBase.__registry__.keys())}"
137 )
138 return converter_class()
141def _add_converter_methods():
142 """Add to_X() methods to ConverterBase.
144 For each target memory type, generates a method like to_cupy(), to_torch(), etc.
145 that tries GPU-to-GPU conversion via DLPack first, then falls back to CPU roundtrip.
146 """
147 from arraybridge.utils import _supports_dlpack
149 for target_type in MemoryType:
150 method_name = f"to_{target_type.value}"
152 def make_method(tgt):
153 def method(self, data, gpu_id):
154 # Try GPU-to-GPU first (DLPack)
155 if _supports_dlpack(data):
156 try:
157 target_converter = get_converter(tgt.value)
158 result = target_converter.from_dlpack(data, gpu_id)
159 return target_converter.move_to_device(result, gpu_id)
160 except Exception as e:
161 logger.warning(f"DLPack conversion failed: {e}. Using CPU roundtrip.")
163 # CPU roundtrip using polymorphism
164 numpy_data = self.to_numpy(data, gpu_id)
165 target_converter = get_converter(tgt.value)
166 return target_converter.from_numpy(numpy_data, gpu_id)
168 return method
170 setattr(ConverterBase, method_name, make_method(target_type))
173def _validate_registry():
174 """Validate that all memory types are registered."""
175 required_types = {mt.value for mt in MemoryType}
176 registered_types = set(ConverterBase.__registry__.keys())
178 if required_types != registered_types:
179 missing = required_types - registered_types
180 extra = registered_types - required_types
181 msg_parts = []
182 if missing:
183 msg_parts.append(f"Missing: {missing}")
184 if extra:
185 msg_parts.append(f"Extra: {extra}")
186 raise RuntimeError(f"Registry validation failed. {', '.join(msg_parts)}")
188 logger.debug(f"✅ Validated {len(registered_types)} memory type converters in registry")
191# Add to_X() conversion methods after converter classes are created
192_add_converter_methods()
194# Run validation at module load time
195_validate_registry()