Coverage for src/arraybridge/types.py: 97%
35 statements
« prev ^ index » next coverage.py v7.11.0, created at 2025-11-03 05:09 +0000
« prev ^ index » next coverage.py v7.11.0, created at 2025-11-03 05:09 +0000
1"""
2Memory type definitions for arraybridge.
4This module defines the MemoryType enum and related constants for managing
5different array/tensor frameworks.
6"""
8from enum import Enum
9from typing import Any, Callable, TypeVar
11T = TypeVar("T")
12ConversionFunc = Callable[[Any], Any]
15class MemoryType(Enum):
16 """Enum representing different array/tensor framework types."""
18 NUMPY = "numpy"
19 CUPY = "cupy"
20 TORCH = "torch"
21 TENSORFLOW = "tensorflow"
22 JAX = "jax"
23 PYCLESPERANTO = "pyclesperanto"
25 @property
26 def converter(self):
27 """Get the converter instance for this memory type."""
28 from arraybridge.converters_registry import get_converter
30 return get_converter(self.value)
33# Auto-generate to_X() methods on enum
34def _add_conversion_methods():
35 """Add to_X() conversion methods to MemoryType enum."""
36 for target_type in MemoryType:
37 method_name = f"to_{target_type.value}"
39 def make_method(target):
40 def method(self, data, gpu_id):
41 return getattr(self.converter, f"to_{target.value}")(data, gpu_id)
43 return method
45 setattr(MemoryType, method_name, make_method(target_type))
48_add_conversion_methods()
51# Memory type sets
52CPU_MEMORY_TYPES: set[MemoryType] = {MemoryType.NUMPY}
53GPU_MEMORY_TYPES: set[MemoryType] = {
54 MemoryType.CUPY,
55 MemoryType.TORCH,
56 MemoryType.TENSORFLOW,
57 MemoryType.JAX,
58 MemoryType.PYCLESPERANTO,
59}
60SUPPORTED_MEMORY_TYPES: set[MemoryType] = CPU_MEMORY_TYPES | GPU_MEMORY_TYPES
62# String value sets for validation
63VALID_MEMORY_TYPES = {mt.value for mt in MemoryType}
64VALID_GPU_MEMORY_TYPES = {mt.value for mt in GPU_MEMORY_TYPES}
66# Memory type constants for direct access
67MEMORY_TYPE_NUMPY = MemoryType.NUMPY.value
68MEMORY_TYPE_CUPY = MemoryType.CUPY.value
69MEMORY_TYPE_TORCH = MemoryType.TORCH.value
70MEMORY_TYPE_TENSORFLOW = MemoryType.TENSORFLOW.value
71MEMORY_TYPE_JAX = MemoryType.JAX.value
72MEMORY_TYPE_PYCLESPERANTO = MemoryType.PYCLESPERANTO.value