Coverage for src/arraybridge/utils.py: 78%

120 statements  

« prev     ^ index     » next       coverage.py v7.11.0, created at 2025-11-03 05:09 +0000

1""" 

2Memory conversion utility functions for arraybridge. 

3 

4This module provides utility functions for memory conversion operations, 

5supporting Clause 251 (Declarative Memory Conversion Interface) and 

6Clause 65 (Fail Loudly). 

7""" 

8 

9import importlib 

10import logging 

11from typing import Any, Optional 

12 

13from arraybridge.types import MemoryType 

14 

15from .exceptions import MemoryConversionError 

16from .framework_config import _FRAMEWORK_CONFIG 

17 

18logger = logging.getLogger(__name__) 

19 

20 

21class _ModulePlaceholder: 

22 """ 

23 Placeholder for missing optional modules that allows attribute access 

24 for type annotations while still being falsy and failing on actual use. 

25 """ 

26 

27 def __init__(self, module_name: str): 

28 self._module_name = module_name 

29 

30 def __bool__(self): 

31 return False 

32 

33 def __getattr__(self, name): 

34 # Return another placeholder for chained attribute access 

35 # This allows things like cp.ndarray in type annotations to work 

36 return _ModulePlaceholder(f"{self._module_name}.{name}") 

37 

38 def __call__(self, *args, **kwargs): 

39 # If someone tries to actually call a function, fail loudly 

40 raise ImportError( 

41 f"Module '{self._module_name}' is not available. " 

42 f"Please install the required dependency." 

43 ) 

44 

45 def __repr__(self): 

46 return f"<ModulePlaceholder for '{self._module_name}'>" 

47 

48 

49def optional_import(module_name: str) -> Optional[Any]: 

50 """ 

51 Import a module if available, otherwise return a placeholder that handles 

52 attribute access gracefully for type annotations but fails on actual use. 

53 

54 This function allows for graceful handling of optional dependencies. 

55 It can be used to import libraries that may not be installed, 

56 particularly GPU-related libraries like torch, tensorflow, and cupy. 

57 

58 Args: 

59 module_name: Name of the module to import 

60 

61 Returns: 

62 The imported module if available, a placeholder otherwise 

63 

64 Example: 

65 ```python 

66 # Import torch if available 

67 torch = optional_import("torch") 

68 

69 # Check if torch is available before using it 

70 if torch: 

71 # Use torch 

72 tensor = torch.tensor([1, 2, 3]) 

73 else: 

74 # Handle the case where torch is not available 

75 raise ImportError("PyTorch is required for this function") 

76 ``` 

77 """ 

78 try: 

79 # Use importlib.import_module which handles dotted names properly 

80 return importlib.import_module(module_name) 

81 except (ImportError, ModuleNotFoundError, AttributeError): 

82 # Return a placeholder that handles attribute access gracefully 

83 return _ModulePlaceholder(module_name) 

84 

85 

86def _ensure_module(module_name: str) -> Any: 

87 """ 

88 Ensure a module is imported and meets version requirements. 

89 

90 Args: 

91 module_name: The name of the module to import 

92 

93 Returns: 

94 The imported module 

95 

96 Raises: 

97 ImportError: If the module cannot be imported or does not meet version requirements 

98 RuntimeError: If the module has known issues with specific versions 

99 """ 

100 try: 

101 module = importlib.import_module(module_name) 

102 except ImportError: 

103 raise ImportError( 

104 f"Module {module_name} is required for this operation " f"but is not installed" 

105 ) 

106 

107 # Check TensorFlow version for DLPack compatibility 

108 if module_name == "tensorflow": 

109 try: 

110 from packaging import version 

111 

112 tf_version = version.parse(module.__version__) 

113 min_version = version.parse("2.12.0") 

114 

115 if tf_version < min_version: 

116 raise RuntimeError( 

117 f"TensorFlow version {module.__version__} is not supported " 

118 f"for DLPack operations. " 

119 f"Version 2.12.0 or higher is required for stable DLPack support." 

120 ) 

121 except ImportError: 

122 # Fallback: simple string comparison if packaging not available 

123 try: 

124 tf_parts = [int(x) for x in module.__version__.split(".")[:3]] 

125 if (tf_parts[0] < 2) or (tf_parts[0] == 2 and tf_parts[1] < 12): 

