Coverage for openhcs/core/memory/conversion_functions.py: 5.3%

518 statements  

« prev     ^ index     » next       coverage.py v7.10.3, created at 2025-08-14 05:57 +0000

1""" 

2Direct memory conversion functions for OpenHCS. 

3 

4This module provides direct conversion functions between different memory types, 

5enforcing Clause 65 (Fail Loudly), Clause 88 (No Inferred Capabilities), 

6and Clause 251 (Declarative Memory Conversion). 

7""" 

8 

9from typing import Any, Optional 

10 

11from openhcs.constants.constants import MemoryType 

12 

13from .exceptions import MemoryConversionError 

14from .utils import (_ensure_module, _supports_cuda_array_interface, 

15 _supports_dlpack) 

16 

17# NumPy conversion functions 

18 

19def _numpy_to_numpy(data: Any) -> Any: 

20 """Convert numpy array to numpy array (identity operation).""" 

21 return data.copy() 

22 

23 

24def _numpy_to_cupy(data: Any, gpu_id: int) -> Any: 

25 """ 

26 Convert numpy array to cupy array. 

27 

28 Args: 

29 data: The numpy array to convert 

30 gpu_id: The target GPU device ID 

31 

32 Returns: 

33 The converted cupy array 

34 

35 Raises: 

36 ImportError: If cupy is not installed 

37 ValueError: If gpu_id is negative 

38 """ 

39 cupy = _ensure_module("cupy") 

40 

41 # Validate gpu_id 

42 if gpu_id < 0: 

43 raise ValueError(f"Invalid GPU ID: {gpu_id}. Must be a non-negative integer.") 

44 

45 # Always use the specified GPU device 

46 with cupy.cuda.Device(gpu_id): 

47 return cupy.array(data) 

48 

49 

50def _numpy_to_torch(data: Any, gpu_id: int) -> Any: 

51 """ 

52 Convert numpy array to torch tensor. 

53 

54 Args: 

55 data: The numpy array to convert 

56 gpu_id: The target GPU device ID 

57 

58 Returns: 

59 The converted torch tensor 

60 

61 Raises: 

62 ImportError: If torch is not installed 

63 ValueError: If gpu_id is negative 

64 """ 

65 torch = _ensure_module("torch") 

66 

67 # Validate gpu_id 

68 if gpu_id is None: 

69 raise ValueError("🔥 GPU ID IS NONE! The compiler failed to assign a GPU to this torch function. This is a GPU registry/compiler bug!") 

70 if gpu_id < 0: 

71 raise ValueError(f"Invalid GPU ID: {gpu_id}. Must be a non-negative integer.") 

72 

73 # Always use the specified GPU device 

74 device = torch.device(f"cuda:{gpu_id}") 

75 return torch.tensor(data, device=device) 

76 

77 

78def _numpy_to_pyclesperanto(data: Any, gpu_id: int) -> Any: 

79 """ 

80 Convert numpy array to pyclesperanto array. 

81 

82 Args: 

83 data: The numpy array to convert 

84 gpu_id: The target GPU device ID 

85 

86 Returns: 

87 The converted pyclesperanto array 

88 

89 Raises: 

90 ImportError: If pyclesperanto is not installed 

91 ValueError: If gpu_id is negative 

92 """ 

93 cle = _ensure_module("pyclesperanto") 

94 

95 # Validate gpu_id 

96 if gpu_id < 0: 

97 raise ValueError(f"Invalid GPU ID: {gpu_id}. Must be a non-negative integer.") 

98 

99 # Select the appropriate device 

100 devices = cle.list_available_devices() 

101 if gpu_id >= len(devices): 

102 raise ValueError(f"GPU ID {gpu_id} not available. Available devices: {len(devices)}") 

103 

104 # Select device and push data 

105 cle.select_device(gpu_id) 

106 return cle.push(data) 

107 

108 

109def _numpy_to_tensorflow(data: Any, gpu_id: int) -> Any: 

110 """ 

111 Convert numpy array to tensorflow tensor. 

112 

113 Args: 

114 data: The numpy array to convert 

115 gpu_id: The target GPU device ID 

116 

117 Returns: 

118 The converted tensorflow tensor 

119 

120 Raises: 

121 ImportError: If tensorflow is not installed 

122 ValueError: If gpu_id is negative 

123 """ 

124 tf = _ensure_module("tensorflow") 

125 

126 # Validate gpu_id 

127 if gpu_id is None: 

128 raise ValueError("🔥 GPU ID IS NONE! The compiler failed to assign a GPU to this tensorflow function. This is a GPU registry/compiler bug!") 

129 if gpu_id < 0: 

130 raise ValueError(f"Invalid GPU ID: {gpu_id}. Must be a non-negative integer.") 

131 

132 # Always use the specified GPU device 

133 with tf.device(f"/device:GPU:{gpu_id}"): 

134 return tf.convert_to_tensor(data) 

135 

136 

137# pyclesperanto conversion functions 

138 

139def _pyclesperanto_to_numpy(data: Any) -> Any: 

140 """ 

141 Convert pyclesperanto array to numpy array. 

142 

143 Args: 

144 data: The pyclesperanto array to convert 

145 

146 Returns: 

147 The converted numpy array 

148 """ 

149 cle = _ensure_module("pyclesperanto") 

150 return cle.pull(data) 

151 

152 

153def _pyclesperanto_to_pyclesperanto(data: Any) -> Any: 

154 """Convert pyclesperanto array to pyclesperanto array (identity operation).""" 

155 cle = _ensure_module("pyclesperanto") 

156 return data 

157 

158 

159def _pyclesperanto_to_torch(data: Any, allow_cpu_roundtrip: bool = False, device_id: Optional[int] = None) -> Any: 

160 """ 

161 Convert pyclesperanto array to torch tensor, staying on GPU. 

162 

163 Args: 

164 data: The pyclesperanto array to convert 

165 allow_cpu_roundtrip: Whether to allow fallback to CPU roundtrip 

166 device_id: The target GPU device ID (optional) 

167 

168 Returns: 

169 The converted torch tensor 

170 

171 Raises: 

172 MemoryConversionError: If conversion fails and CPU fallback is not authorized 

173 ImportError: If torch is not installed 

174 """ 

175 torch = _ensure_module("torch") 

176 cle = _ensure_module("pyclesperanto") 

177 

178 # Try GPU-to-GPU conversion first 

179 try: 

180 # Use CUDA array interface for zero-copy conversion 

181 if _supports_cuda_array_interface(data): 

182 # Convert via CUDA array interface 

183 tensor = torch.as_tensor(data, device=f"cuda:{device_id if device_id is not None else 0}") 

184 

185 # Move to specified device if needed 

186 if device_id is not None and tensor.device.index != device_id: 

187 tensor = tensor.to(f"cuda:{device_id}") 

188 

189 return tensor 

190 except Exception as e: 

191 if not allow_cpu_roundtrip: 

192 raise MemoryConversionError( 

193 source_type=MemoryType.PYCLESPERANTO.value, 

194 target_type=MemoryType.TORCH.value, 

195 method="GPU_conversion", 

196 reason=str(e) 

197 ) from e 

198 

199 # Fallback: CPU roundtrip 

200 numpy_data = cle.pull(data) 

201 if device_id is not None: 

202 return torch.tensor(numpy_data, device=f"cuda:{device_id}") 

203 return torch.tensor(numpy_data) 

204 

205 

