Coverage for openhcs/core/lazy_gpu_imports.py: 65.5%

65 statements  

« prev     ^ index     » next       coverage.py v7.11.0, created at 2025-11-04 02:09 +0000

1""" 

2Lazy GPU import system. 

3 

4Defers GPU library imports until first use to eliminate startup delay. 

5Supports fast installation checking without imports. 

6""" 

7 

8import importlib 

9import importlib.util 

10import logging 

11import threading 

12from typing import Any, Dict, Optional, Tuple, Callable 

13 

14logger = logging.getLogger(__name__) 

15 

16 

17# GPU check functions - explicit, fail-loud implementations 

18def _check_cuda_available(lib) -> bool: 

19 """Check CUDA availability (torch/cupy pattern).""" 

20 return lib.cuda.is_available() 

21 

22 

23def _check_jax_gpu(lib) -> bool: 

24 """Check JAX GPU availability. 

25 

26 Uses lazy detection: only checks if JAX is installed, defers actual 

27 jax.devices() call to avoid thread explosion during startup. 

28 Returns True if JAX is installed (actual GPU check happens at runtime). 

29 """ 

30 # JAX is installed - assume GPU availability will be checked at runtime 

31 # This avoids calling jax.devices() which creates 54+ threads 

32 return True 

33 

34 

35def _check_tf_gpu(lib) -> bool: 

36 """Check TensorFlow GPU availability.""" 

37 gpus = lib.config.list_physical_devices('GPU') 

38 return len(gpus) > 0 

39 

40 

41# GPU library registry 

42# Format: (module_name, submodule, gpu_check_func, get_device_id_func) 

43GPU_LIBRARY_REGISTRY: Dict[str, Tuple[str, Optional[str], Optional[Callable], Optional[Callable]]] = { 

44 'torch': ('torch', None, _check_cuda_available, lambda lib: lib.cuda.current_device()), 

45 'cupy': ('cupy', None, _check_cuda_available, lambda lib: lib.cuda.get_device_id()), 

46 'jax': ('jax', None, _check_jax_gpu, lambda lib: 0), 

47 'tensorflow': ('tensorflow', None, _check_tf_gpu, lambda lib: 0), 

48 'jnp': ('jax', 'numpy', None, None), 

49 'pyclesperanto': ('pyclesperanto', None, None, None), 

50} 

51 

52 

53class _LazyGPUModule: 

54 """Lazy proxy for GPU libraries - imports on first attribute access.""" 

55 

56 def __init__(self, name: str): 

57 self._name = name 

58 module_name, submodule, _, _ = GPU_LIBRARY_REGISTRY[name] 

59 self._module_name = module_name 

60 self._submodule = submodule 

61 self._module = None 

62 self._lock = threading.Lock() 

63 self._imported = False 

64 

65 # Fast installation check (no import) 

66 self._installed = importlib.util.find_spec(module_name) is not None 

67 

68 def is_installed(self) -> bool: 

69 """Check if installed without importing.""" 

70 return self._installed 

71 

72 def _ensure_imported(self) -> Any: 

73 """ 

74 Import module if needed (thread-safe). 

75  

76 FAIL LOUD: No try-except. Let import errors propagate. 

77 """ 

78 if not self._imported: 

79 with self._lock: 

80 if not self._imported: 80 ↛ 98line 80 didn't jump to line 98

81 if not self._installed: 81 ↛ 87line 81 didn't jump to line 87 because the condition on line 81 was always true

82 # Not installed - return None (expected case) 

83 self._imported = True 

84 return None 

85 

86 # Import the module - FAIL LOUD if import fails 

87 self._module = importlib.import_module(self._module_name) 

88 logger.debug(f"Lazy-imported {self._module_name}") 

89 

90 # Navigate to submodule if specified 

91 if self._submodule: 

92 for attr in self._submodule.split('.'): 

93 self._module = getattr(self._module, attr) 

94 # FAIL LOUD: getattr raises AttributeError if missing 

95 

96 self._imported = True 

97 

98 return self._module 

99 

100 def __getattr__(self, name: str) -> Any: 

101 """ 

102 Lazy import on attribute access. 

103  

104 FAIL LOUD: Raises ImportError if not installed, AttributeError if attribute missing. 

105 """ 

106 module = self._ensure_imported() 

107 if module is None: 107 ↛ 113line 107 didn't jump to line 113 because the condition on line 107 was always true

108 raise ImportError( 

109 f"Module '{self._module_name}' is not installed. " 

110 f"Install it to use {self._name}.{name}" 

111 ) 

112 # FAIL LOUD: getattr raises AttributeError if name doesn't exist 

113 return getattr(module, name) 

114 

115 def __bool__(self) -> bool: 

116 """ 

117 Allow truthiness checks. 

118  

119 Returns False if not installed, True if installed and imports successfully. 

120 FAIL LOUD: Propagates import errors. 

121 """ 

122 module = self._ensure_imported() 

123 return module is not None 

124 

125 

126# Auto-generate lazy proxies from registry 

127for _name in GPU_LIBRARY_REGISTRY.keys(): 

128 globals()[_name] = _LazyGPUModule(_name) 

129 

130# Alias tf -> tensorflow for compatibility 

131tf = globals()['tensorflow'] 

132 

133 

134def check_installed_gpu_libraries() -> Dict[str, bool]: 

135 """ 

136 Check which GPU libraries are installed without importing them. 

137  

138 Fast (~0.001s per library). No imports, just filesystem checks. 

139 """ 

140 return { 

141 name: importlib.util.find_spec(module_name) is not None 

142 for name, (module_name, _, _, _) in GPU_LIBRARY_REGISTRY.items() 

143 } 

144 

145 

146def check_gpu_capability(library_name: str) -> Optional[int]: 

147 """ 

148 Check GPU capability for a library (lazy import). 

149  

150 FAIL LOUD: Propagates import errors and attribute errors. 

151 Only returns None if library not installed or has no GPU. 

152  

153 Args: 

154 library_name: Name from GPU_LIBRARY_REGISTRY 

155  

156 Returns: 

157 Device ID if GPU available, None otherwise 

158 """ 

159 if library_name not in GPU_LIBRARY_REGISTRY: 159 ↛ 160line 159 didn't jump to line 160 because the condition on line 159 was never true

160 raise ValueError(f"Unknown GPU library: {library_name}") 

161 

162 _, _, gpu_check, get_device_id = GPU_LIBRARY_REGISTRY[library_name] 

163 

164 # No GPU check defined for this library 

165 if gpu_check is None: 165 ↛ 166line 165 didn't jump to line 166 because the condition on line 165 was never true

166 return None 

167 

168 # Get lazy module (imports if needed) - FAIL LOUD 

169 lib = globals()[library_name] 

170 

171 # Not installed (expected case) 

172 if not lib: 172 ↛ 176line 172 didn't jump to line 176 because the condition on line 172 was always true

173 return None 

174 

175 # Check GPU availability - FAIL LOUD if check function fails 

176 if gpu_check(lib): 

177 return get_device_id(lib) 

178 

179 return None 

180