Coverage for src/arraybridge/converters.py: 100%
18 statements
« prev ^ index » next coverage.py v7.11.0, created at 2025-11-03 05:09 +0000
« prev ^ index » next coverage.py v7.11.0, created at 2025-11-03 05:09 +0000
1"""Memory conversion public API for OpenHCS."""
3from typing import Any
5import numpy as np
7from arraybridge.converters_registry import get_converter
8from arraybridge.framework_config import _FRAMEWORK_CONFIG
9from arraybridge.types import MemoryType
12def convert_memory(data: Any, source_type: str, target_type: str, gpu_id: int) -> Any:
13 """
14 Convert data between memory types using the unified converter infrastructure.
16 Args:
17 data: The data to convert
18 source_type: The source memory type (e.g., "numpy", "torch")
19 target_type: The target memory type (e.g., "cupy", "jax")
20 gpu_id: The target GPU device ID
22 Returns:
23 The converted data in the target memory type
25 Raises:
26 ValueError: If source_type or target_type is invalid
27 MemoryConversionError: If conversion fails
28 """
29 converter = get_converter(source_type) # Will raise ValueError if invalid
30 method = getattr(converter, f"to_{target_type}")
31 return method(data, gpu_id)
34def detect_memory_type(data: Any) -> str:
35 """
36 Detect the memory type of data using framework config.
38 Args:
39 data: The data to detect
41 Returns:
42 The detected memory type string (e.g., "numpy", "torch")
44 Raises:
45 ValueError: If memory type cannot be detected
46 """
47 # NumPy special case (most common, check first)
48 if isinstance(data, np.ndarray):
49 return MemoryType.NUMPY.value
51 # Check all frameworks using their module names from config
52 module_name = type(data).__module__
54 for mem_type, config in _FRAMEWORK_CONFIG.items():
55 import_name = config["import_name"]
56 # Check if module name starts with or contains the import name
57 if module_name.startswith(import_name) or import_name in module_name:
58 return mem_type.value
60 raise ValueError(f"Unknown memory type for {type(data)} (module: {module_name})")