206def _pyclesperanto_to_tensorflow(data: Any, allow_cpu_roundtrip: bool = False, device_id: Optional[int] = None) -> Any: 

207 """ 

208 Convert pyclesperanto array to tensorflow tensor, staying on GPU. 

209 

210 Args: 

211 data: The pyclesperanto array to convert 

212 allow_cpu_roundtrip: Whether to allow fallback to CPU roundtrip 

213 device_id: The target GPU device ID (optional) 

214 

215 Returns: 

216 The converted tensorflow tensor 

217 

218 Raises: 

219 MemoryConversionError: If conversion fails and CPU fallback is not authorized 

220 ImportError: If tensorflow is not installed 

221 """ 

222 tf = _ensure_module("tensorflow") 

223 cle = _ensure_module("pyclesperanto") 

224 

225 # Try GPU-to-GPU conversion first 

226 try: 

227 # Use CUDA array interface for zero-copy conversion 

228 if _supports_cuda_array_interface(data): 

229 # Convert via CUDA array interface 

230 with tf.device(f"/device:GPU:{device_id if device_id is not None else 0}"): 

231 return tf.experimental.dlpack.from_dlpack(data.__dlpack__()) 

232 except Exception as e: 

233 if not allow_cpu_roundtrip: 

234 raise MemoryConversionError( 

235 source_type=MemoryType.PYCLESPERANTO.value, 

236 target_type=MemoryType.TENSORFLOW.value, 

237 method="GPU_conversion", 

238 reason=str(e) 

239 ) from e 

240 

241 # Fallback: CPU roundtrip 

242 numpy_data = cle.pull(data) 

243 if device_id is not None: 

244 with tf.device(f"/device:GPU:{device_id}"): 

245 return tf.convert_to_tensor(numpy_data) 

246 return tf.convert_to_tensor(numpy_data) 

247 

248 

249def _pyclesperanto_to_jax(data: Any, allow_cpu_roundtrip: bool = False, device_id: Optional[int] = None) -> Any: 

250 """ 

251 Convert pyclesperanto array to JAX array, staying on GPU. 

252 

253 Args: 

254 data: The pyclesperanto array to convert 

255 allow_cpu_roundtrip: Whether to allow fallback to CPU roundtrip 

256 device_id: The target GPU device ID (optional) 

257 

258 Returns: 

259 The converted JAX array 

260 

261 Raises: 

262 MemoryConversionError: If conversion fails and CPU fallback is not authorized 

263 ImportError: If jax is not installed 

264 """ 

265 jax = _ensure_module("jax") 

266 cle = _ensure_module("pyclesperanto") 

267 

268 # Try GPU-to-GPU conversion first 

269 try: 

270 # Use DLPack for zero-copy conversion 

271 if hasattr(data, '__dlpack__'): 

272 dlpack = data.__dlpack__() 

273 result = jax.dlpack.from_dlpack(dlpack) 

274 

275 # Move to specified device if needed 

276 if device_id is not None: 

277 result = jax.device_put(result, jax.devices("gpu")[device_id]) 

278 

279 return result 

280 except Exception as e: 

281 if not allow_cpu_roundtrip: 

282 raise MemoryConversionError( 

283 source_type=MemoryType.PYCLESPERANTO.value, 

284 target_type=MemoryType.JAX.value, 

285 method="GPU_conversion", 

286 reason=str(e) 

287 ) from e 

288 

289 # Fallback: CPU roundtrip 

290 numpy_data = cle.pull(data) 

291 result = jax.numpy.array(numpy_data) 

292 

293 if device_id is not None: 

294 result = jax.device_put(result, jax.devices("gpu")[device_id]) 

295 

296 return result 

297 

298 

299# CuPy conversion functions 

300 

301def _cupy_to_numpy(data: Any) -> Any: 

302 """ 

303 Convert cupy array to numpy array. 

304 

305 Args: 

306 data: The cupy array to convert 

307 

308 Returns: 

309 The converted numpy array 

310 """ 

311 return data.get() 

312 

313 

314def _cupy_to_cupy(data: Any) -> Any: 

315 """Convert cupy array to cupy array (identity operation).""" 

316 return data.copy() 

317 

318 

319def _cupy_to_torch(data: Any, allow_cpu_roundtrip: bool = False, device_id: Optional[int] = None) -> Any: 

320 """ 

321 Convert cupy array to torch tensor, staying on GPU. 

322 

323 Args: 

324 data: The cupy array to convert 

325 allow_cpu_roundtrip: Whether to allow fallback to CPU roundtrip 

326 device_id: The target GPU device ID (optional) 

327 

328 Returns: 

329 The converted torch tensor 

330 

331 Raises: 

332 MemoryConversionError: If conversion fails and CPU fallback is not authorized 

333 ImportError: If torch is not installed 

334 """ 

335 torch = _ensure_module("torch") 

336 

337 # Use DLPack for zero-copy GPU-to-GPU conversion 

338 if _supports_dlpack(data): 

339 try: 

340 dlpack = data.toDlpack() 

341 result = torch.from_dlpack(dlpack) 

342 

343 # Move to specified device if needed 

344 if device_id is not None: 

345 target_device = f"cuda:{device_id}" 

346 if str(result.device) != target_device: 

347 result = result.to(target_device) 

348 

349 return result 

350 except Exception as e: 

351 if not allow_cpu_roundtrip: 

352 raise MemoryConversionError( 

353 source_type=MemoryType.CUPY.value, 

354 target_type=MemoryType.TORCH.value, 

355 method="DLPack", 

356 reason=str(e) 

357 ) from e 

358 

359 # Fallback to CUDA Array Interface 

360 elif _supports_cuda_array_interface(data): 

361 print(f"🔥 CONVERSION DEBUG: CUDA Array Interface supported, data shape: {data.shape}") 

362 try: 

363 print(f"🔥 CONVERSION DEBUG: About to call torch.as_tensor...") 

364 if device_id is not None: 

365 result = torch.as_tensor(data, device=f"cuda:{device_id}") 

366 else: 

367 result = torch.as_tensor(data, device="cuda") 

368 print(f"🔥 CONVERSION DEBUG: torch.as_tensor completed successfully") 

369 return result 

370 except Exception as e: 

371 print(f"🔥 CONVERSION DEBUG: torch.as_tensor failed with error: {e}") 

372 if not allow_cpu_roundtrip: 

373 raise MemoryConversionError( 

374 source_type=MemoryType.CUPY.value, 

375 target_type=MemoryType.TORCH.value, 

376 method="CUDA Array Interface", 

377 reason=str(e) 

378 ) from e 

379 else: 

380 if not allow_cpu_roundtrip: 

381 raise MemoryConversionError( 

382 source_type=MemoryType.CUPY.value, 

383 target_type=MemoryType.TORCH.value, 

384 method="CUDA Array Interface", 

385 reason="CuPy array does not support CUDA Array Interface" 

386 ) 

387 

388 # Only reach here if allow_cpu_roundtrip=True 

389 tensor = torch.from_numpy(data.get()) 

390 

391 # Move to specified device if needed 

392 if device_id is not None: 

393 tensor = tensor.to(f"cuda:{device_id}") 

394 

395 return tensor 

396 

397 

398def _cupy_to_tensorflow(data: Any, allow_cpu_roundtrip: bool = False, device_id: Optional[int] = None) -> Any: 