126 raise RuntimeError( 

127 f"TensorFlow version {module.__version__} is not supported " 

128 f"for DLPack operations. " 

129 f"Version 2.12.0 or higher is required for stable DLPack support." 

130 ) 

131 except (ValueError, IndexError): 

132 # If version parsing fails, assume it's ok 

133 pass 

134 

135 return module 

136 

137 

138def _supports_cuda_array_interface(obj: Any) -> bool: 

139 """ 

140 Check if an object supports the CUDA Array Interface. 

141 

142 Args: 

143 obj: The object to check 

144 

145 Returns: 

146 True if the object supports the CUDA Array Interface, False otherwise 

147 """ 

148 return hasattr(obj, "__cuda_array_interface__") 

149 

150 

151def _supports_dlpack(obj: Any) -> bool: 

152 """ 

153 Check if an object supports DLPack. 

154 

155 Args: 

156 obj: The object to check 

157 

158 Returns: 

159 True if the object supports DLPack, False otherwise 

160 

161 Note: 

162 For TensorFlow tensors, this function enforces Clause 88 (No Inferred Capabilities) 

163 by explicitly checking: 

164 1. TensorFlow version must be 2.12+ for stable DLPack support 

165 2. Tensor must be on GPU (CPU tensors might succeed even without proper DLPack support) 

166 3. tf.experimental.dlpack module must exist 

167 """ 

168 # Check for PyTorch, CuPy, or JAX DLPack support 

169 # PyTorch: __dlpack__ method, CuPy: toDlpack method, JAX: __dlpack__ method 

170 if hasattr(obj, "toDlpack") or hasattr(obj, "to_dlpack") or hasattr(obj, "__dlpack__"): 

171 # Special handling for TensorFlow to enforce Clause 88 

172 if "tensorflow" in str(type(obj)): 

173 try: 

174 import tensorflow as tf 

175 

176 # Check TensorFlow version - DLPack is only stable in TF 2.12+ 

177 tf_version = tf.__version__ 

178 major, minor = map(int, tf_version.split(".")[:2]) 

179 

180 if major < 2 or (major == 2 and minor < 12): 

181 # Explicitly fail for TF < 2.12 to prevent silent fallbacks 

182 raise RuntimeError( 

183 f"TensorFlow version {tf_version} does not support " 

184 f"stable DLPack operations. " 

185 f"Version 2.12.0 or higher is required. " 

186 f"Clause 88 violation: Cannot infer DLPack capability." 

187 ) 

188 

189 # Check if tensor is on GPU - CPU tensors might succeed 

190 # even without proper DLPack support 

191 device_str = obj.device.lower() 

192 if "gpu" not in device_str: 

193 # Explicitly fail for CPU tensors to prevent deceptive behavior 

194 raise RuntimeError( 

195 "TensorFlow tensor on CPU cannot use DLPack operations reliably. " 

196 "Only GPU tensors are supported for DLPack operations. " 

197 "Clause 88 violation: Cannot infer GPU capability." 

198 ) 

199 

200 # Check if experimental.dlpack module exists 

201 if not hasattr(tf.experimental, "dlpack"): 

202 raise RuntimeError( 

203 "TensorFlow installation missing experimental.dlpack module. " 

204 "Clause 88 violation: Cannot infer DLPack capability." 

205 ) 

206 

207 return True 

208 except (ImportError, AttributeError) as e: 

209 # Re-raise with more specific error message 

210 raise RuntimeError( 

211 f"TensorFlow DLPack support check failed: {str(e)}. " 

212 f"Clause 88 violation: Cannot infer DLPack capability." 

213 ) from e 

214 

215 # For non-TensorFlow types, return True if they have DLPack methods 

216 return True 

217 

218 return False 

219 

220 

221# NOTE: Device operations now defined in framework_config.py 

222# This eliminates the scattered _DEVICE_OPS dict 

223 

224 

225def _get_device_id(data: Any, memory_type: str) -> Optional[int]: 

226 """ 

227 Get the GPU device ID from a data object using framework config. 

228 

229 Args: 

230 data: The data object 

231 memory_type: The memory type 

232 

233 Returns: 

234 The GPU device ID or None if not applicable 

235 

236 Raises: 

237 MemoryConversionError: If the device ID cannot be determined for a GPU memory type 

238 """ 

239 # Convert string to enum 

240 mem_type = MemoryType(memory_type) 

241 config = _FRAMEWORK_CONFIG[mem_type] 

