Coverage for openhcs/core/memory/conversion_functions.py: 5.3%
518 statements
« prev ^ index » next coverage.py v7.10.3, created at 2025-08-14 05:57 +0000
« prev ^ index » next coverage.py v7.10.3, created at 2025-08-14 05:57 +0000
1"""
2Direct memory conversion functions for OpenHCS.
4This module provides direct conversion functions between different memory types,
5enforcing Clause 65 (Fail Loudly), Clause 88 (No Inferred Capabilities),
6and Clause 251 (Declarative Memory Conversion).
7"""
9from typing import Any, Optional
11from openhcs.constants.constants import MemoryType
13from .exceptions import MemoryConversionError
14from .utils import (_ensure_module, _supports_cuda_array_interface,
15 _supports_dlpack)
17# NumPy conversion functions
19def _numpy_to_numpy(data: Any) -> Any:
20 """Convert numpy array to numpy array (identity operation)."""
21 return data.copy()
24def _numpy_to_cupy(data: Any, gpu_id: int) -> Any:
25 """
26 Convert numpy array to cupy array.
28 Args:
29 data: The numpy array to convert
30 gpu_id: The target GPU device ID
32 Returns:
33 The converted cupy array
35 Raises:
36 ImportError: If cupy is not installed
37 ValueError: If gpu_id is negative
38 """
39 cupy = _ensure_module("cupy")
41 # Validate gpu_id
42 if gpu_id < 0:
43 raise ValueError(f"Invalid GPU ID: {gpu_id}. Must be a non-negative integer.")
45 # Always use the specified GPU device
46 with cupy.cuda.Device(gpu_id):
47 return cupy.array(data)
50def _numpy_to_torch(data: Any, gpu_id: int) -> Any:
51 """
52 Convert numpy array to torch tensor.
54 Args:
55 data: The numpy array to convert
56 gpu_id: The target GPU device ID
58 Returns:
59 The converted torch tensor
61 Raises:
62 ImportError: If torch is not installed
63 ValueError: If gpu_id is negative
64 """
65 torch = _ensure_module("torch")
67 # Validate gpu_id
68 if gpu_id is None:
69 raise ValueError("🔥 GPU ID IS NONE! The compiler failed to assign a GPU to this torch function. This is a GPU registry/compiler bug!")
70 if gpu_id < 0:
71 raise ValueError(f"Invalid GPU ID: {gpu_id}. Must be a non-negative integer.")
73 # Always use the specified GPU device
74 device = torch.device(f"cuda:{gpu_id}")
75 return torch.tensor(data, device=device)
78def _numpy_to_pyclesperanto(data: Any, gpu_id: int) -> Any:
79 """
80 Convert numpy array to pyclesperanto array.
82 Args:
83 data: The numpy array to convert
84 gpu_id: The target GPU device ID
86 Returns:
87 The converted pyclesperanto array
89 Raises:
90 ImportError: If pyclesperanto is not installed
91 ValueError: If gpu_id is negative
92 """
93 cle = _ensure_module("pyclesperanto")
95 # Validate gpu_id
96 if gpu_id < 0:
97 raise ValueError(f"Invalid GPU ID: {gpu_id}. Must be a non-negative integer.")
99 # Select the appropriate device
100 devices = cle.list_available_devices()
101 if gpu_id >= len(devices):
102 raise ValueError(f"GPU ID {gpu_id} not available. Available devices: {len(devices)}")
104 # Select device and push data
105 cle.select_device(gpu_id)
106 return cle.push(data)
109def _numpy_to_tensorflow(data: Any, gpu_id: int) -> Any:
110 """
111 Convert numpy array to tensorflow tensor.
113 Args:
114 data: The numpy array to convert
115 gpu_id: The target GPU device ID
117 Returns:
118 The converted tensorflow tensor
120 Raises:
121 ImportError: If tensorflow is not installed
122 ValueError: If gpu_id is negative
123 """
124 tf = _ensure_module("tensorflow")
126 # Validate gpu_id
127 if gpu_id is None:
128 raise ValueError("🔥 GPU ID IS NONE! The compiler failed to assign a GPU to this tensorflow function. This is a GPU registry/compiler bug!")
129 if gpu_id < 0:
130 raise ValueError(f"Invalid GPU ID: {gpu_id}. Must be a non-negative integer.")
132 # Always use the specified GPU device
133 with tf.device(f"/device:GPU:{gpu_id}"):
134 return tf.convert_to_tensor(data)
137# pyclesperanto conversion functions
139def _pyclesperanto_to_numpy(data: Any) -> Any:
140 """
141 Convert pyclesperanto array to numpy array.
143 Args:
144 data: The pyclesperanto array to convert
146 Returns:
147 The converted numpy array
148 """
149 cle = _ensure_module("pyclesperanto")
150 return cle.pull(data)
153def _pyclesperanto_to_pyclesperanto(data: Any) -> Any:
154 """Convert pyclesperanto array to pyclesperanto array (identity operation)."""
155 cle = _ensure_module("pyclesperanto")
156 return data
159def _pyclesperanto_to_torch(data: Any, allow_cpu_roundtrip: bool = False, device_id: Optional[int] = None) -> Any:
160 """
161 Convert pyclesperanto array to torch tensor, staying on GPU.
163 Args:
164 data: The pyclesperanto array to convert
165 allow_cpu_roundtrip: Whether to allow fallback to CPU roundtrip
166 device_id: The target GPU device ID (optional)
168 Returns:
169 The converted torch tensor
171 Raises:
172 MemoryConversionError: If conversion fails and CPU fallback is not authorized
173 ImportError: If torch is not installed
174 """
175 torch = _ensure_module("torch")
176 cle = _ensure_module("pyclesperanto")
178 # Try GPU-to-GPU conversion first
179 try:
180 # Use CUDA array interface for zero-copy conversion
181 if _supports_cuda_array_interface(data):
182 # Convert via CUDA array interface
183 tensor = torch.as_tensor(data, device=f"cuda:{device_id if device_id is not None else 0}")
185 # Move to specified device if needed
186 if device_id is not None and tensor.device.index != device_id:
187 tensor = tensor.to(f"cuda:{device_id}")
189 return tensor
190 except Exception as e:
191 if not allow_cpu_roundtrip:
192 raise MemoryConversionError(
193 source_type=MemoryType.PYCLESPERANTO.value,
194 target_type=MemoryType.TORCH.value,
195 method="GPU_conversion",
196 reason=str(e)
197 ) from e
199 # Fallback: CPU roundtrip
200 numpy_data = cle.pull(data)
201 if device_id is not None:
202 return torch.tensor(numpy_data, device=f"cuda:{device_id}")
203 return torch.tensor(numpy_data)
206def _pyclesperanto_to_tensorflow(data: Any, allow_cpu_roundtrip: bool = False, device_id: Optional[int] = None) -> Any:
207 """
208 Convert pyclesperanto array to tensorflow tensor, staying on GPU.
210 Args:
211 data: The pyclesperanto array to convert
212 allow_cpu_roundtrip: Whether to allow fallback to CPU roundtrip
213 device_id: The target GPU device ID (optional)
215 Returns:
216 The converted tensorflow tensor
218 Raises:
219 MemoryConversionError: If conversion fails and CPU fallback is not authorized
220 ImportError: If tensorflow is not installed
221 """
222 tf = _ensure_module("tensorflow")
223 cle = _ensure_module("pyclesperanto")
225 # Try GPU-to-GPU conversion first
226 try:
227 # Use CUDA array interface for zero-copy conversion
228 if _supports_cuda_array_interface(data):
229 # Convert via CUDA array interface
230 with tf.device(f"/device:GPU:{device_id if device_id is not None else 0}"):
231 return tf.experimental.dlpack.from_dlpack(data.__dlpack__())
232 except Exception as e:
233 if not allow_cpu_roundtrip:
234 raise MemoryConversionError(
235 source_type=MemoryType.PYCLESPERANTO.value,
236 target_type=MemoryType.TENSORFLOW.value,
237 method="GPU_conversion",
238 reason=str(e)
239 ) from e
241 # Fallback: CPU roundtrip
242 numpy_data = cle.pull(data)
243 if device_id is not None:
244 with tf.device(f"/device:GPU:{device_id}"):
245 return tf.convert_to_tensor(numpy_data)
246 return tf.convert_to_tensor(numpy_data)
249def _pyclesperanto_to_jax(data: Any, allow_cpu_roundtrip: bool = False, device_id: Optional[int] = None) -> Any:
250 """
251 Convert pyclesperanto array to JAX array, staying on GPU.
253 Args:
254 data: The pyclesperanto array to convert
255 allow_cpu_roundtrip: Whether to allow fallback to CPU roundtrip
256 device_id: The target GPU device ID (optional)
258 Returns:
259 The converted JAX array
261 Raises:
262 MemoryConversionError: If conversion fails and CPU fallback is not authorized
263 ImportError: If jax is not installed
264 """
265 jax = _ensure_module("jax")
266 cle = _ensure_module("pyclesperanto")
268 # Try GPU-to-GPU conversion first
269 try:
270 # Use DLPack for zero-copy conversion
271 if hasattr(data, '__dlpack__'):
272 dlpack = data.__dlpack__()
273 result = jax.dlpack.from_dlpack(dlpack)
275 # Move to specified device if needed
276 if device_id is not None:
277 result = jax.device_put(result, jax.devices("gpu")[device_id])
279 return result
280 except Exception as e:
281 if not allow_cpu_roundtrip:
282 raise MemoryConversionError(
283 source_type=MemoryType.PYCLESPERANTO.value,
284 target_type=MemoryType.JAX.value,
285 method="GPU_conversion",
286 reason=str(e)
287 ) from e
289 # Fallback: CPU roundtrip
290 numpy_data = cle.pull(data)
291 result = jax.numpy.array(numpy_data)
293 if device_id is not None:
294 result = jax.device_put(result, jax.devices("gpu")[device_id])
296 return result
299# CuPy conversion functions
301def _cupy_to_numpy(data: Any) -> Any:
302 """
303 Convert cupy array to numpy array.
305 Args:
306 data: The cupy array to convert
308 Returns:
309 The converted numpy array
310 """
311 return data.get()
314def _cupy_to_cupy(data: Any) -> Any:
315 """Convert cupy array to cupy array (identity operation)."""
316 return data.copy()
319def _cupy_to_torch(data: Any, allow_cpu_roundtrip: bool = False, device_id: Optional[int] = None) -> Any:
320 """
321 Convert cupy array to torch tensor, staying on GPU.
323 Args:
324 data: The cupy array to convert
325 allow_cpu_roundtrip: Whether to allow fallback to CPU roundtrip
326 device_id: The target GPU device ID (optional)
328 Returns:
329 The converted torch tensor
331 Raises:
332 MemoryConversionError: If conversion fails and CPU fallback is not authorized
333 ImportError: If torch is not installed
334 """
335 torch = _ensure_module("torch")
337 # Use DLPack for zero-copy GPU-to-GPU conversion
338 if _supports_dlpack(data):
339 try:
340 dlpack = data.toDlpack()
341 result = torch.from_dlpack(dlpack)
343 # Move to specified device if needed
344 if device_id is not None:
345 target_device = f"cuda:{device_id}"
346 if str(result.device) != target_device:
347 result = result.to(target_device)
349 return result
350 except Exception as e:
351 if not allow_cpu_roundtrip:
352 raise MemoryConversionError(
353 source_type=MemoryType.CUPY.value,
354 target_type=MemoryType.TORCH.value,
355 method="DLPack",
356 reason=str(e)
357 ) from e
359 # Fallback to CUDA Array Interface
360 elif _supports_cuda_array_interface(data):
361 print(f"🔥 CONVERSION DEBUG: CUDA Array Interface supported, data shape: {data.shape}")
362 try:
363 print(f"🔥 CONVERSION DEBUG: About to call torch.as_tensor...")
364 if device_id is not None:
365 result = torch.as_tensor(data, device=f"cuda:{device_id}")
366 else:
367 result = torch.as_tensor(data, device="cuda")
368 print(f"🔥 CONVERSION DEBUG: torch.as_tensor completed successfully")
369 return result
370 except Exception as e:
371 print(f"🔥 CONVERSION DEBUG: torch.as_tensor failed with error: {e}")
372 if not allow_cpu_roundtrip:
373 raise MemoryConversionError(
374 source_type=MemoryType.CUPY.value,
375 target_type=MemoryType.TORCH.value,
376 method="CUDA Array Interface",
377 reason=str(e)
378 ) from e
379 else:
380 if not allow_cpu_roundtrip:
381 raise MemoryConversionError(
382 source_type=MemoryType.CUPY.value,
383 target_type=MemoryType.TORCH.value,
384 method="CUDA Array Interface",
385 reason="CuPy array does not support CUDA Array Interface"
386 )
388 # Only reach here if allow_cpu_roundtrip=True
389 tensor = torch.from_numpy(data.get())
391 # Move to specified device if needed
392 if device_id is not None:
393 tensor = tensor.to(f"cuda:{device_id}")
395 return tensor
398def _cupy_to_tensorflow(data: Any, allow_cpu_roundtrip: bool = False, device_id: Optional[int] = None) -> Any:
399 """
400 Convert cupy array to tensorflow tensor, staying on GPU.
402 Args:
403 data: The cupy array to convert
404 allow_cpu_roundtrip: Whether to allow fallback to CPU roundtrip
405 device_id: The target GPU device ID (optional)
407 Returns:
408 The converted tensorflow tensor
410 Raises:
411 MemoryConversionError: If conversion fails and CPU fallback is not authorized
412 ImportError: If tensorflow is not installed
413 RuntimeError: If TensorFlow version is < 2.12 (unstable DLPack support)
414 """
415 tf = _ensure_module("tensorflow")
417 # Try using DLPack if supported
418 # _supports_dlpack will raise RuntimeError if TF version < 2.12 or tensor is on CPU
419 # This enforces Clause 88 (No Inferred Capabilities)
420 if _supports_dlpack(data):
421 try:
422 dlpack = data.toDlpack()
423 tensor = tf.experimental.dlpack.from_dlpack(dlpack)
425 # Move to specified device if needed
426 if device_id is not None:
427 with tf.device(f"/device:GPU:{device_id}"):
428 return tf.identity(tensor)
430 return tensor
431 except Exception as e:
432 if not allow_cpu_roundtrip:
433 raise MemoryConversionError(
434 source_type=MemoryType.CUPY.value,
435 target_type=MemoryType.TENSORFLOW.value,
436 method="DLPack",
437 reason=str(e)
438 ) from e
439 elif not allow_cpu_roundtrip:
440 raise MemoryConversionError(
441 source_type=MemoryType.CUPY.value,
442 target_type=MemoryType.TENSORFLOW.value,
443 method="DLPack",
444 reason="DLPack conversion not supported"
445 )
447 # Only reach here if allow_cpu_roundtrip=True
448 if device_id is not None:
449 with tf.device(f"/device:GPU:{device_id}"):
450 return tf.convert_to_tensor(data.get())
452 return tf.convert_to_tensor(data.get())
455def _cupy_to_pyclesperanto(data: Any, allow_cpu_roundtrip: bool = False, device_id: Optional[int] = None) -> Any:
456 """
457 Convert cupy array to pyclesperanto array, staying on GPU.
459 Args:
460 data: The cupy array to convert
461 allow_cpu_roundtrip: Whether to allow fallback to CPU roundtrip
462 device_id: The target GPU device ID (optional)
464 Returns:
465 The converted pyclesperanto array
467 Raises:
468 MemoryConversionError: If conversion fails and CPU fallback is not authorized
469 ImportError: If pyclesperanto is not installed
470 """
471 cle = _ensure_module("pyclesperanto")
473 # Try direct GPU conversion first
474 try:
475 # Get current CuPy device
476 current_device = data.device.id
478 # Select appropriate pyclesperanto device
479 if device_id is not None:
480 target_device = device_id
481 else:
482 target_device = current_device
484 devices = cle.list_available_devices()
485 if target_device >= len(devices):
486 if not allow_cpu_roundtrip:
487 raise MemoryConversionError(
488 source_type=MemoryType.CUPY.value,
489 target_type=MemoryType.PYCLESPERANTO.value,
490 method="device_selection",
491 reason=f"GPU ID {target_device} not available in pyclesperanto"
492 )
493 else:
494 cle.select_device(target_device)
496 # Convert via numpy (pyclesperanto doesn't have direct CuPy interop)
497 numpy_data = data.get() # CuPy to NumPy
498 return cle.push(numpy_data) # NumPy to pyclesperanto
500 except Exception as e:
501 if not allow_cpu_roundtrip:
502 raise MemoryConversionError(
503 source_type=MemoryType.CUPY.value,
504 target_type=MemoryType.PYCLESPERANTO.value,
505 method="GPU_conversion",
506 reason=str(e)
507 ) from e
509 # Fallback: CPU roundtrip
510 numpy_data = data.get()
511 cle.select_device(device_id if device_id is not None else 0)
512 return cle.push(numpy_data)
515def _pyclesperanto_to_cupy(data: Any, allow_cpu_roundtrip: bool = False, device_id: Optional[int] = None) -> Any:
516 """
517 Convert pyclesperanto array to cupy array, staying on GPU.
519 Args:
520 data: The pyclesperanto array to convert
521 allow_cpu_roundtrip: Whether to allow fallback to CPU roundtrip
522 device_id: The target GPU device ID (optional)
524 Returns:
525 The converted cupy array
527 Raises:
528 MemoryConversionError: If conversion fails and CPU fallback is not authorized
529 ImportError: If cupy is not installed
530 """
531 cupy = _ensure_module("cupy")
532 cle = _ensure_module("pyclesperanto")
534 try:
535 # Convert via numpy (pyclesperanto doesn't have direct CuPy interop)
536 numpy_data = cle.pull(data) # pyclesperanto to NumPy
538 # Convert to CuPy on specified device
539 if device_id is not None:
540 with cupy.cuda.Device(device_id):
541 return cupy.array(numpy_data)
542 else:
543 return cupy.array(numpy_data)
545 except Exception as e:
546 if not allow_cpu_roundtrip:
547 raise MemoryConversionError(
548 source_type=MemoryType.PYCLESPERANTO.value,
549 target_type=MemoryType.CUPY.value,
550 method="GPU_conversion",
551 reason=str(e)
552 ) from e
554 # Fallback: CPU roundtrip (same as above)
555 numpy_data = cle.pull(data)
556 if device_id is not None:
557 with cupy.cuda.Device(device_id):
558 return cupy.array(numpy_data)
559 else:
560 return cupy.array(numpy_data)
563def _cupy_to_jax(data: Any, allow_cpu_roundtrip: bool = False, device_id: Optional[int] = None) -> Any:
564 """
565 Convert cupy array to JAX array, staying on GPU with zero-copy DLPack transfer.
567 Args:
568 data: The cupy array to convert
569 allow_cpu_roundtrip: Whether to allow fallback to CPU roundtrip (ignored, always False)
570 device_id: The target GPU device ID (optional)
572 Returns:
573 The converted JAX array
575 Raises:
576 MemoryConversionError: If conversion fails
577 ImportError: If JAX is not installed
578 """
579 jax = _ensure_module("jax")
581 # Check if CuPy array is on GPU (should always be true for CuPy)
582 if not hasattr(data, 'device'):
583 raise MemoryConversionError(
584 source_type=MemoryType.CUPY.value,
585 target_type=MemoryType.JAX.value,
586 method="device_detection",
587 reason="CuPy array does not have a device attribute"
588 )
590 # Try using DLPack for direct GPU-to-GPU transfer
591 if _supports_dlpack(data):
592 try:
593 dlpack = data.toDlpack()
594 result = jax.dlpack.from_dlpack(dlpack)
596 # Move to specified device if needed
597 if device_id is not None:
598 current_device = None
599 try:
600 # Extract device ID from JAX array
601 device_str = str(result.device)
602 if "gpu:" in device_str:
603 current_device = int(device_str.split("gpu:")[-1].split(")")[0])
604 except Exception:
605 pass
607 # Only move if needed
608 if current_device != device_id:
609 result = jax.device_put(result, jax.devices("gpu")[device_id])
611 return result
612 except Exception as e:
613 # No CPU roundtrip allowed, so fail loudly
614 raise MemoryConversionError(
615 source_type=MemoryType.CUPY.value,
616 target_type=MemoryType.JAX.value,
617 method="DLPack",
618 reason=str(e)
619 ) from e
620 else:
621 # No CPU roundtrip allowed, so fail loudly
622 raise MemoryConversionError(
623 source_type=MemoryType.CUPY.value,
624 target_type=MemoryType.JAX.value,
625 method="DLPack",
626 reason="CuPy array does not support DLPack"
627 )
630# PyTorch conversion functions
632def _torch_to_numpy(data: Any) -> Any:
633 """
634 Convert torch tensor to numpy array.
636 Args:
637 data: The torch tensor to convert
639 Returns:
640 The converted numpy array
641 """
642 return data.detach().cpu().numpy()
645def _torch_to_cupy(data: Any, allow_cpu_roundtrip: bool = False, device_id: Optional[int] = None) -> Any:
646 """
647 Convert torch tensor to cupy array, staying on GPU.
649 Args:
650 data: The torch tensor to convert
651 allow_cpu_roundtrip: Whether to allow fallback to CPU roundtrip
652 device_id: The target GPU device ID (optional)
654 Returns:
655 The converted cupy array
657 Raises:
658 MemoryConversionError: If conversion fails and CPU fallback is not authorized
659 ImportError: If cupy is not installed
660 """
661 cupy = _ensure_module("cupy")
663 # Only attempt direct conversion if tensor is on CUDA
664 if data.is_cuda:
665 # Try using CUDA Array Interface
666 if _supports_cuda_array_interface(data):
667 try:
668 result = cupy.asarray(data)
670 # Move to specified device if needed
671 if device_id is not None and result.device.id != device_id:
672 with cupy.cuda.Device(device_id):
673 return result.copy()
675 return result
676 except Exception as e:
677 if not allow_cpu_roundtrip:
678 raise MemoryConversionError(
679 source_type=MemoryType.TORCH.value,
680 target_type=MemoryType.CUPY.value,
681 method="CUDA Array Interface",
682 reason=str(e)
683 ) from e
685 # Try using DLPack
686 if _supports_dlpack(data):
687 try:
688 dlpack = data.to_dlpack()
689 result = cupy.from_dlpack(dlpack)
691 # Move to specified device if needed
692 if device_id is not None and result.device.id != device_id:
693 with cupy.cuda.Device(device_id):
694 return result.copy()
696 return result
697 except Exception as e:
698 if not allow_cpu_roundtrip:
699 raise MemoryConversionError(
700 source_type=MemoryType.TORCH.value,
701 target_type=MemoryType.CUPY.value,
702 method="DLPack",
703 reason=str(e)
704 ) from e
705 elif not allow_cpu_roundtrip:
706 raise MemoryConversionError(
707 source_type=MemoryType.TORCH.value,
708 target_type=MemoryType.CUPY.value,
709 method="GPU-native",
710 reason="PyTorch tensor is not on CUDA"
711 )
713 # Only reach here if allow_cpu_roundtrip=True
714 if device_id is not None:
715 with cupy.cuda.Device(device_id):
716 return cupy.array(data.detach().cpu().numpy())
718 return cupy.array(data.detach().cpu().numpy())
721def _torch_to_torch(data: Any, device_id: Optional[int] = None) -> Any:
722 """
723 Convert torch tensor to torch tensor (identity operation).
725 Args:
726 data: The torch tensor to convert
727 device_id: The target GPU device ID (optional)
729 Returns:
730 The cloned torch tensor, possibly on a different device
731 """
732 result = data.clone()
734 # Move to specified device if needed
735 if device_id is not None:
736 if data.is_cuda and data.device.index != device_id:
737 result = result.to(f"cuda:{device_id}")
738 elif not data.is_cuda:
739 result = result.to(f"cuda:{device_id}")
741 return result
744def _torch_to_tensorflow(data: Any, allow_cpu_roundtrip: bool = False, device_id: Optional[int] = None) -> Any:
745 """
746 Convert torch tensor to tensorflow tensor, staying on GPU.
748 Args:
749 data: The torch tensor to convert
750 allow_cpu_roundtrip: Whether to allow fallback to CPU roundtrip
751 device_id: The target GPU device ID (optional)
753 Returns:
754 The converted tensorflow tensor
756 Raises:
757 MemoryConversionError: If conversion fails and CPU fallback is not authorized
758 ImportError: If tensorflow is not installed
759 RuntimeError: If TensorFlow version is < 2.12 (unstable DLPack support)
760 """
761 tf = _ensure_module("tensorflow")
763 # Only attempt direct conversion if tensor is on CUDA
764 if data.is_cuda:
765 # Check TensorFlow version for DLPack compatibility
766 try:
767 # This will check TF version and raise RuntimeError if < 2.12
768 # Enforces Clause 88 (No Inferred Capabilities)
769 tf_version = tf.__version__
770 major, minor = map(int, tf_version.split('.')[:2])
772 if major < 2 or (major == 2 and minor < 12):
773 raise RuntimeError(
774 f"TensorFlow version {tf_version} does not support stable DLPack operations. "
775 f"Version 2.12.0 or higher is required. "
776 f"Clause 88 violation: Cannot infer DLPack capability."
777 )
779 # Check if experimental.dlpack module exists
780 if not hasattr(tf.experimental, "dlpack"):
781 raise RuntimeError(
782 "TensorFlow installation missing experimental.dlpack module. "
783 "Clause 88 violation: Cannot infer DLPack capability."
784 )
786 # Now try the conversion
787 try:
788 dlpack = data.to_dlpack()
789 tensor = tf.experimental.dlpack.from_dlpack(dlpack)
791 # Move to specified device if needed
792 if device_id is not None:
793 with tf.device(f"/device:GPU:{device_id}"):
794 return tf.identity(tensor)
796 return tensor
797 except Exception as e:
798 if not allow_cpu_roundtrip:
799 raise MemoryConversionError(
800 source_type=MemoryType.TORCH.value,
801 target_type=MemoryType.TENSORFLOW.value,
802 method="DLPack",
803 reason=str(e)
804 ) from e
805 except RuntimeError as e:
806 if not allow_cpu_roundtrip:
807 raise MemoryConversionError(
808 source_type=MemoryType.TORCH.value,
809 target_type=MemoryType.TENSORFLOW.value,
810 method="DLPack",
811 reason=str(e)
812 ) from e
814 # If we get here, either the tensor is not on CUDA or there was a DLPack issue
815 if not allow_cpu_roundtrip:
816 raise MemoryConversionError(
817 source_type=MemoryType.TORCH.value,
818 target_type=MemoryType.TENSORFLOW.value,
819 method="GPU-native",
820 reason="PyTorch tensor is not on CUDA or TensorFlow DLPack support issue"
821 )
823 # Only reach here if allow_cpu_roundtrip=True
824 # This is an explicit CPU roundtrip, which is only allowed if explicitly requested
825 numpy_data = data.detach().cpu().numpy()
827 if device_id is not None:
828 with tf.device(f"/device:GPU:{device_id}"):
829 return tf.convert_to_tensor(numpy_data)
831 return tf.convert_to_tensor(numpy_data)
834# TensorFlow conversion functions
836def _tensorflow_to_numpy(data: Any) -> Any:
837 """
838 Convert tensorflow tensor to numpy array.
840 Args:
841 data: The tensorflow tensor to convert
843 Returns:
844 The converted numpy array
845 """
846 return data.numpy()
849def _torch_to_pyclesperanto(data: Any, allow_cpu_roundtrip: bool = False, device_id: Optional[int] = None) -> Any:
850 """
851 Convert torch tensor to pyclesperanto array, staying on GPU.
853 Args:
854 data: The torch tensor to convert
855 allow_cpu_roundtrip: Whether to allow fallback to CPU roundtrip
856 device_id: The target GPU device ID (optional)
858 Returns:
859 The converted pyclesperanto array
861 Raises:
862 MemoryConversionError: If conversion fails and CPU fallback is not authorized
863 ImportError: If pyclesperanto is not installed
864 """
865 cle = _ensure_module("pyclesperanto")
867 # Try GPU-to-GPU conversion first
868 if data.is_cuda:
869 try:
870 # Use CUDA array interface for zero-copy conversion
871 if _supports_cuda_array_interface(data):
872 # Select target device
873 target_device = device_id if device_id is not None else data.device.index
874 cle.select_device(target_device)
876 # Convert via CUDA array interface
877 return cle.asarray(data.detach())
878 except Exception as e:
879 if not allow_cpu_roundtrip:
880 raise MemoryConversionError(
881 source_type=MemoryType.TORCH.value,
882 target_type=MemoryType.PYCLESPERANTO.value,
883 method="GPU_conversion",
884 reason=str(e)
885 ) from e
887 # Fallback: CPU roundtrip
888 if not allow_cpu_roundtrip:
889 raise MemoryConversionError(
890 source_type=MemoryType.TORCH.value,
891 target_type=MemoryType.PYCLESPERANTO.value,
892 method="GPU-native",
893 reason="PyTorch tensor is not on CUDA"
894 )
896 # CPU roundtrip conversion
897 numpy_data = data.detach().cpu().numpy()
898 cle.select_device(device_id if device_id is not None else 0)
899 return cle.push(numpy_data)
902def _torch_to_jax(data: Any, allow_cpu_roundtrip: bool = False, device_id: Optional[int] = None) -> Any:
903 """
904 Convert PyTorch tensor to JAX array, staying on GPU with zero-copy DLPack transfer.
906 Args:
907 data: The PyTorch tensor to convert
908 allow_cpu_roundtrip: Whether to allow fallback to CPU roundtrip (ignored, always False)
909 device_id: The target GPU device ID (optional)
911 Returns:
912 The converted JAX array
914 Raises:
915 MemoryConversionError: If conversion fails
916 ImportError: If JAX is not installed
917 """
918 jax = _ensure_module("jax")
919 torch = _ensure_module("torch")
921 # If tensor is on CPU, move it to GPU first (similar to _numpy_to_jax behavior)
922 if not data.is_cuda:
923 if device_id is not None:
924 # Move CPU tensor to specified GPU device
925 data = data.to(f"cuda:{device_id}")
926 else:
927 # Move to default GPU device
928 data = data.cuda()
930 # Now attempt direct conversion with tensor on CUDA
931 if data.is_cuda:
932 # Try using DLPack for direct GPU-to-GPU transfer
933 if _supports_dlpack(data):
934 try:
935 dlpack = torch.to_dlpack(data)
936 result = jax.dlpack.from_dlpack(dlpack)
938 # Move to specified device if needed
939 if device_id is not None:
940 current_device = None
941 try:
942 # Extract device ID from JAX array
943 device_str = str(result.device)
944 if "gpu:" in device_str or "cuda:" in device_str:
945 current_device = int(device_str.split("gpu:")[-1].split(")")[0])
946 except Exception:
947 pass
949 # Only move if needed
950 if current_device != device_id:
951 result = jax.device_put(result, jax.devices("gpu")[device_id])
953 return result
954 except Exception as e:
955 # No CPU roundtrip allowed, so fail loudly
956 raise MemoryConversionError(
957 source_type=MemoryType.TORCH.value,
958 target_type=MemoryType.JAX.value,
959 method="DLPack",
960 reason=str(e)
961 ) from e
963 # If we get here, there was a DLPack issue (tensor should be on CUDA at this point)
964 raise MemoryConversionError(
965 source_type=MemoryType.TORCH.value,
966 target_type=MemoryType.JAX.value,
967 method="GPU-native",
968 reason="DLPack conversion failed after moving tensor to CUDA"
969 )
972# JAX conversion functions
974def _numpy_to_jax(data: Any, gpu_id: int) -> Any:
975 """
976 Convert numpy array to JAX array.
978 Args:
979 data: The numpy array to convert
980 gpu_id: The target GPU device ID
982 Returns:
983 The converted JAX array
985 Raises:
986 ImportError: If JAX is not installed
987 ValueError: If gpu_id is negative
988 """
989 jax = _ensure_module("jax")
991 # Validate gpu_id
992 if gpu_id < 0:
993 raise ValueError(f"Invalid GPU ID: {gpu_id}. Must be a non-negative integer.")
995 # Create JAX array on CPU
996 result = jax.numpy.array(data)
998 # Always move to the specified GPU device
999 # JAX uses different device notation
1000 result = jax.device_put(result, jax.devices("gpu")[gpu_id])
1002 return result
1005def _jax_to_numpy(data: Any) -> Any:
1006 """
1007 Convert JAX array to numpy array.
1009 Args:
1010 data: The JAX array to convert
1012 Returns:
1013 The converted numpy array
1014 """
1015 # JAX arrays can be converted to numpy with .copy()
1016 return data.copy()
1019def _jax_to_cupy(
1020 data: Any, allow_cpu_roundtrip: bool = False, device_id: Optional[int] = None
1021) -> Any:
1022 """
1023 Convert JAX array to cupy array, staying on GPU if possible.
1025 Args:
1026 data: The JAX array to convert
1027 allow_cpu_roundtrip: Whether to allow fallback to CPU roundtrip
1028 device_id: The target GPU device ID (optional)
1030 Returns:
1031 The converted cupy array
1033 Raises:
1034 MemoryConversionError: If conversion fails and CPU fallback is not authorized
1035 ImportError: If cupy is not installed
1036 """
1037 jax = _ensure_module("jax")
1038 cupy = _ensure_module("cupy")
1040 # Check if JAX array is on GPU
1041 device_str = str(data.device).lower()
1042 is_on_gpu = device_str.startswith("gpu") or device_str.startswith("cuda")
1044 if is_on_gpu:
1045 # Try using DLPack for direct GPU-to-GPU transfer
1046 if _supports_dlpack(data):
1047 try:
1048 dlpack = jax.dlpack.to_dlpack(data)
1049 result = cupy.from_dlpack(dlpack)
1051 # Move to specified device if needed
1052 if device_id is not None and result.device.id != device_id:
1053 with cupy.cuda.Device(device_id):
1054 return result.copy()
1056 return result
1057 except Exception as e:
1058 if not allow_cpu_roundtrip:
1059 raise MemoryConversionError(
1060 source_type=MemoryType.JAX.value,
1061 target_type=MemoryType.CUPY.value,
1062 method="DLPack",
1063 reason=str(e)
1064 ) from e
1065 elif not allow_cpu_roundtrip:
1066 raise MemoryConversionError(
1067 source_type=MemoryType.JAX.value,
1068 target_type=MemoryType.CUPY.value,
1069 method="GPU-native",
1070 reason="JAX array is not on GPU"
1071 )
1073 # Only reach here if allow_cpu_roundtrip=True or DLPack failed
1074 numpy_data = _jax_to_numpy(data)
1076 if device_id is not None:
1077 with cupy.cuda.Device(device_id):
1078 return cupy.array(numpy_data)
1080 return cupy.array(numpy_data)
1083def _jax_to_torch(
1084 data: Any, allow_cpu_roundtrip: bool = False, device_id: Optional[int] = None
1085) -> Any:
1086 """
1087 Convert JAX array to torch tensor, staying on GPU if possible.
1089 Args:
1090 data: The JAX array to convert
1091 allow_cpu_roundtrip: Whether to allow fallback to CPU roundtrip
1092 device_id: The target GPU device ID (optional)
1094 Returns:
1095 The converted torch tensor
1097 Raises:
1098 MemoryConversionError: If conversion fails and CPU fallback is not authorized
1099 ImportError: If torch is not installed
1100 """
1101 jax = _ensure_module("jax")
1102 torch = _ensure_module("torch")
1104 # Check if JAX array is on GPU
1105 device_str = str(data.device).lower()
1106 is_on_gpu = device_str.startswith("gpu") or device_str.startswith("cuda")
1108 if is_on_gpu:
1109 # Try using DLPack for direct GPU-to-GPU transfer
1110 if _supports_dlpack(data):
1111 try:
1112 dlpack = jax.dlpack.to_dlpack(data)
1113 tensor = torch.from_dlpack(dlpack)
1115 # Move to specified device if needed
1116 if device_id is not None and tensor.device.index != device_id:
1117 tensor = tensor.to(f"cuda:{device_id}")
1119 return tensor
1120 except Exception as e:
1121 if not allow_cpu_roundtrip:
1122 raise MemoryConversionError(
1123 source_type=MemoryType.JAX.value,
1124 target_type=MemoryType.TORCH.value,
1125 method="DLPack",
1126 reason=str(e)
1127 ) from e
1128 elif not allow_cpu_roundtrip:
1129 raise MemoryConversionError(
1130 source_type=MemoryType.JAX.value,
1131 target_type=MemoryType.TORCH.value,
1132 method="GPU-native",
1133 reason="JAX array is not on GPU"
1134 )
1136 # Only reach here if allow_cpu_roundtrip=True or DLPack failed
1137 numpy_data = _jax_to_numpy(data)
1139 if device_id is not None:
1140 return torch.tensor(numpy_data, device=f"cuda:{device_id}")
1142 return torch.tensor(numpy_data)
1145def _jax_to_tensorflow(
1146 data: Any, allow_cpu_roundtrip: bool = False, device_id: Optional[int] = None
1147) -> Any:
1148 """
1149 Convert JAX array to tensorflow tensor, staying on GPU if possible.
1151 Args:
1152 data: The JAX array to convert
1153 allow_cpu_roundtrip: Whether to allow fallback to CPU roundtrip
1154 device_id: The target GPU device ID (optional)
1156 Returns:
1157 The converted tensorflow tensor
1159 Raises:
1160 MemoryConversionError: If conversion fails and CPU fallback is not authorized
1161 ImportError: If tensorflow is not installed
1162 """
1163 jax = _ensure_module("jax")
1164 tf = _ensure_module("tensorflow")
1166 # Check if JAX array is on GPU
1167 device_str = str(data.device).lower()
1168 is_on_gpu = device_str.startswith("gpu") or device_str.startswith("cuda")
1170 if is_on_gpu:
1171 # Try using DLPack for direct GPU-to-GPU transfer
1172 if _supports_dlpack(data):
1173 try:
1174 dlpack = jax.dlpack.to_dlpack(data)
1175 tensor = tf.experimental.dlpack.from_dlpack(dlpack)
1177 # Move to specified device if needed
1178 if device_id is not None:
1179 with tf.device(f"/device:GPU:{device_id}"):
1180 return tf.identity(tensor)
1182 return tensor
1183 except Exception as e:
1184 if not allow_cpu_roundtrip:
1185 raise MemoryConversionError(
1186 source_type=MemoryType.JAX.value,
1187 target_type=MemoryType.TENSORFLOW.value,
1188 method="DLPack",
1189 reason=str(e)
1190 ) from e
1191 elif not allow_cpu_roundtrip:
1192 raise MemoryConversionError(
1193 source_type=MemoryType.JAX.value,
1194 target_type=MemoryType.TENSORFLOW.value,
1195 method="GPU-native",
1196 reason="JAX array is not on GPU"
1197 )
1199 # Only reach here if allow_cpu_roundtrip=True or DLPack failed
1200 numpy_data = _jax_to_numpy(data)
1202 if device_id is not None:
1203 with tf.device(f"/device:GPU:{device_id}"):
1204 return tf.convert_to_tensor(numpy_data)
1206 return tf.convert_to_tensor(numpy_data)
1209def _tensorflow_to_cupy(
1210 data: Any, allow_cpu_roundtrip: bool = False, device_id: Optional[int] = None
1211) -> Any:
1212 """
1213 Convert tensorflow tensor to cupy array, staying on GPU.
1215 Args:
1216 data: The tensorflow tensor to convert
1217 allow_cpu_roundtrip: Whether to allow fallback to CPU roundtrip
1218 device_id: The target GPU device ID (optional)
1220 Returns:
1221 The converted cupy array
1223 Raises:
1224 MemoryConversionError: If conversion fails and CPU fallback is not authorized
1225 ImportError: If cupy is not installed
1226 RuntimeError: If TensorFlow version is < 2.12 (unstable DLPack support)
1227 """
1228 cupy = _ensure_module("cupy")
1229 tf = _ensure_module("tensorflow")
1231 # _supports_dlpack will raise RuntimeError if TF version < 2.12 or tensor is on CPU
1232 # This enforces Clause 88 (No Inferred Capabilities)
1233 try:
1234 if _supports_dlpack(data):
1235 try:
1236 dlpack = tf.experimental.dlpack.to_dlpack(data)
1237 result = cupy.from_dlpack(dlpack)
1239 # Move to specified device if needed
1240 if device_id is not None and result.device.id != device_id:
1241 with cupy.cuda.Device(device_id):
1242 return result.copy()
1244 return result
1245 except Exception as e:
1246 if not allow_cpu_roundtrip:
1247 raise MemoryConversionError(
1248 source_type=MemoryType.TENSORFLOW.value,
1249 target_type=MemoryType.CUPY.value,
1250 method="DLPack",
1251 reason=str(e)
1252 ) from e
1253 except RuntimeError as e:
1254 if not allow_cpu_roundtrip:
1255 raise MemoryConversionError(
1256 source_type=MemoryType.TENSORFLOW.value,
1257 target_type=MemoryType.CUPY.value,
1258 method="DLPack",
1259 reason=str(e)
1260 ) from e
1262 # Only reach here if allow_cpu_roundtrip=True or _supports_dlpack raised an exception
1263 if not allow_cpu_roundtrip:
1264 raise MemoryConversionError(
1265 source_type=MemoryType.TENSORFLOW.value,
1266 target_type=MemoryType.CUPY.value,
1267 method="GPU-native",
1268 reason="TensorFlow tensor is not on GPU or DLPack not supported"
1269 )
1271 # Only reach here if allow_cpu_roundtrip=True
1272 if device_id is not None:
1273 with cupy.cuda.Device(device_id):
1274 return cupy.array(data.numpy())
1276 return cupy.array(data.numpy())
1279def _tensorflow_to_torch(
1280 data: Any, allow_cpu_roundtrip: bool = False, device_id: Optional[int] = None
1281) -> Any:
1282 """
1283 Convert tensorflow tensor to torch tensor, staying on GPU.
1285 Args:
1286 data: The tensorflow tensor to convert
1287 allow_cpu_roundtrip: Whether to allow fallback to CPU roundtrip
1288 device_id: The target GPU device ID (optional)
1290 Returns:
1291 The converted torch tensor
1293 Raises:
1294 MemoryConversionError: If conversion fails and CPU fallback is not authorized
1295 ImportError: If torch is not installed
1296 RuntimeError: If TensorFlow version is < 2.12 (unstable DLPack support)
1297 """
1298 torch = _ensure_module("torch")
1299 tf = _ensure_module("tensorflow")
1301 # _supports_dlpack will raise RuntimeError if TF version < 2.12 or tensor is on CPU
1302 # This enforces Clause 88 (No Inferred Capabilities)
1303 try:
1304 if _supports_dlpack(data):
1305 try:
1306 dlpack = tf.experimental.dlpack.to_dlpack(data)
1307 tensor = torch.from_dlpack(dlpack)
1309 # Move to specified device if needed
1310 if device_id is not None and tensor.device.index != device_id:
1311 tensor = tensor.to(f"cuda:{device_id}")
1313 return tensor
1314 except Exception as e:
1315 if not allow_cpu_roundtrip:
1316 raise MemoryConversionError(
1317 source_type=MemoryType.TENSORFLOW.value,
1318 target_type=MemoryType.TORCH.value,
1319 method="DLPack",
1320 reason=str(e)
1321 ) from e
1322 except RuntimeError as e:
1323 if not allow_cpu_roundtrip:
1324 raise MemoryConversionError(
1325 source_type=MemoryType.TENSORFLOW.value,
1326 target_type=MemoryType.TORCH.value,
1327 method="DLPack",
1328 reason=str(e)
1329 ) from e
1331 # Only reach here if allow_cpu_roundtrip=True or _supports_dlpack raised an exception
1332 if not allow_cpu_roundtrip:
1333 raise MemoryConversionError(
1334 source_type=MemoryType.TENSORFLOW.value,
1335 target_type=MemoryType.TORCH.value,
1336 method="GPU-native",
1337 reason="TensorFlow tensor is not on GPU or DLPack not supported"
1338 )
1340 # Only reach here if allow_cpu_roundtrip=True
1341 tensor = torch.from_numpy(data.numpy())
1343 # Move to specified device if needed
1344 if device_id is not None:
1345 tensor = tensor.to(f"cuda:{device_id}")
1347 return tensor
1350def _tensorflow_to_jax(
1351 data: Any, allow_cpu_roundtrip: bool = False, device_id: Optional[int] = None
1352) -> Any:
1353 """
1354 Convert TensorFlow tensor to JAX array, staying on GPU with zero-copy DLPack transfer.
1356 Args:
1357 data: The TensorFlow tensor to convert
1358 allow_cpu_roundtrip: Whether to allow fallback to CPU roundtrip (ignored, always False)
1359 device_id: The target GPU device ID (optional)
1361 Returns:
1362 The converted JAX array
1364 Raises:
1365 MemoryConversionError: If conversion fails
1366 ImportError: If JAX is not installed
1367 RuntimeError: If TensorFlow version is < 2.12 (unstable DLPack support)
1368 """
1369 jax = _ensure_module("jax")
1370 tf = _ensure_module("tensorflow")
1372 # Check TensorFlow version for DLPack compatibility
1373 tf_version = tf.__version__
1374 major, minor = map(int, tf_version.split('.')[:2])
1376 if major < 2 or (major == 2 and minor < 12):
1377 raise RuntimeError(
1378 f"TensorFlow version {tf_version} does not support stable DLPack operations. "
1379 f"Version 2.12.0 or higher is required. "
1380 f"Clause 88 violation: Cannot infer DLPack capability."
1381 )
1383 # Check if experimental.dlpack module exists
1384 if not hasattr(tf.experimental, "dlpack"):
1385 raise RuntimeError(
1386 "TensorFlow installation missing experimental.dlpack module. "
1387 "Clause 88 violation: Cannot infer DLPack capability."
1388 )
1390 # Check if tensor is on GPU
1391 device_str = data.device.lower()
1392 is_on_gpu = "gpu" in device_str
1394 if is_on_gpu:
1395 # Try using DLPack for direct GPU-to-GPU transfer
1396 try:
1397 dlpack = tf.experimental.dlpack.to_dlpack(data)
1398 result = jax.dlpack.from_dlpack(dlpack)
1400 # Move to specified device if needed
1401 if device_id is not None:
1402 current_device = None
1403 try:
1404 # Extract device ID from JAX array
1405 device_str = str(result.device)
1406 if "gpu:" in device_str:
1407 current_device = int(device_str.rsplit('gpu:', maxsplit=1)[-1].split(")")[0])
1408 except (ValueError, IndexError):
1409 pass
1411 # Only move if needed
1412 if current_device != device_id:
1413 result = jax.device_put(result, jax.devices("gpu")[device_id])
1415 return result
1416 except Exception as e:
1417 # No CPU roundtrip allowed, so fail loudly
1418 raise MemoryConversionError(
1419 source_type=MemoryType.TENSORFLOW.value,
1420 target_type=MemoryType.JAX.value,
1421 method="DLPack",
1422 reason=str(e)
1423 ) from e
1425 # If we get here, the tensor is not on GPU
1426 # No CPU roundtrip allowed, so fail loudly
1427 raise MemoryConversionError(
1428 source_type=MemoryType.TENSORFLOW.value,
1429 target_type=MemoryType.JAX.value,
1430 method="GPU-native",
1431 reason="TensorFlow tensor is not on GPU"
1432 )
1435def _tensorflow_to_tensorflow(data: Any, device_id: Optional[int] = None) -> Any:
1436 """
1437 Convert tensorflow tensor to tensorflow tensor (identity operation).
1439 Args:
1440 data: The tensorflow tensor to convert
1441 device_id: The target GPU device ID (optional)
1443 Returns:
1444 The copied tensorflow tensor, possibly on a different device
1445 """
1446 tf = _ensure_module("tensorflow")
1448 if device_id is not None:
1449 with tf.device(f"/device:GPU:{device_id}"):
1450 return tf.identity(data)
1452 return tf.identity(data)
1455def _tensorflow_to_pyclesperanto(data: Any, allow_cpu_roundtrip: bool = False, device_id: Optional[int] = None) -> Any:
1456 """
1457 Convert tensorflow tensor to pyclesperanto array, staying on GPU.
1459 Args:
1460 data: The tensorflow tensor to convert
1461 allow_cpu_roundtrip: Whether to allow fallback to CPU roundtrip
1462 device_id: The target GPU device ID (optional)
1464 Returns:
1465 The converted pyclesperanto array
1467 Raises:
1468 MemoryConversionError: If conversion fails and CPU fallback is not authorized
1469 ImportError: If pyclesperanto is not installed
1470 """
1471 tf = _ensure_module("tensorflow")
1472 cle = _ensure_module("pyclesperanto")
1474 # Try GPU-to-GPU conversion first
1475 try:
1476 # Use DLPack for zero-copy conversion
1477 if hasattr(tf.experimental, 'dlpack') and hasattr(tf.experimental.dlpack, 'to_dlpack'):
1478 dlpack = tf.experimental.dlpack.to_dlpack(data)
1480 # Select target device
1481 target_device = device_id if device_id is not None else 0
1482 cle.select_device(target_device)
1484 # Convert from DLPack
1485 return cle.from_dlpack(dlpack)
1486 except Exception as e:
1487 if not allow_cpu_roundtrip:
1488 raise MemoryConversionError(
1489 source_type=MemoryType.TENSORFLOW.value,
1490 target_type=MemoryType.PYCLESPERANTO.value,
1491 method="GPU_conversion",
1492 reason=str(e)
1493 ) from e
1495 # Fallback: CPU roundtrip
1496 numpy_data = data.numpy()
1497 cle.select_device(device_id if device_id is not None else 0)
1498 return cle.push(numpy_data)
1501def _jax_to_jax(data: Any, device_id: Optional[int] = None) -> Any:
1502 """
1503 Convert JAX array to JAX array (identity operation).
1505 Args:
1506 data: The JAX array to convert
1507 device_id: The target GPU device ID (optional)
1509 Returns:
1510 The cloned JAX array, possibly on a different device
1511 """
1512 jax = _ensure_module("jax")
1514 result = data.copy()
1516 # Move to specified device if needed
1517 if device_id is not None:
1518 result = jax.device_put(result, jax.devices("gpu")[device_id])
1520 return result
1523def _jax_to_pyclesperanto(data: Any, allow_cpu_roundtrip: bool = False, device_id: Optional[int] = None) -> Any:
1524 """
1525 Convert JAX array to pyclesperanto array, staying on GPU.
1527 Args:
1528 data: The JAX array to convert
1529 allow_cpu_roundtrip: Whether to allow fallback to CPU roundtrip
1530 device_id: The target GPU device ID (optional)
1532 Returns:
1533 The converted pyclesperanto array
1535 Raises:
1536 MemoryConversionError: If conversion fails and CPU fallback is not authorized
1537 ImportError: If pyclesperanto is not installed
1538 """
1539 jax = _ensure_module("jax")
1540 cle = _ensure_module("pyclesperanto")
1542 # Try GPU-to-GPU conversion first
1543 try:
1544 # Use DLPack for zero-copy conversion
1545 if hasattr(data, '__dlpack__'):
1546 dlpack = data.__dlpack__()
1548 # Select target device
1549 target_device = device_id if device_id is not None else 0
1550 cle.select_device(target_device)
1552 # Convert from DLPack
1553 return cle.from_dlpack(dlpack)
1554 except Exception as e:
1555 if not allow_cpu_roundtrip:
1556 raise MemoryConversionError(
1557 source_type=MemoryType.JAX.value,
1558 target_type=MemoryType.PYCLESPERANTO.value,
1559 method="GPU_conversion",
1560 reason=str(e)
1561 ) from e
1563 # Fallback: CPU roundtrip
1564 numpy_data = _jax_to_numpy(data)
1565 cle.select_device(device_id if device_id is not None else 0)
1566 return cle.push(numpy_data)