399 """ 

400 Convert cupy array to tensorflow tensor, staying on GPU. 

401 

402 Args: 

403 data: The cupy array to convert 

404 allow_cpu_roundtrip: Whether to allow fallback to CPU roundtrip 

405 device_id: The target GPU device ID (optional) 

406 

407 Returns: 

408 The converted tensorflow tensor 

409 

410 Raises: 

411 MemoryConversionError: If conversion fails and CPU fallback is not authorized 

412 ImportError: If tensorflow is not installed 

413 RuntimeError: If TensorFlow version is < 2.12 (unstable DLPack support) 

414 """ 

415 tf = _ensure_module("tensorflow") 

416 

417 # Try using DLPack if supported 

418 # _supports_dlpack will raise RuntimeError if TF version < 2.12 or tensor is on CPU 

419 # This enforces Clause 88 (No Inferred Capabilities) 

420 if _supports_dlpack(data): 

421 try: 

422 dlpack = data.toDlpack() 

423 tensor = tf.experimental.dlpack.from_dlpack(dlpack) 

424 

425 # Move to specified device if needed 

426 if device_id is not None: 

427 with tf.device(f"/device:GPU:{device_id}"): 

428 return tf.identity(tensor) 

429 

430 return tensor 

431 except Exception as e: 

432 if not allow_cpu_roundtrip: 

433 raise MemoryConversionError( 

434 source_type=MemoryType.CUPY.value, 

435 target_type=MemoryType.TENSORFLOW.value, 

436 method="DLPack", 

437 reason=str(e) 

438 ) from e 

439 elif not allow_cpu_roundtrip: 

440 raise MemoryConversionError( 

441 source_type=MemoryType.CUPY.value, 

442 target_type=MemoryType.TENSORFLOW.value, 

443 method="DLPack", 

444 reason="DLPack conversion not supported" 

445 ) 

446 

447 # Only reach here if allow_cpu_roundtrip=True 

448 if device_id is not None: 

449 with tf.device(f"/device:GPU:{device_id}"): 

450 return tf.convert_to_tensor(data.get()) 

451 

452 return tf.convert_to_tensor(data.get()) 

453 

454 

455def _cupy_to_pyclesperanto(data: Any, allow_cpu_roundtrip: bool = False, device_id: Optional[int] = None) -> Any: 

456 """ 

457 Convert cupy array to pyclesperanto array, staying on GPU. 

458 

459 Args: 

460 data: The cupy array to convert 

461 allow_cpu_roundtrip: Whether to allow fallback to CPU roundtrip 

462 device_id: The target GPU device ID (optional) 

463 

464 Returns: 

465 The converted pyclesperanto array 

466 

467 Raises: 

468 MemoryConversionError: If conversion fails and CPU fallback is not authorized 

469 ImportError: If pyclesperanto is not installed 

470 """ 

471 cle = _ensure_module("pyclesperanto") 

472 

473 # Try direct GPU conversion first 

474 try: 

475 # Get current CuPy device 

476 current_device = data.device.id 

477 

478 # Select appropriate pyclesperanto device 

479 if device_id is not None: 

480 target_device = device_id 

481 else: 

482 target_device = current_device 

483 

484 devices = cle.list_available_devices() 

485 if target_device >= len(devices): 

486 if not allow_cpu_roundtrip: 

487 raise MemoryConversionError( 

488 source_type=MemoryType.CUPY.value, 

489 target_type=MemoryType.PYCLESPERANTO.value, 

490 method="device_selection", 

491 reason=f"GPU ID {target_device} not available in pyclesperanto" 

492 ) 

493 else: 

494 cle.select_device(target_device) 

495 

496 # Convert via numpy (pyclesperanto doesn't have direct CuPy interop) 

497 numpy_data = data.get() # CuPy to NumPy 

498 return cle.push(numpy_data) # NumPy to pyclesperanto 

499 

500 except Exception as e: 

501 if not allow_cpu_roundtrip: 

502 raise MemoryConversionError( 

503 source_type=MemoryType.CUPY.value, 

504 target_type=MemoryType.PYCLESPERANTO.value, 

505 method="GPU_conversion", 

506 reason=str(e) 

507 ) from e 

508 

509 # Fallback: CPU roundtrip 

510 numpy_data = data.get() 

511 cle.select_device(device_id if device_id is not None else 0) 

512 return cle.push(numpy_data) 

513 

514 

515def _pyclesperanto_to_cupy(data: Any, allow_cpu_roundtrip: bool = False, device_id: Optional[int] = None) -> Any: 

516 """ 

517 Convert pyclesperanto array to cupy array, staying on GPU. 

518 

519 Args: 

520 data: The pyclesperanto array to convert 

521 allow_cpu_roundtrip: Whether to allow fallback to CPU roundtrip 

522 device_id: The target GPU device ID (optional) 

523 

524 Returns: 

525 The converted cupy array 

526 

527 Raises: 

528 MemoryConversionError: If conversion fails and CPU fallback is not authorized 

529 ImportError: If cupy is not installed 

530 """ 

531 cupy = _ensure_module("cupy") 

532 cle = _ensure_module("pyclesperanto") 

533 

534 try: 

535 # Convert via numpy (pyclesperanto doesn't have direct CuPy interop) 

536 numpy_data = cle.pull(data) # pyclesperanto to NumPy 

537 

538 # Convert to CuPy on specified device 

539 if device_id is not None: 

540 with cupy.cuda.Device(device_id): 

541 return cupy.array(numpy_data) 

542 else: 

543 return cupy.array(numpy_data) 

544 

545 except Exception as e: 

546 if not allow_cpu_roundtrip: 

547 raise MemoryConversionError( 

548 source_type=MemoryType.PYCLESPERANTO.value, 

549 target_type=MemoryType.CUPY.value, 

550 method="GPU_conversion", 

551 reason=str(e) 

552 ) from e 

553 

554 # Fallback: CPU roundtrip (same as above) 

555 numpy_data = cle.pull(data) 

556 if device_id is not None: 

557 with cupy.cuda.Device(device_id): 

558 return cupy.array(numpy_data) 

559 else: 

560 return cupy.array(numpy_data) 

561 

562 

563def _cupy_to_jax(data: Any, allow_cpu_roundtrip: bool = False, device_id: Optional[int] = None) -> Any: 

564 """ 

565 Convert cupy array to JAX array, staying on GPU with zero-copy DLPack transfer. 

566 

567 Args: 

568 data: The cupy array to convert 

569 allow_cpu_roundtrip: Whether to allow fallback to CPU roundtrip (ignored, always False) 

570 device_id: The target GPU device ID (optional) 

571 

572 Returns: 

573 The converted JAX array 

574 

575 Raises: 

576 MemoryConversionError: If conversion fails 

577 ImportError: If JAX is not installed 

578 """ 

579 jax = _ensure_module("jax") 

580 

581 # Check if CuPy array is on GPU (should always be true for CuPy) 

582 if not hasattr(data, 'device'): 

583 raise MemoryConversionError( 

584 source_type=MemoryType.CUPY.value, 

585 target_type=MemoryType.JAX.value, 

586 method="device_detection", 

587 reason="CuPy array does not have a device attribute" 

588 ) 

589 

590 # Try using DLPack for direct GPU-to-GPU transfer 

591 if _supports_dlpack(data): 

592 try: 

593 dlpack = data.toDlpack() 

594 result = jax.dlpack.from_dlpack(dlpack) 

595 

596 # Move to specified device if needed 

