Coverage for openhcs/core/lazy_gpu_imports.py: 65.5%
65 statements
« prev ^ index » next coverage.py v7.11.0, created at 2025-11-04 02:09 +0000
« prev ^ index » next coverage.py v7.11.0, created at 2025-11-04 02:09 +0000
1"""
2Lazy GPU import system.
4Defers GPU library imports until first use to eliminate startup delay.
5Supports fast installation checking without imports.
6"""
8import importlib
9import importlib.util
10import logging
11import threading
12from typing import Any, Dict, Optional, Tuple, Callable
14logger = logging.getLogger(__name__)
17# GPU check functions - explicit, fail-loud implementations
18def _check_cuda_available(lib) -> bool:
19 """Check CUDA availability (torch/cupy pattern)."""
20 return lib.cuda.is_available()
23def _check_jax_gpu(lib) -> bool:
24 """Check JAX GPU availability.
26 Uses lazy detection: only checks if JAX is installed, defers actual
27 jax.devices() call to avoid thread explosion during startup.
28 Returns True if JAX is installed (actual GPU check happens at runtime).
29 """
30 # JAX is installed - assume GPU availability will be checked at runtime
31 # This avoids calling jax.devices() which creates 54+ threads
32 return True
35def _check_tf_gpu(lib) -> bool:
36 """Check TensorFlow GPU availability."""
37 gpus = lib.config.list_physical_devices('GPU')
38 return len(gpus) > 0
41# GPU library registry
42# Format: (module_name, submodule, gpu_check_func, get_device_id_func)
43GPU_LIBRARY_REGISTRY: Dict[str, Tuple[str, Optional[str], Optional[Callable], Optional[Callable]]] = {
44 'torch': ('torch', None, _check_cuda_available, lambda lib: lib.cuda.current_device()),
45 'cupy': ('cupy', None, _check_cuda_available, lambda lib: lib.cuda.get_device_id()),
46 'jax': ('jax', None, _check_jax_gpu, lambda lib: 0),
47 'tensorflow': ('tensorflow', None, _check_tf_gpu, lambda lib: 0),
48 'jnp': ('jax', 'numpy', None, None),
49 'pyclesperanto': ('pyclesperanto', None, None, None),
50}
53class _LazyGPUModule:
54 """Lazy proxy for GPU libraries - imports on first attribute access."""
56 def __init__(self, name: str):
57 self._name = name
58 module_name, submodule, _, _ = GPU_LIBRARY_REGISTRY[name]
59 self._module_name = module_name
60 self._submodule = submodule
61 self._module = None
62 self._lock = threading.Lock()
63 self._imported = False
65 # Fast installation check (no import)
66 self._installed = importlib.util.find_spec(module_name) is not None
68 def is_installed(self) -> bool:
69 """Check if installed without importing."""
70 return self._installed
72 def _ensure_imported(self) -> Any:
73 """
74 Import module if needed (thread-safe).
76 FAIL LOUD: No try-except. Let import errors propagate.
77 """
78 if not self._imported:
79 with self._lock:
80 if not self._imported: 80 ↛ 98line 80 didn't jump to line 98
81 if not self._installed: 81 ↛ 87line 81 didn't jump to line 87 because the condition on line 81 was always true
82 # Not installed - return None (expected case)
83 self._imported = True
84 return None
86 # Import the module - FAIL LOUD if import fails
87 self._module = importlib.import_module(self._module_name)
88 logger.debug(f"Lazy-imported {self._module_name}")
90 # Navigate to submodule if specified
91 if self._submodule:
92 for attr in self._submodule.split('.'):
93 self._module = getattr(self._module, attr)
94 # FAIL LOUD: getattr raises AttributeError if missing
96 self._imported = True
98 return self._module
100 def __getattr__(self, name: str) -> Any:
101 """
102 Lazy import on attribute access.
104 FAIL LOUD: Raises ImportError if not installed, AttributeError if attribute missing.
105 """
106 module = self._ensure_imported()
107 if module is None: 107 ↛ 113line 107 didn't jump to line 113 because the condition on line 107 was always true
108 raise ImportError(
109 f"Module '{self._module_name}' is not installed. "
110 f"Install it to use {self._name}.{name}"
111 )
112 # FAIL LOUD: getattr raises AttributeError if name doesn't exist
113 return getattr(module, name)
115 def __bool__(self) -> bool:
116 """
117 Allow truthiness checks.
119 Returns False if not installed, True if installed and imports successfully.
120 FAIL LOUD: Propagates import errors.
121 """
122 module = self._ensure_imported()
123 return module is not None
126# Auto-generate lazy proxies from registry
127for _name in GPU_LIBRARY_REGISTRY.keys():
128 globals()[_name] = _LazyGPUModule(_name)
130# Alias tf -> tensorflow for compatibility
131tf = globals()['tensorflow']
134def check_installed_gpu_libraries() -> Dict[str, bool]:
135 """
136 Check which GPU libraries are installed without importing them.
138 Fast (~0.001s per library). No imports, just filesystem checks.
139 """
140 return {
141 name: importlib.util.find_spec(module_name) is not None
142 for name, (module_name, _, _, _) in GPU_LIBRARY_REGISTRY.items()
143 }
146def check_gpu_capability(library_name: str) -> Optional[int]:
147 """
148 Check GPU capability for a library (lazy import).
150 FAIL LOUD: Propagates import errors and attribute errors.
151 Only returns None if library not installed or has no GPU.
153 Args:
154 library_name: Name from GPU_LIBRARY_REGISTRY
156 Returns:
157 Device ID if GPU available, None otherwise
158 """
159 if library_name not in GPU_LIBRARY_REGISTRY: 159 ↛ 160line 159 didn't jump to line 160 because the condition on line 159 was never true
160 raise ValueError(f"Unknown GPU library: {library_name}")
162 _, _, gpu_check, get_device_id = GPU_LIBRARY_REGISTRY[library_name]
164 # No GPU check defined for this library
165 if gpu_check is None: 165 ↛ 166line 165 didn't jump to line 166 because the condition on line 165 was never true
166 return None
168 # Get lazy module (imports if needed) - FAIL LOUD
169 lib = globals()[library_name]
171 # Not installed (expected case)
172 if not lib: 172 ↛ 176line 172 didn't jump to line 176 because the condition on line 172 was always true
173 return None
175 # Check GPU availability - FAIL LOUD if check function fails
176 if gpu_check(lib):
177 return get_device_id(lib)
179 return None