Coverage for openhcs/core/memory/converters.py: 56.0%
19 statements
« prev ^ index » next coverage.py v7.11.0, created at 2025-11-04 02:09 +0000
« prev ^ index » next coverage.py v7.11.0, created at 2025-11-04 02:09 +0000
1"""Memory conversion public API for OpenHCS."""
3from typing import Any
4import numpy as np
5from openhcs.constants.constants import MemoryType
6from openhcs.core.memory.conversion_helpers import _CONVERTERS
7from openhcs.core.memory.framework_config import _FRAMEWORK_CONFIG
10def convert_memory(data: Any, source_type: str, target_type: str, gpu_id: int) -> Any:
11 """
12 Convert data between memory types using the unified converter infrastructure.
14 Args:
15 data: The data to convert
16 source_type: The source memory type (e.g., "numpy", "torch")
17 target_type: The target memory type (e.g., "cupy", "jax")
18 gpu_id: The target GPU device ID
20 Returns:
21 The converted data in the target memory type
23 Raises:
24 ValueError: If source_type or target_type is invalid
25 MemoryConversionError: If conversion fails
26 """
27 source_enum = MemoryType(source_type)
28 converter = _CONVERTERS[source_enum]
29 method = getattr(converter, f"to_{target_type}")
30 return method(data, gpu_id)
33def detect_memory_type(data: Any) -> str:
34 """
35 Detect the memory type of data using framework config.
37 Args:
38 data: The data to detect
40 Returns:
41 The detected memory type string (e.g., "numpy", "torch")
43 Raises:
44 ValueError: If memory type cannot be detected
45 """
46 # NumPy special case (most common, check first)
47 if isinstance(data, np.ndarray): 47 ↛ 51line 47 didn't jump to line 51 because the condition on line 47 was always true
48 return MemoryType.NUMPY.value
50 # Check all frameworks using their module names from config
51 module_name = type(data).__module__
53 for mem_type, config in _FRAMEWORK_CONFIG.items():
54 import_name = config['import_name']
55 # Check if module name starts with or contains the import name
56 if module_name.startswith(import_name) or import_name in module_name:
57 return mem_type.value
59 raise ValueError(f"Unknown memory type for {type(data)} (module: {module_name})")