597 if device_id is not None: 

598 current_device = None 

599 try: 

600 # Extract device ID from JAX array 

601 device_str = str(result.device) 

602 if "gpu:" in device_str: 

603 current_device = int(device_str.split("gpu:")[-1].split(")")[0]) 

604 except Exception: 

605 pass 

606 

607 # Only move if needed 

608 if current_device != device_id: 

609 result = jax.device_put(result, jax.devices("gpu")[device_id]) 

610 

611 return result 

612 except Exception as e: 

613 # No CPU roundtrip allowed, so fail loudly 

614 raise MemoryConversionError( 

615 source_type=MemoryType.CUPY.value, 

616 target_type=MemoryType.JAX.value, 

617 method="DLPack", 

618 reason=str(e) 

619 ) from e 

620 else: 

621 # No CPU roundtrip allowed, so fail loudly 

622 raise MemoryConversionError( 

623 source_type=MemoryType.CUPY.value, 

624 target_type=MemoryType.JAX.value, 

625 method="DLPack", 

626 reason="CuPy array does not support DLPack" 

627 ) 

628 

629 

630# PyTorch conversion functions 

631 

632def _torch_to_numpy(data: Any) -> Any: 

633 """ 

634 Convert torch tensor to numpy array. 

635 

636 Args: 

637 data: The torch tensor to convert 

638 

639 Returns: 

640 The converted numpy array 

641 """ 

642 return data.detach().cpu().numpy() 

643 

644 

645def _torch_to_cupy(data: Any, allow_cpu_roundtrip: bool = False, device_id: Optional[int] = None) -> Any: 

646 """ 

647 Convert torch tensor to cupy array, staying on GPU. 

648 

649 Args: 

650 data: The torch tensor to convert 

651 allow_cpu_roundtrip: Whether to allow fallback to CPU roundtrip 

652 device_id: The target GPU device ID (optional) 

653 

654 Returns: 

655 The converted cupy array 

656 

657 Raises: 

658 MemoryConversionError: If conversion fails and CPU fallback is not authorized 

659 ImportError: If cupy is not installed 

660 """ 

661 cupy = _ensure_module("cupy") 

662 

663 # Only attempt direct conversion if tensor is on CUDA 

664 if data.is_cuda: 

665 # Try using CUDA Array Interface 

666 if _supports_cuda_array_interface(data): 

667 try: 

668 result = cupy.asarray(data) 

669 

670 # Move to specified device if needed 

671 if device_id is not None and result.device.id != device_id: 

672 with cupy.cuda.Device(device_id): 

673 return result.copy() 

674 

675 return result 

676 except Exception as e: 

677 if not allow_cpu_roundtrip: 

678 raise MemoryConversionError( 

679 source_type=MemoryType.TORCH.value, 

680 target_type=MemoryType.CUPY.value, 

681 method="CUDA Array Interface", 

682 reason=str(e) 

683 ) from e 

684 

685 # Try using DLPack 

686 if _supports_dlpack(data): 

687 try: 

688 dlpack = data.to_dlpack() 

689 result = cupy.from_dlpack(dlpack) 

690 

691 # Move to specified device if needed 

692 if device_id is not None and result.device.id != device_id: 

693 with cupy.cuda.Device(device_id): 

694 return result.copy() 

695 

696 return result 

697 except Exception as e: 

698 if not allow_cpu_roundtrip: 

699 raise MemoryConversionError( 

700 source_type=MemoryType.TORCH.value, 

701 target_type=MemoryType.CUPY.value, 

702 method="DLPack", 

703 reason=str(e) 

704 ) from e 

705 elif not allow_cpu_roundtrip: 

706 raise MemoryConversionError( 

707 source_type=MemoryType.TORCH.value, 

708 target_type=MemoryType.CUPY.value, 

709 method="GPU-native", 

710 reason="PyTorch tensor is not on CUDA" 

711 ) 

712 

713 # Only reach here if allow_cpu_roundtrip=True 

714 if device_id is not None: 

715 with cupy.cuda.Device(device_id): 

716 return cupy.array(data.detach().cpu().numpy()) 

717 

718 return cupy.array(data.detach().cpu().numpy()) 

719 

720 

721def _torch_to_torch(data: Any, device_id: Optional[int] = None) -> Any: 

722 """ 

723 Convert torch tensor to torch tensor (identity operation). 

724 

725 Args: 

726 data: The torch tensor to convert 

727 device_id: The target GPU device ID (optional) 

728 

729 Returns: 

730 The cloned torch tensor, possibly on a different device 

731 """ 

732 result = data.clone() 

733 

734 # Move to specified device if needed 

735 if device_id is not None: 

736 if data.is_cuda and data.device.index != device_id: 

737 result = result.to(f"cuda:{device_id}") 

738 elif not data.is_cuda: 

739 result = result.to(f"cuda:{device_id}") 

740 

741 return result 

742 

743 

744def _torch_to_tensorflow(data: Any, allow_cpu_roundtrip: bool = False, device_id: Optional[int] = None) -> Any: 

745 """ 

746 Convert torch tensor to tensorflow tensor, staying on GPU. 

747 

748 Args: 

749 data: The torch tensor to convert 

750 allow_cpu_roundtrip: Whether to allow fallback to CPU roundtrip 

751 device_id: The target GPU device ID (optional) 

752 

753 Returns: 

754 The converted tensorflow tensor 

755 

756 Raises: 

757 MemoryConversionError: If conversion fails and CPU fallback is not authorized 

758 ImportError: If tensorflow is not installed 

759 RuntimeError: If TensorFlow version is < 2.12 (unstable DLPack support) 

760 """ 

761 tf = _ensure_module("tensorflow") 

762 

763 # Only attempt direct conversion if tensor is on CUDA 

764 if data.is_cuda: 

765 # Check TensorFlow version for DLPack compatibility 

766 try: 

767 # This will check TF version and raise RuntimeError if < 2.12 

768 # Enforces Clause 88 (No Inferred Capabilities) 

769 tf_version = tf.__version__ 

770 major, minor = map(int, tf_version.split('.')[:2]) 

771 

772 if major < 2 or (major == 2 and minor < 12): 

773 raise RuntimeError( 

774 f"TensorFlow version {tf_version} does not support stable DLPack operations. " 

775 f"Version 2.12.0 or higher is required. " 

776 f"Clause 88 violation: Cannot infer DLPack capability." 

777 ) 

778 

779 # Check if experimental.dlpack module exists 

780 if not hasattr(tf.experimental, "dlpack"): 

781 raise RuntimeError( 

782 "TensorFlow installation missing experimental.dlpack module. " 

783 "Clause 88 violation: Cannot infer DLPack capability." 

784 ) 

785 

786 # Now try the conversion 

787 try: 

788 dlpack = data.to_dlpack() 

789 tensor = tf.experimental.dlpack.from_dlpack(dlpack) 

790 

791 # Move to specified device if needed 

792 if device_id is not None: 

793 with tf.device(f"/device:GPU:{device_id}"): 

794 return tf.identity(tensor) 

795 

796 return tensor 

797 except Exception as e: 

798 if not allow_cpu_roundtrip: 

799 raise MemoryConversionError( 

800 source_type=MemoryType.TORCH.value, 

801 target_type=MemoryType.TENSORFLOW.value, 

802 method="DLPack", 

803 reason=str(e) 

804 ) from e 

805 except RuntimeError as e: 

