Coverage for src/polystore/lazy_imports.py: 76%
71 statements
« prev ^ index » next coverage.py v7.11.0, created at 2025-11-03 06:58 +0000
« prev ^ index » next coverage.py v7.11.0, created at 2025-11-03 06:58 +0000
1"""
2Lazy imports for optional GPU frameworks.
4This module provides lazy imports for GPU frameworks (PyTorch, JAX, TensorFlow, CuPy)
5to avoid importing heavy dependencies unless they are actually used.
6"""
8import os
9from typing import Optional
11# Module-level cache for lazy imports
12_torch = None
13_jax = None
14_jnp = None
15_cupy = None
16_tf = None
17_imports_attempted = False
20def _attempt_imports():
21 """Attempt to import GPU frameworks once."""
22 global _torch, _jax, _jnp, _cupy, _tf, _imports_attempted
24 if _imports_attempted:
25 return
27 _imports_attempted = True
29 # Skip GPU libraries if running in no-GPU mode
30 if os.getenv('POLYSTORE_NO_GPU') == '1':
31 return
33 # PyTorch
34 try:
35 import torch as _torch_module
36 _torch = _torch_module
37 except ImportError:
38 pass
40 # JAX
41 try:
42 import jax as _jax_module
43 import jax.numpy as _jnp_module
44 _jax = _jax_module
45 _jnp = _jnp_module
46 except ImportError:
47 pass
49 # CuPy
50 try:
51 import cupy as _cupy_module
52 _cupy = _cupy_module
53 except ImportError:
54 pass
56 # TensorFlow
57 try:
58 import tensorflow as _tf_module
59 _tf = _tf_module
60 except ImportError:
61 pass
64@property
65def torch():
66 """Lazy import PyTorch."""
67 _attempt_imports()
68 return _torch
71@property
72def jax():
73 """Lazy import JAX."""
74 _attempt_imports()
75 return _jax
78@property
79def jnp():
80 """Lazy import JAX NumPy."""
81 _attempt_imports()
82 return _jnp
85@property
86def cupy():
87 """Lazy import CuPy."""
88 _attempt_imports()
89 return _cupy
92@property
93def tf():
94 """Lazy import TensorFlow."""
95 _attempt_imports()
96 return _tf
99# Simple function-based API
100def get_torch():
101 """Get PyTorch module if available."""
102 _attempt_imports()
103 return _torch
106def get_jax():
107 """Get JAX module if available."""
108 _attempt_imports()
109 return _jax
112def get_jnp():
113 """Get JAX NumPy module if available."""
114 _attempt_imports()
115 return _jnp
118def get_cupy():
119 """Get CuPy module if available."""
120 _attempt_imports()
121 return _cupy
124def get_tf():
125 """Get TensorFlow module if available."""
126 _attempt_imports()
127 return _tf