Coverage for src/arraybridge/utils.py: 78%
120 statements
« prev ^ index » next coverage.py v7.11.0, created at 2025-11-03 05:09 +0000
« prev ^ index » next coverage.py v7.11.0, created at 2025-11-03 05:09 +0000
1"""
2Memory conversion utility functions for arraybridge.
4This module provides utility functions for memory conversion operations,
5supporting Clause 251 (Declarative Memory Conversion Interface) and
6Clause 65 (Fail Loudly).
7"""
9import importlib
10import logging
11from typing import Any, Optional
13from arraybridge.types import MemoryType
15from .exceptions import MemoryConversionError
16from .framework_config import _FRAMEWORK_CONFIG
18logger = logging.getLogger(__name__)
21class _ModulePlaceholder:
22 """
23 Placeholder for missing optional modules that allows attribute access
24 for type annotations while still being falsy and failing on actual use.
25 """
27 def __init__(self, module_name: str):
28 self._module_name = module_name
30 def __bool__(self):
31 return False
33 def __getattr__(self, name):
34 # Return another placeholder for chained attribute access
35 # This allows things like cp.ndarray in type annotations to work
36 return _ModulePlaceholder(f"{self._module_name}.{name}")
38 def __call__(self, *args, **kwargs):
39 # If someone tries to actually call a function, fail loudly
40 raise ImportError(
41 f"Module '{self._module_name}' is not available. "
42 f"Please install the required dependency."
43 )
45 def __repr__(self):
46 return f"<ModulePlaceholder for '{self._module_name}'>"
49def optional_import(module_name: str) -> Optional[Any]:
50 """
51 Import a module if available, otherwise return a placeholder that handles
52 attribute access gracefully for type annotations but fails on actual use.
54 This function allows for graceful handling of optional dependencies.
55 It can be used to import libraries that may not be installed,
56 particularly GPU-related libraries like torch, tensorflow, and cupy.
58 Args:
59 module_name: Name of the module to import
61 Returns:
62 The imported module if available, a placeholder otherwise
64 Example:
65 ```python
66 # Import torch if available
67 torch = optional_import("torch")
69 # Check if torch is available before using it
70 if torch:
71 # Use torch
72 tensor = torch.tensor([1, 2, 3])
73 else:
74 # Handle the case where torch is not available
75 raise ImportError("PyTorch is required for this function")
76 ```
77 """
78 try:
79 # Use importlib.import_module which handles dotted names properly
80 return importlib.import_module(module_name)
81 except (ImportError, ModuleNotFoundError, AttributeError):
82 # Return a placeholder that handles attribute access gracefully
83 return _ModulePlaceholder(module_name)
86def _ensure_module(module_name: str) -> Any:
87 """
88 Ensure a module is imported and meets version requirements.
90 Args:
91 module_name: The name of the module to import
93 Returns:
94 The imported module
96 Raises:
97 ImportError: If the module cannot be imported or does not meet version requirements
98 RuntimeError: If the module has known issues with specific versions
99 """
100 try:
101 module = importlib.import_module(module_name)
102 except ImportError:
103 raise ImportError(
104 f"Module {module_name} is required for this operation " f"but is not installed"
105 )
107 # Check TensorFlow version for DLPack compatibility
108 if module_name == "tensorflow":
109 try:
110 from packaging import version
112 tf_version = version.parse(module.__version__)
113 min_version = version.parse("2.12.0")
115 if tf_version < min_version:
116 raise RuntimeError(
117 f"TensorFlow version {module.__version__} is not supported "
118 f"for DLPack operations. "
119 f"Version 2.12.0 or higher is required for stable DLPack support."
120 )
121 except ImportError:
122 # Fallback: simple string comparison if packaging not available
123 try:
124 tf_parts = [int(x) for x in module.__version__.split(".")[:3]]
125 if (tf_parts[0] < 2) or (tf_parts[0] == 2 and tf_parts[1] < 12):
126 raise RuntimeError(
127 f"TensorFlow version {module.__version__} is not supported "
128 f"for DLPack operations. "
129 f"Version 2.12.0 or higher is required for stable DLPack support."
130 )
131 except (ValueError, IndexError):
132 # If version parsing fails, assume it's ok
133 pass
135 return module
138def _supports_cuda_array_interface(obj: Any) -> bool:
139 """
140 Check if an object supports the CUDA Array Interface.
142 Args:
143 obj: The object to check
145 Returns:
146 True if the object supports the CUDA Array Interface, False otherwise
147 """
148 return hasattr(obj, "__cuda_array_interface__")
151def _supports_dlpack(obj: Any) -> bool:
152 """
153 Check if an object supports DLPack.
155 Args:
156 obj: The object to check
158 Returns:
159 True if the object supports DLPack, False otherwise
161 Note:
162 For TensorFlow tensors, this function enforces Clause 88 (No Inferred Capabilities)
163 by explicitly checking:
164 1. TensorFlow version must be 2.12+ for stable DLPack support
165 2. Tensor must be on GPU (CPU tensors might succeed even without proper DLPack support)
166 3. tf.experimental.dlpack module must exist
167 """
168 # Check for PyTorch, CuPy, or JAX DLPack support
169 # PyTorch: __dlpack__ method, CuPy: toDlpack method, JAX: __dlpack__ method
170 if hasattr(obj, "toDlpack") or hasattr(obj, "to_dlpack") or hasattr(obj, "__dlpack__"):
171 # Special handling for TensorFlow to enforce Clause 88
172 if "tensorflow" in str(type(obj)):
173 try:
174 import tensorflow as tf
176 # Check TensorFlow version - DLPack is only stable in TF 2.12+
177 tf_version = tf.__version__
178 major, minor = map(int, tf_version.split(".")[:2])
180 if major < 2 or (major == 2 and minor < 12):
181 # Explicitly fail for TF < 2.12 to prevent silent fallbacks
182 raise RuntimeError(
183 f"TensorFlow version {tf_version} does not support "
184 f"stable DLPack operations. "
185 f"Version 2.12.0 or higher is required. "
186 f"Clause 88 violation: Cannot infer DLPack capability."
187 )
189 # Check if tensor is on GPU - CPU tensors might succeed
190 # even without proper DLPack support
191 device_str = obj.device.lower()
192 if "gpu" not in device_str:
193 # Explicitly fail for CPU tensors to prevent deceptive behavior
194 raise RuntimeError(
195 "TensorFlow tensor on CPU cannot use DLPack operations reliably. "
196 "Only GPU tensors are supported for DLPack operations. "
197 "Clause 88 violation: Cannot infer GPU capability."
198 )
200 # Check if experimental.dlpack module exists
201 if not hasattr(tf.experimental, "dlpack"):
202 raise RuntimeError(
203 "TensorFlow installation missing experimental.dlpack module. "
204 "Clause 88 violation: Cannot infer DLPack capability."
205 )
207 return True
208 except (ImportError, AttributeError) as e:
209 # Re-raise with more specific error message
210 raise RuntimeError(
211 f"TensorFlow DLPack support check failed: {str(e)}. "
212 f"Clause 88 violation: Cannot infer DLPack capability."
213 ) from e
215 # For non-TensorFlow types, return True if they have DLPack methods
216 return True
218 return False
221# NOTE: Device operations now defined in framework_config.py
222# This eliminates the scattered _DEVICE_OPS dict
225def _get_device_id(data: Any, memory_type: str) -> Optional[int]:
226 """
227 Get the GPU device ID from a data object using framework config.
229 Args:
230 data: The data object
231 memory_type: The memory type
233 Returns:
234 The GPU device ID or None if not applicable
236 Raises:
237 MemoryConversionError: If the device ID cannot be determined for a GPU memory type
238 """
239 # Convert string to enum
240 mem_type = MemoryType(memory_type)
241 config = _FRAMEWORK_CONFIG[mem_type]
242 get_id_handler = config["get_device_id"]
244 # Check if it's a callable handler (pyclesperanto)
245 if callable(get_id_handler):
246 mod = _ensure_module(mem_type.value)
247 return get_id_handler(data, mod)
249 # Check if it's None (CPU)
250 if get_id_handler is None:
251 return None
253 # It's an eval expression
254 try:
255 mod = _ensure_module(mem_type.value) # noqa: F841 (used in eval)
256 return eval(get_id_handler)
257 except (AttributeError, Exception) as e:
258 logger.warning(f"Failed to get device ID for {mem_type.value} array: {e}")
259 # Try fallback if available
260 if "get_device_id_fallback" in config:
261 return eval(config["get_device_id_fallback"])
264def _set_device(memory_type: str, device_id: int) -> None:
265 """
266 Set the current device for a specific memory type using framework config.
268 Args:
269 memory_type: The memory type
270 device_id: The GPU device ID
272 Raises:
273 MemoryConversionError: If the device cannot be set
274 """
275 # Convert string to enum
276 mem_type = MemoryType(memory_type)
277 config = _FRAMEWORK_CONFIG[mem_type]
278 set_device_handler = config["set_device"]
280 # Check if it's a callable handler (pyclesperanto)
281 if callable(set_device_handler):
282 try:
283 mod = _ensure_module(mem_type.value)
284 set_device_handler(device_id, mod)
285 except Exception as e:
286 raise MemoryConversionError(
287 source_type=memory_type,
288 target_type=memory_type,
289 method="device_selection",
290 reason=f"Failed to set {mem_type.value} device to {device_id}: {e}",
291 ) from e
292 return
294 # Check if it's None (frameworks that don't need global device setting)
295 if set_device_handler is None:
296 return
298 # It's an eval expression
299 try:
300 mod = _ensure_module(mem_type.value) # noqa: F841 (used in eval)
301 eval(set_device_handler.format(mod="mod"))
302 except Exception as e:
303 raise MemoryConversionError(
304 source_type=memory_type,
305 target_type=memory_type,
306 method="device_selection",
307 reason=f"Failed to set {mem_type.value} device to {device_id}: {e}",
308 ) from e
311def _move_to_device(data: Any, memory_type: str, device_id: int) -> Any:
312 """
313 Move data to a specific GPU device using framework config.
315 Args:
316 data: The data to move
317 memory_type: The memory type
318 device_id: The target GPU device ID
320 Returns:
321 The data on the target device
323 Raises:
324 MemoryConversionError: If the data cannot be moved to the specified device
325 """
326 # Convert string to enum
327 mem_type = MemoryType(memory_type)
328 config = _FRAMEWORK_CONFIG[mem_type]
329 move_handler = config["move_to_device"]
331 # Check if it's a callable handler (pyclesperanto)
332 if callable(move_handler):
333 try:
334 mod = _ensure_module(mem_type.value)
335 return move_handler(data, device_id, mod, memory_type)
336 except Exception as e:
337 raise MemoryConversionError(
338 source_type=memory_type,
339 target_type=memory_type,
340 method="device_movement",
341 reason=f"Failed to move {mem_type.value} array to device {device_id}: {e}",
342 ) from e
344 # Check if it's None (CPU memory types)
345 if move_handler is None:
346 return data
348 # It's an eval expression
349 try:
350 mod = _ensure_module(mem_type.value) # noqa: F841 (used in eval)
352 # Handle context managers (CuPy, TensorFlow)
353 if "move_context" in config and config["move_context"]:
354 context_expr = config["move_context"].format(mod="mod")
355 context = eval(context_expr)
356 with context:
357 return eval(move_handler.format(mod="mod"))
358 else:
359 return eval(move_handler.format(mod="mod"))
360 except Exception as e:
361 raise MemoryConversionError(
362 source_type=memory_type,
363 target_type=memory_type,
364 method="device_movement",
365 reason=f"Failed to move {mem_type.value} array to device {device_id}: {e}",
366 ) from e