806 if not allow_cpu_roundtrip: 

807 raise MemoryConversionError( 

808 source_type=MemoryType.TORCH.value, 

809 target_type=MemoryType.TENSORFLOW.value, 

810 method="DLPack", 

811 reason=str(e) 

812 ) from e 

813 

814 # If we get here, either the tensor is not on CUDA or there was a DLPack issue 

815 if not allow_cpu_roundtrip: 

816 raise MemoryConversionError( 

817 source_type=MemoryType.TORCH.value, 

818 target_type=MemoryType.TENSORFLOW.value, 

819 method="GPU-native", 

820 reason="PyTorch tensor is not on CUDA or TensorFlow DLPack support issue" 

821 ) 

822 

823 # Only reach here if allow_cpu_roundtrip=True 

824 # This is an explicit CPU roundtrip, which is only allowed if explicitly requested 

825 numpy_data = data.detach().cpu().numpy() 

826 

827 if device_id is not None: 

828 with tf.device(f"/device:GPU:{device_id}"): 

829 return tf.convert_to_tensor(numpy_data) 

830 

831 return tf.convert_to_tensor(numpy_data) 

832 

833 

834# TensorFlow conversion functions 

835 

836def _tensorflow_to_numpy(data: Any) -> Any: 

837 """ 

838 Convert tensorflow tensor to numpy array. 

839 

840 Args: 

841 data: The tensorflow tensor to convert 

842 

843 Returns: 

844 The converted numpy array 

845 """ 

846 return data.numpy() 

847 

848 

849def _torch_to_pyclesperanto(data: Any, allow_cpu_roundtrip: bool = False, device_id: Optional[int] = None) -> Any: 

850 """ 

851 Convert torch tensor to pyclesperanto array, staying on GPU. 

852 

853 Args: 

854 data: The torch tensor to convert 

855 allow_cpu_roundtrip: Whether to allow fallback to CPU roundtrip 

856 device_id: The target GPU device ID (optional) 

857 

858 Returns: 

859 The converted pyclesperanto array 

860 

861 Raises: 

862 MemoryConversionError: If conversion fails and CPU fallback is not authorized 

863 ImportError: If pyclesperanto is not installed 

864 """ 

865 cle = _ensure_module("pyclesperanto") 

866 

867 # Try GPU-to-GPU conversion first 

868 if data.is_cuda: 

869 try: 

870 # Use CUDA array interface for zero-copy conversion 

871 if _supports_cuda_array_interface(data): 

872 # Select target device 

873 target_device = device_id if device_id is not None else data.device.index 

874 cle.select_device(target_device) 

875 

876 # Convert via CUDA array interface 

877 return cle.asarray(data.detach()) 

878 except Exception as e: 

879 if not allow_cpu_roundtrip: 

880 raise MemoryConversionError( 

881 source_type=MemoryType.TORCH.value, 

882 target_type=MemoryType.PYCLESPERANTO.value, 

883 method="GPU_conversion", 

884 reason=str(e) 

885 ) from e 

886 

887 # Fallback: CPU roundtrip 

888 if not allow_cpu_roundtrip: 

889 raise MemoryConversionError( 

890 source_type=MemoryType.TORCH.value, 

891 target_type=MemoryType.PYCLESPERANTO.value, 

892 method="GPU-native", 

893 reason="PyTorch tensor is not on CUDA" 

894 ) 

895 

896 # CPU roundtrip conversion 

897 numpy_data = data.detach().cpu().numpy() 

898 cle.select_device(device_id if device_id is not None else 0) 

899 return cle.push(numpy_data) 

900 

901 

902def _torch_to_jax(data: Any, allow_cpu_roundtrip: bool = False, device_id: Optional[int] = None) -> Any: 

903 """ 

904 Convert PyTorch tensor to JAX array, staying on GPU with zero-copy DLPack transfer. 

905 

906 Args: 

907 data: The PyTorch tensor to convert 

908 allow_cpu_roundtrip: Whether to allow fallback to CPU roundtrip (ignored, always False) 

909 device_id: The target GPU device ID (optional) 

910 

911 Returns: 

912 The converted JAX array 

913 

914 Raises: 

915 MemoryConversionError: If conversion fails 

916 ImportError: If JAX is not installed 

917 """ 

918 jax = _ensure_module("jax") 

919 torch = _ensure_module("torch") 

920 

921 # If tensor is on CPU, move it to GPU first (similar to _numpy_to_jax behavior) 

922 if not data.is_cuda: 

923 if device_id is not None: 

924 # Move CPU tensor to specified GPU device 

925 data = data.to(f"cuda:{device_id}") 

926 else: 

927 # Move to default GPU device 

928 data = data.cuda() 

929 

930 # Now attempt direct conversion with tensor on CUDA 

931 if data.is_cuda: 

932 # Try using DLPack for direct GPU-to-GPU transfer 

933 if _supports_dlpack(data): 

934 try: 

935 dlpack = torch.to_dlpack(data) 

936 result = jax.dlpack.from_dlpack(dlpack) 

937 

938 # Move to specified device if needed 

939 if device_id is not None: 

940 current_device = None 

941 try: 

942 # Extract device ID from JAX array 

943 device_str = str(result.device) 

944 if "gpu:" in device_str or "cuda:" in device_str: 

945 current_device = int(device_str.split("gpu:")[-1].split(")")[0]) 

946 except Exception: 

947 pass 

948 

949 # Only move if needed 

950 if current_device != device_id: 

951 result = jax.device_put(result, jax.devices("gpu")[device_id]) 

952 

953 return result 

954 except Exception as e: 

955 # No CPU roundtrip allowed, so fail loudly 

956 raise MemoryConversionError( 

957 source_type=MemoryType.TORCH.value, 

958 target_type=MemoryType.JAX.value, 

959 method="DLPack", 

960 reason=str(e) 

961 ) from e 

962 

963 # If we get here, there was a DLPack issue (tensor should be on CUDA at this point) 

964 raise MemoryConversionError( 

965 source_type=MemoryType.TORCH.value, 

966 target_type=MemoryType.JAX.value, 

967 method="GPU-native", 

968 reason="DLPack conversion failed after moving tensor to CUDA" 

969 ) 

970 

971 

972# JAX conversion functions 

973 

974def _numpy_to_jax(data: Any, gpu_id: int) -> Any: 

975 """ 

976 Convert numpy array to JAX array. 

977 

978 Args: 

979 data: The numpy array to convert 

980 gpu_id: The target GPU device ID 

981 

982 Returns: 

983 The converted JAX array 

984 

985 Raises: 

986 ImportError: If JAX is not installed 

987 ValueError: If gpu_id is negative 

988 """ 

989 jax = _ensure_module("jax") 

990 

991 # Validate gpu_id 

992 if gpu_id < 0: 

993 raise ValueError(f"Invalid GPU ID: {gpu_id}. Must be a non-negative integer.") 

994 

995 # Create JAX array on CPU 

996 result = jax.numpy.array(data) 

997 

998 # Always move to the specified GPU device 

999 # JAX uses different device notation 

1000 result = jax.device_put(result, jax.devices("gpu")[gpu_id]) 

1001 

1002 return result 

1003 

1004 

1005def _jax_to_numpy(data: Any) -> Any: 

1006 """ 

1007 Convert JAX array to numpy array. 

1008 

1009 Args: 

1010 data: The JAX array to convert 

1011 

1012 Returns: 

1013 The converted numpy array 

1014 """ 

