Coverage for src/arraybridge/converters.py: 100%

18 statements  

« prev     ^ index     » next       coverage.py v7.11.0, created at 2025-11-03 05:09 +0000

1"""Memory conversion public API for OpenHCS.""" 

2 

3from typing import Any 

4 

5import numpy as np 

6 

7from arraybridge.converters_registry import get_converter 

8from arraybridge.framework_config import _FRAMEWORK_CONFIG 

9from arraybridge.types import MemoryType 

10 

11 

12def convert_memory(data: Any, source_type: str, target_type: str, gpu_id: int) -> Any: 

13 """ 

14 Convert data between memory types using the unified converter infrastructure. 

15 

16 Args: 

17 data: The data to convert 

18 source_type: The source memory type (e.g., "numpy", "torch") 

19 target_type: The target memory type (e.g., "cupy", "jax") 

20 gpu_id: The target GPU device ID 

21 

22 Returns: 

23 The converted data in the target memory type 

24 

25 Raises: 

26 ValueError: If source_type or target_type is invalid 

27 MemoryConversionError: If conversion fails 

28 """ 

29 converter = get_converter(source_type) # Will raise ValueError if invalid 

30 method = getattr(converter, f"to_{target_type}") 

31 return method(data, gpu_id) 

32 

33 

34def detect_memory_type(data: Any) -> str: 

35 """ 

36 Detect the memory type of data using framework config. 

37 

38 Args: 

39 data: The data to detect 

40 

41 Returns: 

42 The detected memory type string (e.g., "numpy", "torch") 

43 

44 Raises: 

45 ValueError: If memory type cannot be detected 

46 """ 

47 # NumPy special case (most common, check first) 

48 if isinstance(data, np.ndarray): 

49 return MemoryType.NUMPY.value 

50 

51 # Check all frameworks using their module names from config 

52 module_name = type(data).__module__ 

53 

54 for mem_type, config in _FRAMEWORK_CONFIG.items(): 

55 import_name = config["import_name"] 

56 # Check if module name starts with or contains the import name 

57 if module_name.startswith(import_name) or import_name in module_name: 

58 return mem_type.value 

59 

60 raise ValueError(f"Unknown memory type for {type(data)} (module: {module_name})")