242 get_id_handler = config["get_device_id"] 

243 

244 # Check if it's a callable handler (pyclesperanto) 

245 if callable(get_id_handler): 

246 mod = _ensure_module(mem_type.value) 

247 return get_id_handler(data, mod) 

248 

249 # Check if it's None (CPU) 

250 if get_id_handler is None: 

251 return None 

252 

253 # It's an eval expression 

254 try: 

255 mod = _ensure_module(mem_type.value) # noqa: F841 (used in eval) 

256 return eval(get_id_handler) 

257 except (AttributeError, Exception) as e: 

258 logger.warning(f"Failed to get device ID for {mem_type.value} array: {e}") 

259 # Try fallback if available 

260 if "get_device_id_fallback" in config: 

261 return eval(config["get_device_id_fallback"]) 

262 

263 

264def _set_device(memory_type: str, device_id: int) -> None: 

265 """ 

266 Set the current device for a specific memory type using framework config. 

267 

268 Args: 

269 memory_type: The memory type 

270 device_id: The GPU device ID 

271 

272 Raises: 

273 MemoryConversionError: If the device cannot be set 

274 """ 

275 # Convert string to enum 

276 mem_type = MemoryType(memory_type) 

277 config = _FRAMEWORK_CONFIG[mem_type] 

278 set_device_handler = config["set_device"] 

279 

280 # Check if it's a callable handler (pyclesperanto) 

281 if callable(set_device_handler): 

282 try: 

283 mod = _ensure_module(mem_type.value) 

284 set_device_handler(device_id, mod) 

285 except Exception as e: 

286 raise MemoryConversionError( 

287 source_type=memory_type, 

288 target_type=memory_type, 

289 method="device_selection", 

290 reason=f"Failed to set {mem_type.value} device to {device_id}: {e}", 

291 ) from e 

292 return 

293 

294 # Check if it's None (frameworks that don't need global device setting) 

295 if set_device_handler is None: 

296 return 

297 

298 # It's an eval expression 

299 try: 

300 mod = _ensure_module(mem_type.value) # noqa: F841 (used in eval) 

301 eval(set_device_handler.format(mod="mod")) 

302 except Exception as e: 

303 raise MemoryConversionError( 

304 source_type=memory_type, 

305 target_type=memory_type, 

306 method="device_selection", 

307 reason=f"Failed to set {mem_type.value} device to {device_id}: {e}", 

308 ) from e 

309 

310 

311def _move_to_device(data: Any, memory_type: str, device_id: int) -> Any: 

312 """ 

313 Move data to a specific GPU device using framework config. 

314 

315 Args: 

316 data: The data to move 

317 memory_type: The memory type 

318 device_id: The target GPU device ID 

319 

320 Returns: 

321 The data on the target device 

322 

323 Raises: 

324 MemoryConversionError: If the data cannot be moved to the specified device 

325 """ 

326 # Convert string to enum 

327 mem_type = MemoryType(memory_type) 

328 config = _FRAMEWORK_CONFIG[mem_type] 

329 move_handler = config["move_to_device"] 

330 

331 # Check if it's a callable handler (pyclesperanto) 

332 if callable(move_handler): 

333 try: 

334 mod = _ensure_module(mem_type.value) 

335 return move_handler(data, device_id, mod, memory_type) 

336 except Exception as e: 

337 raise MemoryConversionError( 

338 source_type=memory_type, 

339 target_type=memory_type, 

340 method="device_movement", 

341 reason=f"Failed to move {mem_type.value} array to device {device_id}: {e}", 

342 ) from e 

343 

344 # Check if it's None (CPU memory types) 

345 if move_handler is None: 

346 return data 

347 

348 # It's an eval expression 

349 try: 

350 mod = _ensure_module(mem_type.value) # noqa: F841 (used in eval) 

351 

352 # Handle context managers (CuPy, TensorFlow) 

353 if "move_context" in config and config["move_context"]: 

354 context_expr = config["move_context"].format(mod="mod") 

355 context = eval(context_expr) 

356 with context: 

357 return eval(move_handler.format(mod="mod")) 

358 else: 

359 return eval(move_handler.format(mod="mod")) 

360 except Exception as e: 

361 raise MemoryConversionError( 

362 source_type=memory_type, 

363 target_type=memory_type, 

364 method="device_movement", 

365 reason=f"Failed to move {mem_type.value} array to device {device_id}: {e}", 

366 ) from e