Coverage for src / arraybridge / types.py: 100%
22 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-10 22:33 +0000
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-10 22:33 +0000
1"""
2Memory type definitions for arraybridge.
4This module defines the MemoryType enum and related constants for managing
5different array/tensor frameworks.
6"""
8from enum import Enum
9from typing import Any, Callable, TypeVar
11T = TypeVar("T")
12ConversionFunc = Callable[[Any], Any]
15class MemoryType(Enum):
16 """Enum representing different array/tensor framework types."""
18 NUMPY = "numpy"
19 CUPY = "cupy"
20 TORCH = "torch"
21 TENSORFLOW = "tensorflow"
22 JAX = "jax"
23 PYCLESPERANTO = "pyclesperanto"
26# Memory type sets
27CPU_MEMORY_TYPES: set[MemoryType] = {MemoryType.NUMPY}
28GPU_MEMORY_TYPES: set[MemoryType] = {
29 MemoryType.CUPY,
30 MemoryType.TORCH,
31 MemoryType.TENSORFLOW,
32 MemoryType.JAX,
33 MemoryType.PYCLESPERANTO,
34}
35SUPPORTED_MEMORY_TYPES: set[MemoryType] = CPU_MEMORY_TYPES | GPU_MEMORY_TYPES
37# String value sets for validation
38VALID_MEMORY_TYPES = {mt.value for mt in MemoryType}
39VALID_GPU_MEMORY_TYPES = {mt.value for mt in GPU_MEMORY_TYPES}
41# Memory type constants for direct access
42MEMORY_TYPE_NUMPY = MemoryType.NUMPY.value
43MEMORY_TYPE_CUPY = MemoryType.CUPY.value
44MEMORY_TYPE_TORCH = MemoryType.TORCH.value
45MEMORY_TYPE_TENSORFLOW = MemoryType.TENSORFLOW.value
46MEMORY_TYPE_JAX = MemoryType.JAX.value
47MEMORY_TYPE_PYCLESPERANTO = MemoryType.PYCLESPERANTO.value