Coverage for openhcs/core/memory/utils.py: 5.6%
152 statements
« prev ^ index » next coverage.py v7.10.3, created at 2025-08-14 05:57 +0000
« prev ^ index » next coverage.py v7.10.3, created at 2025-08-14 05:57 +0000
1"""
2Memory conversion utility functions for OpenHCS.
4This module provides utility functions for memory conversion operations,
5supporting Clause 251 (Declarative Memory Conversion Interface) and
6Clause 65 (Fail Loudly).
7"""
9import importlib
10import logging
11from typing import Any, Optional
13from openhcs.constants.constants import MemoryType
15from .exceptions import MemoryConversionError
17logger = logging.getLogger(__name__)
20def _ensure_module(module_name: str) -> Any:
21 """
22 Ensure a module is imported and meets version requirements.
24 Args:
25 module_name: The name of the module to import
27 Returns:
28 The imported module
30 Raises:
31 ImportError: If the module cannot be imported or does not meet version requirements
32 RuntimeError: If the module has known issues with specific versions
33 """
34 try:
35 module = importlib.import_module(module_name)
37 # Check TensorFlow version for DLPack compatibility
38 if module_name == "tensorflow":
39 import pkg_resources
40 tf_version = pkg_resources.parse_version(module.__version__)
41 min_version = pkg_resources.parse_version("2.12.0")
43 if tf_version < min_version:
44 raise RuntimeError(
45 f"TensorFlow version {module.__version__} is not supported for DLPack operations. "
46 f"Version 2.12.0 or higher is required for stable DLPack support. "
47 f"Clause 88 (No Inferred Capabilities) violation: Cannot infer DLPack capability."
48 )
50 return module
51 except ImportError:
52 raise ImportError(f"Module {module_name} is required for this operation but is not installed")
55def _supports_cuda_array_interface(obj: Any) -> bool:
56 """
57 Check if an object supports the CUDA Array Interface.
59 Args:
60 obj: The object to check
62 Returns:
63 True if the object supports the CUDA Array Interface, False otherwise
64 """
65 return hasattr(obj, "__cuda_array_interface__")
68def _supports_dlpack(obj: Any) -> bool:
69 """
70 Check if an object supports DLPack.
72 Args:
73 obj: The object to check
75 Returns:
76 True if the object supports DLPack, False otherwise
78 Note:
79 For TensorFlow tensors, this function enforces Clause 88 (No Inferred Capabilities)
80 by explicitly checking:
81 1. TensorFlow version must be 2.12+ for stable DLPack support
82 2. Tensor must be on GPU (CPU tensors might succeed even without proper DLPack support)
83 3. tf.experimental.dlpack module must exist
84 """
85 # Check for PyTorch, CuPy, or JAX DLPack support
86 # PyTorch: __dlpack__ method, CuPy: toDlpack method, JAX: __dlpack__ method
87 if hasattr(obj, "toDlpack") or hasattr(obj, "to_dlpack") or hasattr(obj, "__dlpack__"):
88 # Special handling for TensorFlow to enforce Clause 88
89 if 'tensorflow' in str(type(obj)):
90 try:
91 import tensorflow as tf
93 # Check TensorFlow version - DLPack is only stable in TF 2.12+
94 tf_version = tf.__version__
95 major, minor = map(int, tf_version.split('.')[:2])
97 if major < 2 or (major == 2 and minor < 12):
98 # Explicitly fail for TF < 2.12 to prevent silent fallbacks
99 raise RuntimeError(
100 f"TensorFlow version {tf_version} does not support stable DLPack operations. "
101 f"Version 2.12.0 or higher is required. "
102 f"Clause 88 violation: Cannot infer DLPack capability."
103 )
105 # Check if tensor is on GPU - CPU tensors might succeed even without proper DLPack support
106 device_str = obj.device.lower()
107 if "gpu" not in device_str:
108 # Explicitly fail for CPU tensors to prevent deceptive behavior
109 raise RuntimeError(
110 "TensorFlow tensor on CPU cannot use DLPack operations reliably. "
111 "Only GPU tensors are supported for DLPack operations. "
112 "Clause 88 violation: Cannot infer GPU capability."
113 )
115 # Check if experimental.dlpack module exists
116 if not hasattr(tf.experimental, "dlpack"):
117 raise RuntimeError(
118 "TensorFlow installation missing experimental.dlpack module. "
119 "Clause 88 violation: Cannot infer DLPack capability."
120 )
122 return True
123 except (ImportError, AttributeError) as e:
124 # Re-raise with more specific error message
125 raise RuntimeError(
126 f"TensorFlow DLPack support check failed: {str(e)}. "
127 f"Clause 88 violation: Cannot infer DLPack capability."
128 ) from e
130 # For non-TensorFlow types, return True if they have DLPack methods
131 return True
133 return False
136def _get_device_id(data: Any, memory_type: str) -> Optional[int]:
137 """
138 Get the GPU device ID from a data object.
140 Args:
141 data: The data object
142 memory_type: The memory type
144 Returns:
145 The GPU device ID or None if not applicable
147 Raises:
148 MemoryConversionError: If the device ID cannot be determined for a GPU memory type
149 """
150 if memory_type == MemoryType.NUMPY.value:
151 return None
153 if memory_type == MemoryType.CUPY.value:
154 try:
155 return data.device.id
156 except AttributeError:
157 # Default to device 0 if not available
158 # This is a special case because CuPy arrays are always on a GPU
159 return 0
160 except Exception as e:
161 logger.warning(f"Failed to get device ID for CuPy array: {str(e)}")
162 return 0
164 if memory_type == MemoryType.TORCH.value:
165 try:
166 if data.is_cuda:
167 return data.device.index
168 # CPU tensor, no device ID
169 return None
170 except Exception as e:
171 logger.warning(f"Failed to get device ID for PyTorch tensor: {str(e)}")
172 return None
174 if memory_type == MemoryType.TENSORFLOW.value:
175 try:
176 device_str = data.device.lower()
177 if "gpu" in device_str:
178 # Extract device ID from string like "/device:gpu:0"
179 return int(device_str.split(":")[-1])
180 # CPU tensor, no device ID
181 return None
182 except Exception as e:
183 logger.warning(f"Failed to get device ID for TensorFlow tensor: {str(e)}")
184 return None
186 if memory_type == MemoryType.JAX.value:
187 try:
188 device_str = str(data.device).lower()
189 if "gpu" in device_str:
190 # Extract device ID from string like "gpu:0"
191 return int(device_str.split(":")[-1])
192 # CPU array, no device ID
193 return None
194 except Exception as e:
195 logger.warning(f"Failed to get device ID for JAX array: {str(e)}")
196 return None
198 if memory_type == MemoryType.PYCLESPERANTO.value:
199 try:
200 cle = _ensure_module("pyclesperanto")
201 current_device = cle.get_device()
202 # pyclesperanto device is an object, try to extract ID
203 if hasattr(current_device, 'id'):
204 return current_device.id
205 # Fallback: try to get device index from device list
206 devices = cle.list_available_devices()
207 for i, device in enumerate(devices):
208 if str(device) == str(current_device):
209 return i
210 # Default to 0 if we can't determine
211 return 0
212 except Exception as e:
213 logger.warning(f"Failed to get device ID for pyclesperanto array: {str(e)}")
214 return 0
216 return None
219def _set_device(memory_type: str, device_id: int) -> None:
220 """
221 Set the current device for a specific memory type.
223 Args:
224 memory_type: The memory type
225 device_id: The GPU device ID
227 Raises:
228 MemoryConversionError: If the device cannot be set
229 """
230 if memory_type == MemoryType.CUPY.value:
231 try:
232 cupy = _ensure_module("cupy")
233 cupy.cuda.Device(device_id).use()
234 except Exception as e:
235 raise MemoryConversionError(
236 source_type=memory_type,
237 target_type=memory_type,
238 method="device_selection",
239 reason=f"Failed to set CuPy device to {device_id}: {str(e)}"
240 ) from e
242 if memory_type == MemoryType.PYCLESPERANTO.value:
243 try:
244 cle = _ensure_module("pyclesperanto")
245 devices = cle.list_available_devices()
246 if device_id >= len(devices):
247 raise ValueError(f"Device ID {device_id} not available. Available devices: {len(devices)}")
248 cle.select_device(device_id)
249 except Exception as e:
250 raise MemoryConversionError(
251 source_type=memory_type,
252 target_type=memory_type,
253 method="device_selection",
254 reason=f"Failed to set pyclesperanto device to {device_id}: {str(e)}"
255 ) from e
257 # JAX doesn't have a global device setting mechanism
258 # Device selection happens at array creation or device_put time
260 # PyTorch and TensorFlow handle device placement at tensor creation time
261 # No need to set a global device
264def _move_to_device(data: Any, memory_type: str, device_id: int) -> Any:
265 """
266 Move data to a specific GPU device.
268 Args:
269 data: The data to move
270 memory_type: The memory type
271 device_id: The target GPU device ID
273 Returns:
274 The data on the target device
276 Raises:
277 MemoryConversionError: If the data cannot be moved to the specified device
278 """
279 if memory_type == MemoryType.CUPY.value:
280 cupy = _ensure_module("cupy")
281 try:
282 if data.device.id != device_id:
283 with cupy.cuda.Device(device_id):
284 return data.copy()
285 return data
286 except Exception as e:
287 raise MemoryConversionError(
288 source_type=memory_type,
289 target_type=memory_type,
290 method="device_movement",
291 reason=f"Failed to move CuPy array to device {device_id}: {str(e)}"
292 ) from e
294 if memory_type == MemoryType.TORCH.value:
295 try:
296 if data.is_cuda and data.device.index != device_id:
297 return data.to(f"cuda:{device_id}")
298 if not data.is_cuda:
299 return data.to(f"cuda:{device_id}")
300 return data
301 except Exception as e:
302 raise MemoryConversionError(
303 source_type=memory_type,
304 target_type=memory_type,
305 method="device_movement",
306 reason=f"Failed to move PyTorch tensor to device {device_id}: {str(e)}"
307 ) from e
309 if memory_type == MemoryType.TENSORFLOW.value:
310 try:
311 tf = _ensure_module("tensorflow")
312 with tf.device(f"/device:GPU:{device_id}"):
313 return tf.identity(data)
314 except Exception as e:
315 raise MemoryConversionError(
316 source_type=memory_type,
317 target_type=memory_type,
318 method="device_movement",
319 reason=f"Failed to move TensorFlow tensor to device {device_id}: {str(e)}"
320 ) from e
322 if memory_type == MemoryType.JAX.value:
323 try:
324 jax = _ensure_module("jax")
325 # JAX uses different device notation
326 return jax.device_put(data, jax.devices("gpu")[device_id])
327 except Exception as e:
328 raise MemoryConversionError(
329 source_type=memory_type,
330 target_type=memory_type,
331 method="device_movement",
332 reason=f"Failed to move JAX array to device {device_id}: {str(e)}"
333 ) from e
335 if memory_type == MemoryType.PYCLESPERANTO.value:
336 try:
337 cle = _ensure_module("pyclesperanto")
338 # Get current device of the array
339 current_device_id = _get_device_id(data, memory_type)
341 if current_device_id != device_id:
342 # Select target device and copy data
343 cle.select_device(device_id)
344 result = cle.create_like(data)
345 cle.copy(data, result)
346 return result
347 return data
348 except Exception as e:
349 raise MemoryConversionError(
350 source_type=memory_type,
351 target_type=memory_type,
352 method="device_movement",
353 reason=f"Failed to move pyclesperanto array to device {device_id}: {str(e)}"
354 ) from e
356 return data