1015 # JAX arrays can be converted to numpy with .copy() 

1016 return data.copy() 

1017 

1018 

1019def _jax_to_cupy( 

1020 data: Any, allow_cpu_roundtrip: bool = False, device_id: Optional[int] = None 

1021) -> Any: 

1022 """ 

1023 Convert JAX array to cupy array, staying on GPU if possible. 

1024 

1025 Args: 

1026 data: The JAX array to convert 

1027 allow_cpu_roundtrip: Whether to allow fallback to CPU roundtrip 

1028 device_id: The target GPU device ID (optional) 

1029 

1030 Returns: 

1031 The converted cupy array 

1032 

1033 Raises: 

1034 MemoryConversionError: If conversion fails and CPU fallback is not authorized 

1035 ImportError: If cupy is not installed 

1036 """ 

1037 jax = _ensure_module("jax") 

1038 cupy = _ensure_module("cupy") 

1039 

1040 # Check if JAX array is on GPU 

1041 device_str = str(data.device).lower() 

1042 is_on_gpu = device_str.startswith("gpu") or device_str.startswith("cuda") 

1043 

1044 if is_on_gpu: 

1045 # Try using DLPack for direct GPU-to-GPU transfer 

1046 if _supports_dlpack(data): 

1047 try: 

1048 dlpack = jax.dlpack.to_dlpack(data) 

1049 result = cupy.from_dlpack(dlpack) 

1050 

1051 # Move to specified device if needed 

1052 if device_id is not None and result.device.id != device_id: 

1053 with cupy.cuda.Device(device_id): 

1054 return result.copy() 

1055 

1056 return result 

1057 except Exception as e: 

1058 if not allow_cpu_roundtrip: 

1059 raise MemoryConversionError( 

1060 source_type=MemoryType.JAX.value, 

1061 target_type=MemoryType.CUPY.value, 

1062 method="DLPack", 

1063 reason=str(e) 

1064 ) from e 

1065 elif not allow_cpu_roundtrip: 

1066 raise MemoryConversionError( 

1067 source_type=MemoryType.JAX.value, 

1068 target_type=MemoryType.CUPY.value, 

1069 method="GPU-native", 

1070 reason="JAX array is not on GPU" 

1071 ) 

1072 

1073 # Only reach here if allow_cpu_roundtrip=True or DLPack failed 

1074 numpy_data = _jax_to_numpy(data) 

1075 

1076 if device_id is not None: 

1077 with cupy.cuda.Device(device_id): 

1078 return cupy.array(numpy_data) 

1079 

1080 return cupy.array(numpy_data) 

1081 

1082 

1083def _jax_to_torch( 

1084 data: Any, allow_cpu_roundtrip: bool = False, device_id: Optional[int] = None 

1085) -> Any: 

1086 """ 

1087 Convert JAX array to torch tensor, staying on GPU if possible. 

1088 

1089 Args: 

1090 data: The JAX array to convert 

1091 allow_cpu_roundtrip: Whether to allow fallback to CPU roundtrip 

1092 device_id: The target GPU device ID (optional) 

1093 

1094 Returns: 

1095 The converted torch tensor 

1096 

1097 Raises: 

1098 MemoryConversionError: If conversion fails and CPU fallback is not authorized 

1099 ImportError: If torch is not installed 

1100 """ 

1101 jax = _ensure_module("jax") 

1102 torch = _ensure_module("torch") 

1103 

1104 # Check if JAX array is on GPU 

1105 device_str = str(data.device).lower() 

1106 is_on_gpu = device_str.startswith("gpu") or device_str.startswith("cuda") 

1107 

1108 if is_on_gpu: 

1109 # Try using DLPack for direct GPU-to-GPU transfer 

1110 if _supports_dlpack(data): 

1111 try: 

1112 dlpack = jax.dlpack.to_dlpack(data) 

1113 tensor = torch.from_dlpack(dlpack) 

1114 

1115 # Move to specified device if needed 

1116 if device_id is not None and tensor.device.index != device_id: 

1117 tensor = tensor.to(f"cuda:{device_id}") 

1118 

1119 return tensor 

1120 except Exception as e: 

1121 if not allow_cpu_roundtrip: 

1122 raise MemoryConversionError( 

1123 source_type=MemoryType.JAX.value, 

1124 target_type=MemoryType.TORCH.value, 

1125 method="DLPack", 

1126 reason=str(e) 

1127 ) from e 

1128 elif not allow_cpu_roundtrip: 

1129 raise MemoryConversionError( 

1130 source_type=MemoryType.JAX.value, 

1131 target_type=MemoryType.TORCH.value, 

1132 method="GPU-native", 

1133 reason="JAX array is not on GPU" 

1134 ) 

1135 

1136 # Only reach here if allow_cpu_roundtrip=True or DLPack failed 

1137 numpy_data = _jax_to_numpy(data) 

1138 

1139 if device_id is not None: 

1140 return torch.tensor(numpy_data, device=f"cuda:{device_id}") 

1141 

1142 return torch.tensor(numpy_data) 

1143 

1144 

1145def _jax_to_tensorflow( 

1146 data: Any, allow_cpu_roundtrip: bool = False, device_id: Optional[int] = None 

1147) -> Any: 

1148 """ 

1149 Convert JAX array to tensorflow tensor, staying on GPU if possible. 

1150 

1151 Args: 

1152 data: The JAX array to convert 

1153 allow_cpu_roundtrip: Whether to allow fallback to CPU roundtrip 

1154 device_id: The target GPU device ID (optional) 

1155 

1156 Returns: 

1157 The converted tensorflow tensor 

1158 

1159 Raises: 

1160 MemoryConversionError: If conversion fails and CPU fallback is not authorized 

1161 ImportError: If tensorflow is not installed 

1162 """ 

1163 jax = _ensure_module("jax") 

1164 tf = _ensure_module("tensorflow") 

1165 

1166 # Check if JAX array is on GPU 

1167 device_str = str(data.device).lower() 

1168 is_on_gpu = device_str.startswith("gpu") or device_str.startswith("cuda") 

1169 

1170 if is_on_gpu: 

1171 # Try using DLPack for direct GPU-to-GPU transfer 

1172 if _supports_dlpack(data): 

1173 try: 

1174 dlpack = jax.dlpack.to_dlpack(data) 

1175 tensor = tf.experimental.dlpack.from_dlpack(dlpack) 

1176 

1177 # Move to specified device if needed 

1178 if device_id is not None: 

1179 with tf.device(f"/device:GPU:{device_id}"): 

1180 return tf.identity(tensor) 

1181 

1182 return tensor 

1183 except Exception as e: 

1184 if not allow_cpu_roundtrip: 

1185 raise MemoryConversionError( 

1186 source_type=MemoryType.JAX.value, 

1187 target_type=MemoryType.TENSORFLOW.value, 

1188 method="DLPack", 

1189 reason=str(e) 

1190 ) from e 

1191 elif not allow_cpu_roundtrip: 

1192 raise MemoryConversionError( 

1193 source_type=MemoryType.JAX.value, 

1194 target_type=MemoryType.TENSORFLOW.value, 

1195 method="GPU-native", 

1196 reason="JAX array is not on GPU" 

1197 ) 

1198 

1199 # Only reach here if allow_cpu_roundtrip=True or DLPack failed 

1200 numpy_data = _jax_to_numpy(data) 

