Coverage for openhcs/core/memory/gpu_utils.py: 36.4%

70 statements  

« prev     ^ index     » next       coverage.py v7.10.3, created at 2025-08-14 05:57 +0000

1""" 

2GPU utility functions for OpenHCS. 

3 

4This module provides utility functions for checking GPU availability 

5across different frameworks (cupy, torch, tensorflow, jax). 

6 

7Doctrinal Clauses: 

8- Clause 88 — No Inferred Capabilities 

9- Clause 293 — GPU Pre-Declaration Enforcement 

10""" 

11 

12import logging 

13from typing import Optional 

14 

15from openhcs.core.utils import optional_import 

16 

17logger = logging.getLogger(__name__) 

18 

19 

20def check_cupy_gpu_available() -> Optional[int]: 

21 """ 

22 Check if cupy is available and can access a GPU. 

23 

24 Returns: 

25 GPU device ID if available, None otherwise 

26 """ 

27 cp = optional_import("cupy") 

28 if cp is None: 28 ↛ 29line 28 didn't jump to line 29 because the condition on line 28 was never true

29 logger.debug("Cupy not installed") 

30 return None 

31 

32 try: 

33 # Check if cupy is available and can access a GPU 

34 if cp.cuda.is_available(): 34 ↛ anywhereline 34 didn't jump anywhere: it always raised an exception.

35 # Get the current device ID 

36 device_id = cp.cuda.get_device_id() 

37 logger.debug("Cupy GPU available: device_id=%s", device_id) 

38 return device_id 

39 else: 

40 logger.debug("Cupy CUDA not available") 

41 return None 

42 except Exception as e: 

43 logger.debug("Error checking cupy GPU availability: %s", e) 

44 return None 

45 

46 

47def check_torch_gpu_available() -> Optional[int]: 

48 """ 

49 Check if torch is available and can access a GPU. 

50 

51 Returns: 

52 GPU device ID if available, None otherwise 

53 """ 

54 torch = optional_import("torch") 

55 if torch is None: 55 ↛ 56line 55 didn't jump to line 56 because the condition on line 55 was never true

56 logger.debug("Torch not installed") 

57 return None 

58 

59 try: 

60 # Check if torch is available and can access a GPU 

61 if torch.cuda.is_available(): 61 ↛ anywhereline 61 didn't jump anywhere: it always raised an exception.

62 # Get the current device ID 

63 device_id = torch.cuda.current_device() 

64 logger.debug("Torch GPU available: device_id=%s", device_id) 

65 return device_id 

66 else: 

67 logger.debug("Torch CUDA not available") 

68 return None 

69 except Exception as e: 

70 logger.debug("Error checking torch GPU availability: %s", e) 

71 return None 

72 

73 

74def check_tf_gpu_available() -> Optional[int]: 

75 """ 

76 Check if tensorflow is available and can access a GPU. 

77 

78 Returns: 

79 GPU device ID if available, None otherwise 

80 """ 

81 tf = optional_import("tensorflow") 

82 if tf is None: 82 ↛ 83line 82 didn't jump to line 83 because the condition on line 82 was never true

83 logger.debug("TensorFlow not installed") 

84 return None 

85 

86 try: 

87 # Check if tensorflow is available and can access a GPU 

88 gpus = tf.config.list_physical_devices('GPU') 

89 if gpus: 

90 # Get the first GPU device ID 

91 # TensorFlow doesn't have a direct way to get the CUDA device ID, 

92 # so we'll just use the index in the list 

93 device_id = 0 

94 logger.debug("TensorFlow GPU available: device_id=%s", device_id) 

95 return device_id 

96 else: 

97 logger.debug("TensorFlow GPU not available") 

98 return None 

99 except Exception as e: 

100 logger.debug("Error checking TensorFlow GPU availability: %s", e) 

101 return None 

102 

103 

104def check_jax_gpu_available() -> Optional[int]: 

105 """ 

106 Check if JAX is available and can access a GPU. 

107 

108 Returns: 

109 GPU device ID if available, None otherwise 

110 """ 

111 jax = optional_import("jax") 

112 if jax is None: 

113 logger.debug("JAX not installed") 

114 return None 

115 

116 try: 

117 # Check if JAX is available and can access a GPU 

118 devices = jax.devices() 

119 gpu_devices = [d for d in devices if d.platform == 'gpu'] 

120 

121 if gpu_devices: 

122 # Get the first GPU device ID 

123 # JAX device IDs are typically in the form 'gpu:0' 

124 device_str = str(gpu_devices[0]) 

125 if ':' in device_str: 

126 device_id = int(device_str.split(':')[-1]) 

127 else: 

128 # Default to 0 if we can't parse the device ID 

129 device_id = 0 

130 logger.debug("JAX GPU available: device_id=%s", device_id) 

131 return device_id 

132 else: 

133 logger.debug("JAX GPU not available") 

134 return None 

135 except Exception as e: 

136 logger.debug("Error checking JAX GPU availability: %s", e) 

137 return None