Coverage for src/arraybridge/types.py: 97%

35 statements  

« prev     ^ index     » next       coverage.py v7.11.0, created at 2025-11-03 05:09 +0000

1""" 

2Memory type definitions for arraybridge. 

3 

4This module defines the MemoryType enum and related constants for managing 

5different array/tensor frameworks. 

6""" 

7 

8from enum import Enum 

9from typing import Any, Callable, TypeVar 

10 

11T = TypeVar("T") 

12ConversionFunc = Callable[[Any], Any] 

13 

14 

15class MemoryType(Enum): 

16 """Enum representing different array/tensor framework types.""" 

17 

18 NUMPY = "numpy" 

19 CUPY = "cupy" 

20 TORCH = "torch" 

21 TENSORFLOW = "tensorflow" 

22 JAX = "jax" 

23 PYCLESPERANTO = "pyclesperanto" 

24 

25 @property 

26 def converter(self): 

27 """Get the converter instance for this memory type.""" 

28 from arraybridge.converters_registry import get_converter 

29 

30 return get_converter(self.value) 

31 

32 

33# Auto-generate to_X() methods on enum 

34def _add_conversion_methods(): 

35 """Add to_X() conversion methods to MemoryType enum.""" 

36 for target_type in MemoryType: 

37 method_name = f"to_{target_type.value}" 

38 

39 def make_method(target): 

40 def method(self, data, gpu_id): 

41 return getattr(self.converter, f"to_{target.value}")(data, gpu_id) 

42 

43 return method 

44 

45 setattr(MemoryType, method_name, make_method(target_type)) 

46 

47 

48_add_conversion_methods() 

49 

50 

51# Memory type sets 

52CPU_MEMORY_TYPES: set[MemoryType] = {MemoryType.NUMPY} 

53GPU_MEMORY_TYPES: set[MemoryType] = { 

54 MemoryType.CUPY, 

55 MemoryType.TORCH, 

56 MemoryType.TENSORFLOW, 

57 MemoryType.JAX, 

58 MemoryType.PYCLESPERANTO, 

59} 

60SUPPORTED_MEMORY_TYPES: set[MemoryType] = CPU_MEMORY_TYPES | GPU_MEMORY_TYPES 

61 

62# String value sets for validation 

63VALID_MEMORY_TYPES = {mt.value for mt in MemoryType} 

64VALID_GPU_MEMORY_TYPES = {mt.value for mt in GPU_MEMORY_TYPES} 

65 

66# Memory type constants for direct access 

67MEMORY_TYPE_NUMPY = MemoryType.NUMPY.value 

68MEMORY_TYPE_CUPY = MemoryType.CUPY.value 

69MEMORY_TYPE_TORCH = MemoryType.TORCH.value 

70MEMORY_TYPE_TENSORFLOW = MemoryType.TENSORFLOW.value 

71MEMORY_TYPE_JAX = MemoryType.JAX.value 

72MEMORY_TYPE_PYCLESPERANTO = MemoryType.PYCLESPERANTO.value