Coverage for src / arraybridge / converters_registry.py: 95%
91 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-10 22:33 +0000
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-10 22:33 +0000
1"""
2Registry-based converter infrastructure using metaclass-registry.
4This module provides the ConverterBase class using AutoRegisterMeta,
5concrete converter implementations for each framework, and a helper
6function for registry lookups.
7"""
9import logging
10from abc import abstractmethod
12from metaclass_registry import AutoRegisterMeta
14from arraybridge.framework_config import _FRAMEWORK_CONFIG
15from arraybridge.types import MemoryType
17logger = logging.getLogger(__name__)
20class ConverterBase(metaclass=AutoRegisterMeta):
21 """Base class for memory type converters using auto-registration.
23 Each concrete converter sets memory_type to register itself in the registry.
24 The registry key is the memory_type attribute (e.g., "numpy", "torch").
25 """
27 __registry_key__ = "memory_type"
28 __registry__ = {} # Simple dict - no lazy discovery needed (converters created dynamically below)
29 memory_type: str = None
31 @abstractmethod
32 def to_numpy(self, data, gpu_id):
33 """Extract to NumPy (type-specific implementation)."""
34 pass
36 @abstractmethod
37 def from_numpy(self, data, gpu_id):
38 """Create from NumPy (type-specific implementation)."""
39 pass
41 @abstractmethod
42 def from_dlpack(self, data, gpu_id):
43 """Create from DLPack capsule (type-specific implementation)."""
44 pass
46 @abstractmethod
47 def move_to_device(self, data, gpu_id):
48 """Move data to specified GPU device if needed (type-specific implementation)."""
49 pass
52def _ensure_module(memory_type: str):
53 """Import and return the module for the given memory type."""
54 from arraybridge.utils import _ensure_module as _ensure_module_impl
56 return _ensure_module_impl(memory_type)
59def _make_lambda_with_name(expr_str, mem_type, method_name):
60 """Create a lambda from expression string and add proper __name__ for debugging.
62 Note: Uses eval() for dynamic code generation from trusted framework_config.py strings.
63 This is safe because:
64 1. Input strings come from _FRAMEWORK_CONFIG, not user input
65 2. Strings are defined at module load time by package maintainers
66 3. This pattern enables declarative framework configuration
67 """
68 module_str = f'_ensure_module("{mem_type.value}")'
69 lambda_expr = f"lambda self, data, gpu_id: {expr_str.format(mod=module_str)}"
70 lambda_func = eval(lambda_expr)
71 lambda_func.__name__ = method_name
72 lambda_func.__qualname__ = f"{mem_type.value.capitalize()}Converter.{method_name}"
73 return lambda_func
76def _make_not_implemented(mem_type_value, method_name):
77 """Create a lambda that raises NotImplementedError with the correct signature."""
79 def not_impl(self, data, gpu_id):
80 raise NotImplementedError(f"DLPack not supported for {mem_type_value}")
82 not_impl.__name__ = method_name
83 not_impl.__qualname__ = f"{mem_type_value.capitalize()}Converter.{method_name}"
84 return not_impl
87# Auto-generate converter classes for each memory type
88def _create_converter_classes():
89 """Create concrete converter classes for each memory type."""
90 converters = {}
92 for mem_type in MemoryType:
93 config = _FRAMEWORK_CONFIG[mem_type]
94 conversion_ops = config["conversion_ops"]
96 # Build class attributes
97 class_attrs = {
98 "memory_type": mem_type.value,
99 }
101 # Add conversion methods
102 for method_name, expr in conversion_ops.items():
103 if expr is None:
104 class_attrs[method_name] = _make_not_implemented(mem_type.value, method_name)
105 else:
106 class_attrs[method_name] = _make_lambda_with_name(expr, mem_type, method_name)
108 # Create the class
109 class_name = f"{mem_type.value.capitalize()}Converter"
110 converter_class = type(class_name, (ConverterBase,), class_attrs)
112 converters[mem_type] = converter_class
114 return converters
117# Create all converter classes at module load time
118_CONVERTER_CLASSES = _create_converter_classes()
121def get_converter(memory_type: str):
122 """Get a converter instance for the given memory type.
124 Args:
125 memory_type: The memory type string (e.g., "numpy", "torch")
127 Returns:
128 A converter instance for the memory type
130 Raises:
131 ValueError: If memory type is not registered
132 """
133 converter_class = ConverterBase.__registry__.get(memory_type)
134 if converter_class is None:
135 raise ValueError(
136 f"No converter registered for memory type '{memory_type}'. "
137 f"Available types: {sorted(ConverterBase.__registry__.keys())}"
138 )
139 return converter_class()
142def _add_converter_methods():
143 """Add to_X() methods to ConverterBase.
145 For each target memory type, generates a method like to_cupy(), to_torch(), etc.
146 that tries GPU-to-GPU conversion via DLPack first, then falls back to CPU roundtrip.
147 """
148 from arraybridge.utils import _supports_dlpack
150 for target_type in MemoryType:
151 method_name = f"to_{target_type.value}"
153 def make_method(tgt):
154 def method(self, data, gpu_id):
155 # Try GPU-to-GPU first (DLPack)
156 if _supports_dlpack(data):
157 try:
158 target_converter = get_converter(tgt.value)
159 result = target_converter.from_dlpack(data, gpu_id)
160 return target_converter.move_to_device(result, gpu_id)
161 except Exception as e:
162 logger.warning(f"DLPack conversion failed: {e}. Using CPU roundtrip.")
164 # CPU roundtrip using polymorphism
165 numpy_data = self.to_numpy(data, gpu_id)
166 target_converter = get_converter(tgt.value)
167 return target_converter.from_numpy(numpy_data, gpu_id)
169 return method
171 setattr(ConverterBase, method_name, make_method(target_type))
174def _validate_registry():
175 """Validate that all memory types are registered."""
176 required_types = {mt.value for mt in MemoryType}
177 registered_types = set(ConverterBase.__registry__.keys())
179 if required_types != registered_types:
180 missing = required_types - registered_types
181 extra = registered_types - required_types
182 msg_parts = []
183 if missing:
184 msg_parts.append(f"Missing: {missing}")
185 if extra:
186 msg_parts.append(f"Extra: {extra}")
187 raise RuntimeError(f"Registry validation failed. {', '.join(msg_parts)}")
189 logger.debug(f"✅ Validated {len(registered_types)} memory type converters in registry")
192# Add to_X() conversion methods after converter classes are created
193_add_converter_methods()
195# Run validation at module load time
196_validate_registry()