Coverage for openhcs/core/memory/converters.py: 56.0%

19 statements  

« prev     ^ index     » next       coverage.py v7.11.0, created at 2025-11-04 02:09 +0000

1"""Memory conversion public API for OpenHCS.""" 

2 

3from typing import Any 

4import numpy as np 

5from openhcs.constants.constants import MemoryType 

6from openhcs.core.memory.conversion_helpers import _CONVERTERS 

7from openhcs.core.memory.framework_config import _FRAMEWORK_CONFIG 

8 

9 

10def convert_memory(data: Any, source_type: str, target_type: str, gpu_id: int) -> Any: 

11 """ 

12 Convert data between memory types using the unified converter infrastructure. 

13 

14 Args: 

15 data: The data to convert 

16 source_type: The source memory type (e.g., "numpy", "torch") 

17 target_type: The target memory type (e.g., "cupy", "jax") 

18 gpu_id: The target GPU device ID 

19 

20 Returns: 

21 The converted data in the target memory type 

22 

23 Raises: 

24 ValueError: If source_type or target_type is invalid 

25 MemoryConversionError: If conversion fails 

26 """ 

27 source_enum = MemoryType(source_type) 

28 converter = _CONVERTERS[source_enum] 

29 method = getattr(converter, f"to_{target_type}") 

30 return method(data, gpu_id) 

31 

32 

33def detect_memory_type(data: Any) -> str: 

34 """ 

35 Detect the memory type of data using framework config. 

36 

37 Args: 

38 data: The data to detect 

39 

40 Returns: 

41 The detected memory type string (e.g., "numpy", "torch") 

42 

43 Raises: 

44 ValueError: If memory type cannot be detected 

45 """ 

46 # NumPy special case (most common, check first) 

47 if isinstance(data, np.ndarray): 47 ↛ 51line 47 didn't jump to line 51 because the condition on line 47 was always true

48 return MemoryType.NUMPY.value 

49 

50 # Check all frameworks using their module names from config 

51 module_name = type(data).__module__ 

52 

53 for mem_type, config in _FRAMEWORK_CONFIG.items(): 

54 import_name = config['import_name'] 

55 # Check if module name starts with or contains the import name 

56 if module_name.startswith(import_name) or import_name in module_name: 

57 return mem_type.value 

58 

59 raise ValueError(f"Unknown memory type for {type(data)} (module: {module_name})")