Coverage for openhcs/core/memory/gpu_utils.py: 36.4%
70 statements
« prev ^ index » next coverage.py v7.10.3, created at 2025-08-14 05:57 +0000
« prev ^ index » next coverage.py v7.10.3, created at 2025-08-14 05:57 +0000
1"""
2GPU utility functions for OpenHCS.
4This module provides utility functions for checking GPU availability
5across different frameworks (cupy, torch, tensorflow, jax).
7Doctrinal Clauses:
8- Clause 88 — No Inferred Capabilities
9- Clause 293 — GPU Pre-Declaration Enforcement
10"""
12import logging
13from typing import Optional
15from openhcs.core.utils import optional_import
17logger = logging.getLogger(__name__)
20def check_cupy_gpu_available() -> Optional[int]:
21 """
22 Check if cupy is available and can access a GPU.
24 Returns:
25 GPU device ID if available, None otherwise
26 """
27 cp = optional_import("cupy")
28 if cp is None: 28 ↛ 29line 28 didn't jump to line 29 because the condition on line 28 was never true
29 logger.debug("Cupy not installed")
30 return None
32 try:
33 # Check if cupy is available and can access a GPU
34 if cp.cuda.is_available(): 34 ↛ anywhereline 34 didn't jump anywhere: it always raised an exception.
35 # Get the current device ID
36 device_id = cp.cuda.get_device_id()
37 logger.debug("Cupy GPU available: device_id=%s", device_id)
38 return device_id
39 else:
40 logger.debug("Cupy CUDA not available")
41 return None
42 except Exception as e:
43 logger.debug("Error checking cupy GPU availability: %s", e)
44 return None
47def check_torch_gpu_available() -> Optional[int]:
48 """
49 Check if torch is available and can access a GPU.
51 Returns:
52 GPU device ID if available, None otherwise
53 """
54 torch = optional_import("torch")
55 if torch is None: 55 ↛ 56line 55 didn't jump to line 56 because the condition on line 55 was never true
56 logger.debug("Torch not installed")
57 return None
59 try:
60 # Check if torch is available and can access a GPU
61 if torch.cuda.is_available(): 61 ↛ anywhereline 61 didn't jump anywhere: it always raised an exception.
62 # Get the current device ID
63 device_id = torch.cuda.current_device()
64 logger.debug("Torch GPU available: device_id=%s", device_id)
65 return device_id
66 else:
67 logger.debug("Torch CUDA not available")
68 return None
69 except Exception as e:
70 logger.debug("Error checking torch GPU availability: %s", e)
71 return None
74def check_tf_gpu_available() -> Optional[int]:
75 """
76 Check if tensorflow is available and can access a GPU.
78 Returns:
79 GPU device ID if available, None otherwise
80 """
81 tf = optional_import("tensorflow")
82 if tf is None: 82 ↛ 83line 82 didn't jump to line 83 because the condition on line 82 was never true
83 logger.debug("TensorFlow not installed")
84 return None
86 try:
87 # Check if tensorflow is available and can access a GPU
88 gpus = tf.config.list_physical_devices('GPU')
89 if gpus:
90 # Get the first GPU device ID
91 # TensorFlow doesn't have a direct way to get the CUDA device ID,
92 # so we'll just use the index in the list
93 device_id = 0
94 logger.debug("TensorFlow GPU available: device_id=%s", device_id)
95 return device_id
96 else:
97 logger.debug("TensorFlow GPU not available")
98 return None
99 except Exception as e:
100 logger.debug("Error checking TensorFlow GPU availability: %s", e)
101 return None
104def check_jax_gpu_available() -> Optional[int]:
105 """
106 Check if JAX is available and can access a GPU.
108 Returns:
109 GPU device ID if available, None otherwise
110 """
111 jax = optional_import("jax")
112 if jax is None:
113 logger.debug("JAX not installed")
114 return None
116 try:
117 # Check if JAX is available and can access a GPU
118 devices = jax.devices()
119 gpu_devices = [d for d in devices if d.platform == 'gpu']
121 if gpu_devices:
122 # Get the first GPU device ID
123 # JAX device IDs are typically in the form 'gpu:0'
124 device_str = str(gpu_devices[0])
125 if ':' in device_str:
126 device_id = int(device_str.split(':')[-1])
127 else:
128 # Default to 0 if we can't parse the device ID
129 device_id = 0
130 logger.debug("JAX GPU available: device_id=%s", device_id)
131 return device_id
132 else:
133 logger.debug("JAX GPU not available")
134 return None
135 except Exception as e:
136 logger.debug("Error checking JAX GPU availability: %s", e)
137 return None