Coverage for src/polystore/lazy_imports.py: 76%

71 statements  

« prev     ^ index     » next       coverage.py v7.11.0, created at 2025-11-03 06:58 +0000

1""" 

2Lazy imports for optional GPU frameworks. 

3 

4This module provides lazy imports for GPU frameworks (PyTorch, JAX, TensorFlow, CuPy) 

5to avoid importing heavy dependencies unless they are actually used. 

6""" 

7 

8import os 

9from typing import Optional 

10 

11# Module-level cache for lazy imports 

12_torch = None 

13_jax = None 

14_jnp = None 

15_cupy = None 

16_tf = None 

17_imports_attempted = False 

18 

19 

20def _attempt_imports(): 

21 """Attempt to import GPU frameworks once.""" 

22 global _torch, _jax, _jnp, _cupy, _tf, _imports_attempted 

23 

24 if _imports_attempted: 

25 return 

26 

27 _imports_attempted = True 

28 

29 # Skip GPU libraries if running in no-GPU mode 

30 if os.getenv('POLYSTORE_NO_GPU') == '1': 

31 return 

32 

33 # PyTorch 

34 try: 

35 import torch as _torch_module 

36 _torch = _torch_module 

37 except ImportError: 

38 pass 

39 

40 # JAX 

41 try: 

42 import jax as _jax_module 

43 import jax.numpy as _jnp_module 

44 _jax = _jax_module 

45 _jnp = _jnp_module 

46 except ImportError: 

47 pass 

48 

49 # CuPy 

50 try: 

51 import cupy as _cupy_module 

52 _cupy = _cupy_module 

53 except ImportError: 

54 pass 

55 

56 # TensorFlow 

57 try: 

58 import tensorflow as _tf_module 

59 _tf = _tf_module 

60 except ImportError: 

61 pass 

62 

63 

64@property 

65def torch(): 

66 """Lazy import PyTorch.""" 

67 _attempt_imports() 

68 return _torch 

69 

70 

71@property 

72def jax(): 

73 """Lazy import JAX.""" 

74 _attempt_imports() 

75 return _jax 

76 

77 

78@property 

79def jnp(): 

80 """Lazy import JAX NumPy.""" 

81 _attempt_imports() 

82 return _jnp 

83 

84 

85@property 

86def cupy(): 

87 """Lazy import CuPy.""" 

88 _attempt_imports() 

89 return _cupy 

90 

91 

92@property 

93def tf(): 

94 """Lazy import TensorFlow.""" 

95 _attempt_imports() 

96 return _tf 

97 

98 

99# Simple function-based API 

100def get_torch(): 

101 """Get PyTorch module if available.""" 

102 _attempt_imports() 

103 return _torch 

104 

105 

106def get_jax(): 

107 """Get JAX module if available.""" 

108 _attempt_imports() 

109 return _jax 

110 

111 

112def get_jnp(): 

113 """Get JAX NumPy module if available.""" 

114 _attempt_imports() 

115 return _jnp 

116 

117 

118def get_cupy(): 

119 """Get CuPy module if available.""" 

120 _attempt_imports() 

121 return _cupy 

122 

123 

124def get_tf(): 

125 """Get TensorFlow module if available.""" 

126 _attempt_imports() 

127 return _tf