Coverage for src / arraybridge / types.py: 100%

22 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-01-10 22:33 +0000

1""" 

2Memory type definitions for arraybridge. 

3 

4This module defines the MemoryType enum and related constants for managing 

5different array/tensor frameworks. 

6""" 

7 

8from enum import Enum 

9from typing import Any, Callable, TypeVar 

10 

11T = TypeVar("T") 

12ConversionFunc = Callable[[Any], Any] 

13 

14 

15class MemoryType(Enum): 

16 """Enum representing different array/tensor framework types.""" 

17 

18 NUMPY = "numpy" 

19 CUPY = "cupy" 

20 TORCH = "torch" 

21 TENSORFLOW = "tensorflow" 

22 JAX = "jax" 

23 PYCLESPERANTO = "pyclesperanto" 

24 

25 

26# Memory type sets 

27CPU_MEMORY_TYPES: set[MemoryType] = {MemoryType.NUMPY} 

28GPU_MEMORY_TYPES: set[MemoryType] = { 

29 MemoryType.CUPY, 

30 MemoryType.TORCH, 

31 MemoryType.TENSORFLOW, 

32 MemoryType.JAX, 

33 MemoryType.PYCLESPERANTO, 

34} 

35SUPPORTED_MEMORY_TYPES: set[MemoryType] = CPU_MEMORY_TYPES | GPU_MEMORY_TYPES 

36 

37# String value sets for validation 

38VALID_MEMORY_TYPES = {mt.value for mt in MemoryType} 

39VALID_GPU_MEMORY_TYPES = {mt.value for mt in GPU_MEMORY_TYPES} 

40 

41# Memory type constants for direct access 

42MEMORY_TYPE_NUMPY = MemoryType.NUMPY.value 

43MEMORY_TYPE_CUPY = MemoryType.CUPY.value 

44MEMORY_TYPE_TORCH = MemoryType.TORCH.value 

45MEMORY_TYPE_TENSORFLOW = MemoryType.TENSORFLOW.value 

46MEMORY_TYPE_JAX = MemoryType.JAX.value 

47MEMORY_TYPE_PYCLESPERANTO = MemoryType.PYCLESPERANTO.value