1201 

1202 if device_id is not None: 

1203 with tf.device(f"/device:GPU:{device_id}"): 

1204 return tf.convert_to_tensor(numpy_data) 

1205 

1206 return tf.convert_to_tensor(numpy_data) 

1207 

1208 

1209def _tensorflow_to_cupy( 

1210 data: Any, allow_cpu_roundtrip: bool = False, device_id: Optional[int] = None 

1211) -> Any: 

1212 """ 

1213 Convert tensorflow tensor to cupy array, staying on GPU. 

1214 

1215 Args: 

1216 data: The tensorflow tensor to convert 

1217 allow_cpu_roundtrip: Whether to allow fallback to CPU roundtrip 

1218 device_id: The target GPU device ID (optional) 

1219 

1220 Returns: 

1221 The converted cupy array 

1222 

1223 Raises: 

1224 MemoryConversionError: If conversion fails and CPU fallback is not authorized 

1225 ImportError: If cupy is not installed 

1226 RuntimeError: If TensorFlow version is < 2.12 (unstable DLPack support) 

1227 """ 

1228 cupy = _ensure_module("cupy") 

1229 tf = _ensure_module("tensorflow") 

1230 

1231 # _supports_dlpack will raise RuntimeError if TF version < 2.12 or tensor is on CPU 

1232 # This enforces Clause 88 (No Inferred Capabilities) 

1233 try: 

1234 if _supports_dlpack(data): 

1235 try: 

1236 dlpack = tf.experimental.dlpack.to_dlpack(data) 

1237 result = cupy.from_dlpack(dlpack) 

1238 

1239 # Move to specified device if needed 

1240 if device_id is not None and result.device.id != device_id: 

1241 with cupy.cuda.Device(device_id): 

1242 return result.copy() 

1243 

1244 return result 

1245 except Exception as e: 

1246 if not allow_cpu_roundtrip: 

1247 raise MemoryConversionError( 

1248 source_type=MemoryType.TENSORFLOW.value, 

1249 target_type=MemoryType.CUPY.value, 

1250 method="DLPack", 

1251 reason=str(e) 

1252 ) from e 

1253 except RuntimeError as e: 

1254 if not allow_cpu_roundtrip: 

1255 raise MemoryConversionError( 

1256 source_type=MemoryType.TENSORFLOW.value, 

1257 target_type=MemoryType.CUPY.value, 

1258 method="DLPack", 

1259 reason=str(e) 

1260 ) from e 

1261 

1262 # Only reach here if allow_cpu_roundtrip=True or _supports_dlpack raised an exception 

1263 if not allow_cpu_roundtrip: 

1264 raise MemoryConversionError( 

1265 source_type=MemoryType.TENSORFLOW.value, 

1266 target_type=MemoryType.CUPY.value, 

1267 method="GPU-native", 

1268 reason="TensorFlow tensor is not on GPU or DLPack not supported" 

1269 ) 

1270 

1271 # Only reach here if allow_cpu_roundtrip=True 

1272 if device_id is not None: 

1273 with cupy.cuda.Device(device_id): 

1274 return cupy.array(data.numpy()) 

1275 

1276 return cupy.array(data.numpy()) 

1277 

1278 

1279def _tensorflow_to_torch( 

1280 data: Any, allow_cpu_roundtrip: bool = False, device_id: Optional[int] = None 

1281) -> Any: 

1282 """ 

1283 Convert tensorflow tensor to torch tensor, staying on GPU. 

1284 

1285 Args: 

1286 data: The tensorflow tensor to convert 

1287 allow_cpu_roundtrip: Whether to allow fallback to CPU roundtrip 

1288 device_id: The target GPU device ID (optional) 

1289 

1290 Returns: 

1291 The converted torch tensor 

1292 

1293 Raises: 

1294 MemoryConversionError: If conversion fails and CPU fallback is not authorized 

1295 ImportError: If torch is not installed 

1296 RuntimeError: If TensorFlow version is < 2.12 (unstable DLPack support) 

1297 """ 

1298 torch = _ensure_module("torch") 

1299 tf = _ensure_module("tensorflow") 

1300 

1301 # _supports_dlpack will raise RuntimeError if TF version < 2.12 or tensor is on CPU 

1302 # This enforces Clause 88 (No Inferred Capabilities) 

1303 try: 

1304 if _supports_dlpack(data): 

1305 try: 

1306 dlpack = tf.experimental.dlpack.to_dlpack(data) 

1307 tensor = torch.from_dlpack(dlpack) 

1308 

1309 # Move to specified device if needed 

1310 if device_id is not None and tensor.device.index != device_id: 

1311 tensor = tensor.to(f"cuda:{device_id}") 

1312 

1313 return tensor 

1314 except Exception as e: 

1315 if not allow_cpu_roundtrip: 

1316 raise MemoryConversionError( 

1317 source_type=MemoryType.TENSORFLOW.value, 

1318 target_type=MemoryType.TORCH.value, 

1319 method="DLPack", 

1320 reason=str(e) 

1321 ) from e 

1322 except RuntimeError as e: 

1323 if not allow_cpu_roundtrip: 

1324 raise MemoryConversionError( 

1325 source_type=MemoryType.TENSORFLOW.value, 

1326 target_type=MemoryType.TORCH.value, 

1327 method="DLPack", 

1328 reason=str(e) 

1329 ) from e 

1330 

1331 # Only reach here if allow_cpu_roundtrip=True or _supports_dlpack raised an exception 

1332 if not allow_cpu_roundtrip: 

1333 raise MemoryConversionError( 

1334 source_type=MemoryType.TENSORFLOW.value, 

1335 target_type=MemoryType.TORCH.value, 

1336 method="GPU-native", 

1337 reason="TensorFlow tensor is not on GPU or DLPack not supported" 

1338 ) 

1339 

1340 # Only reach here if allow_cpu_roundtrip=True 

1341 tensor = torch.from_numpy(data.numpy()) 

1342 

1343 # Move to specified device if needed 

1344 if device_id is not None: 

1345 tensor = tensor.to(f"cuda:{device_id}") 

1346 

1347 return tensor 

1348 

1349 

1350def _tensorflow_to_jax( 

1351 data: Any, allow_cpu_roundtrip: bool = False, device_id: Optional[int] = None 

1352) -> Any: 

1353 """ 

1354 Convert TensorFlow tensor to JAX array, staying on GPU with zero-copy DLPack transfer. 

1355 

1356 Args: 

1357 data: The TensorFlow tensor to convert 

1358 allow_cpu_roundtrip: Whether to allow fallback to CPU roundtrip (ignored, always False) 

1359 device_id: The target GPU device ID (optional) 

1360 

1361 Returns: 

1362 The converted JAX array 

1363 

1364 Raises: 

1365 MemoryConversionError: If conversion fails 

1366 ImportError: If JAX is not installed 

1367 RuntimeError: If TensorFlow version is < 2.12 (unstable DLPack support) 

1368 """ 

1369 jax = _ensure_module("jax") 

1370 tf = _ensure_module("tensorflow") 

1371 

1372 # Check TensorFlow version for DLPack compatibility 

1373 tf_version = tf.__version__ 

1374 major, minor = map(int, tf_version.split('.')[:2]) 

1375 

1376 if major < 2 or (major == 2 and minor < 12): 

1377 raise RuntimeError( 

1378 f"TensorFlow version {tf_version} does not support stable DLPack operations. " 

1379 f"Version 2.12.0 or higher is required. " 

1380 f"Clause 88 violation: Cannot infer DLPack capability." 

1381 ) 

1382 

1383 # Check if experimental.dlpack module exists 

1384 if not hasattr(tf.experimental, "dlpack"): 

1385 raise RuntimeError( 

1386 "TensorFlow installation missing experimental.dlpack module. " 

1387 "Clause 88 violation: Cannot infer DLPack capability." 

1388 ) 

1389 

1390 # Check if tensor is on GPU 

1391 device_str = data.device.lower() 

1392 is_on_gpu = "gpu" in device_str 

1393 

1394 if is_on_gpu: 

1395 # Try using DLPack for direct GPU-to-GPU transfer 

1396 try: 

1397 dlpack = tf.experimental.dlpack.to_dlpack(data) 

1398 result = jax.dlpack.from_dlpack(dlpack) 

1399 

1400 # Move to specified device if needed 

1401 if device_id is not None: 

1402 current_device = None 

1403 try: 

1404 # Extract device ID from JAX array 

1405 device_str = str(result.device) 

1406 if "gpu:" in device_str: 

1407 current_device = int(device_str.rsplit('gpu:', maxsplit=1)[-1].split(")")[0]) 

1408 except (ValueError, IndexError): 

1409 pass 

1410 

1411 # Only move if needed 

1412 if current_device != device_id: 

1413 result = jax.device_put(result, jax.devices("gpu")[device_id]) 

1414 

1415 return result 

1416 except Exception as e: 

1417 # No CPU roundtrip allowed, so fail loudly 

1418 raise MemoryConversionError( 

1419 source_type=MemoryType.TENSORFLOW.value, 

1420 target_type=MemoryType.JAX.value, 

1421 method="DLPack", 

1422 reason=str(e) 

1423 ) from e 

1424 

1425 # If we get here, the tensor is not on GPU 

1426 # No CPU roundtrip allowed, so fail loudly 

1427 raise MemoryConversionError( 

1428 source_type=MemoryType.TENSORFLOW.value, 

1429 target_type=MemoryType.JAX.value, 

1430 method="GPU-native", 

1431 reason="TensorFlow tensor is not on GPU" 

1432 ) 

1433 

1434 

1435def _tensorflow_to_tensorflow(data: Any, device_id: Optional[int] = None) -> Any: 

1436 """ 

1437 Convert tensorflow tensor to tensorflow tensor (identity operation). 

1438 

1439 Args: 

1440 data: The tensorflow tensor to convert 

1441 device_id: The target GPU device ID (optional) 

1442 

1443 Returns: 

1444 The copied tensorflow tensor, possibly on a different device 

1445 """ 

1446 tf = _ensure_module("tensorflow") 

1447 

1448 if device_id is not None: 

1449 with tf.device(f"/device:GPU:{device_id}"): 

1450 return tf.identity(data) 

1451 

1452 return tf.identity(data) 

1453 

1454 

1455def _tensorflow_to_pyclesperanto(data: Any, allow_cpu_roundtrip: bool = False, device_id: Optional[int] = None) -> Any: 

1456 """ 

1457 Convert tensorflow tensor to pyclesperanto array, staying on GPU. 

1458 

1459 Args: 

1460 data: The tensorflow tensor to convert 

1461 allow_cpu_roundtrip: Whether to allow fallback to CPU roundtrip 

1462 device_id: The target GPU device ID (optional) 

1463 

1464 Returns: 

1465 The converted pyclesperanto array 

1466 

1467 Raises: 

1468 MemoryConversionError: If conversion fails and CPU fallback is not authorized 

1469 ImportError: If pyclesperanto is not installed 

1470 """ 

1471 tf = _ensure_module("tensorflow") 

1472 cle = _ensure_module("pyclesperanto") 

1473 

1474 # Try GPU-to-GPU conversion first 

1475 try: 

1476 # Use DLPack for zero-copy conversion 

1477 if hasattr(tf.experimental, 'dlpack') and hasattr(tf.experimental.dlpack, 'to_dlpack'): 

1478 dlpack = tf.experimental.dlpack.to_dlpack(data) 

1479 

1480 # Select target device 

1481 target_device = device_id if device_id is not None else 0 

1482 cle.select_device(target_device) 

1483 

1484 # Convert from DLPack 

1485 return cle.from_dlpack(dlpack) 

1486 except Exception as e: 

1487 if not allow_cpu_roundtrip: 

1488 raise MemoryConversionError( 

1489 source_type=MemoryType.TENSORFLOW.value, 

1490 target_type=MemoryType.PYCLESPERANTO.value, 

1491 method="GPU_conversion", 

1492 reason=str(e) 

1493 ) from e 

1494 

1495 # Fallback: CPU roundtrip 

1496 numpy_data = data.numpy() 

1497 cle.select_device(device_id if device_id is not None else 0) 

1498 return cle.push(numpy_data) 

1499 

1500 

1501def _jax_to_jax(data: Any, device_id: Optional[int] = None) -> Any: 

1502 """ 

1503 Convert JAX array to JAX array (identity operation). 

1504 

1505 Args: 

1506 data: The JAX array to convert 

1507 device_id: The target GPU device ID (optional) 

1508 

1509 Returns: 

1510 The cloned JAX array, possibly on a different device 

1511 """ 

1512 jax = _ensure_module("jax") 

1513 

1514 result = data.copy() 

1515 

1516 # Move to specified device if needed 

1517 if device_id is not None: 

1518 result = jax.device_put(result, jax.devices("gpu")[device_id]) 

1519 

1520 return result 

1521 

1522 

1523def _jax_to_pyclesperanto(data: Any, allow_cpu_roundtrip: bool = False, device_id: Optional[int] = None) -> Any: 

1524 """ 

1525 Convert JAX array to pyclesperanto array, staying on GPU. 

1526 

1527 Args: 

1528 data: The JAX array to convert 

1529 allow_cpu_roundtrip: Whether to allow fallback to CPU roundtrip 

1530 device_id: The target GPU device ID (optional) 

1531 

1532 Returns: 

1533 The converted pyclesperanto array 

1534 

1535 Raises: 

1536 MemoryConversionError: If conversion fails and CPU fallback is not authorized 

1537 ImportError: If pyclesperanto is not installed 

1538 """ 

1539 jax = _ensure_module("jax") 

1540 cle = _ensure_module("pyclesperanto") 

1541 

1542 # Try GPU-to-GPU conversion first 

1543 try: 

1544 # Use DLPack for zero-copy conversion 

1545 if hasattr(data, '__dlpack__'): 

1546 dlpack = data.__dlpack__() 

1547 

1548 # Select target device 

1549 target_device = device_id if device_id is not None else 0 

1550 cle.select_device(target_device) 

1551 

1552 # Convert from DLPack 

1553 return cle.from_dlpack(dlpack) 

1554 except Exception as e: 

1555 if not allow_cpu_roundtrip: 

1556 raise MemoryConversionError( 

1557 source_type=MemoryType.JAX.value, 

1558 target_type=MemoryType.PYCLESPERANTO.value, 

1559 method="GPU_conversion", 

1560 reason=str(e) 

1561 ) from e 

1562 

1563 # Fallback: CPU roundtrip 

1564 numpy_data = _jax_to_numpy(data) 

1565 cle.select_device(device_id if device_id is not None else 0) 

1566 return cle.push(numpy_data)