Coverage for src/arraybridge/framework_config.py: 74%

80 statements  

« prev     ^ index     » next       coverage.py v7.11.0, created at 2025-11-03 05:09 +0000

1""" 

2Single source of truth for ALL framework-specific behavior. 

3 

4This module consolidates all framework-specific logic that was previously 

5scattered across utils.py, stack_utils.py, gpu_cleanup.py, dtype_scaling.py, 

6and framework_ops.py. 

7 

8Architecture: 

9- Framework handlers: Custom logic for special cases (pyclesperanto, JAX, TensorFlow) 

10- Unified config: Single _FRAMEWORK_CONFIG dict with all framework metadata 

11- Polymorphic dispatch: Handlers can be callables or eval expressions 

12""" 

13 

14import logging 

15from typing import Any, Callable 

16 

17from arraybridge.types import MemoryType 

18 

19logger = logging.getLogger(__name__) 

20 

21 

22# ============================================================================ 

23# FRAMEWORK HANDLERS - All special-case logic lives here 

24# ============================================================================ 

25 

26 

27def _pyclesperanto_get_device_id(data: Any, mod: Any) -> int: 

28 """Get device ID for pyclesperanto array.""" 

29 if mod is None: 

30 return 0 

31 try: 

32 current_device = mod.get_device() 

33 if hasattr(current_device, "id"): 

34 return current_device.id 

35 devices = mod.list_available_devices() 

36 for i, device in enumerate(devices): 

37 if str(device) == str(current_device): 

38 return i 

39 return 0 

40 except Exception as e: 

41 logger.warning(f"Failed to get device ID for pyclesperanto: {e}") 

42 return 0 

43 

44 

45def _pyclesperanto_set_device(device_id: int, mod: Any) -> None: 

46 """Set device for pyclesperanto.""" 

47 if mod is None: 

48 return 

49 devices = mod.list_available_devices() 

50 if device_id >= len(devices): 

51 raise ValueError(f"Device {device_id} not available. Available: {len(devices)}") 

52 mod.select_device(device_id) 

53 

54 

55def _pyclesperanto_move_to_device(data: Any, device_id: int, mod: Any, memory_type: str) -> Any: 

56 """Move pyclesperanto array to device.""" 

57 if mod is None: 

58 return data 

59 # Import here to avoid circular dependency 

60 from arraybridge.utils import _get_device_id 

61 

62 current_device_id = _get_device_id(data, memory_type) 

63 

64 if current_device_id != device_id: 

65 mod.select_device(device_id) 

66 result = mod.create_like(data) 

67 mod.copy(data, result) 

68 return result 

69 return data 

70 

71 

72def _pyclesperanto_stack_slices(slices: list, memory_type: str, gpu_id: int, mod: Any) -> Any: 

73 """Stack slices using pyclesperanto's concatenate_along_z.""" 

74 if mod is None: 

75 return None 

76 from arraybridge.converters import convert_memory, detect_memory_type 

77 

78 converted_slices = [] 

79 conversion_count = 0 

80 

81 for slice_data in slices: 

82 source_type = detect_memory_type(slice_data) 

83 

84 if source_type != memory_type: 

85 conversion_count += 1 

86 

87 if source_type == memory_type: 

88 converted_slices.append(slice_data) 

89 else: 

90 converted = convert_memory(slice_data, source_type, memory_type, gpu_id) 

91 converted_slices.append(converted) 

92 

93 # Log batch conversion 

94 if conversion_count > 0: 

95 logger.debug( 

96 f"🔄 MEMORY CONVERSION: Converted {conversion_count}/{len(slices)} slices " 

97 f"to {memory_type} for pyclesperanto stacking" 

98 ) 

99 

100 return mod.concatenate_along_z(converted_slices) 

101 

102 

103def _jax_assign_slice(result: Any, index: int, slice_data: Any) -> Any: 

104 """Assign slice to JAX array (immutable).""" 

105 if result is None: 

106 return None 

107 return result.at[index].set(slice_data) 

108 

109 

110def _tensorflow_validate_dlpack(obj: Any, mod: Any) -> bool: 

111 """Validate TensorFlow DLPack support.""" 

112 if mod is None: 

113 return False 

114 # Check version 

115 major, minor = map(int, mod.__version__.split(".")[:2]) 

116 if major < 2 or (major == 2 and minor < 12): 

117 raise RuntimeError( 

118 f"TensorFlow {mod.__version__} does not support stable DLPack. " 

119 f"Version 2.12.0+ required. " 

120 f"Clause 88 violation: Cannot infer DLPack capability." 

121 ) 

122 

123 # Check GPU 

124 """Validate TensorFlow DLPack support.""" 

125 # Check version 

126 major, minor = map(int, mod.__version__.split(".")[:2]) 

127 if major < 2 or (major == 2 and minor < 12): 

128 raise RuntimeError( 

129 f"TensorFlow {mod.__version__} does not support stable DLPack. " 

130 f"Version 2.12.0+ required. " 

131 f"Clause 88 violation: Cannot infer DLPack capability." 

132 ) 

133 

134 # Check GPU 

135 device_str = obj.device.lower() 

136 if "gpu" not in device_str: 

137 raise RuntimeError( 

138 "TensorFlow tensor on CPU cannot use DLPack operations reliably. " 

139 "Only GPU tensors are supported for DLPack operations. " 

140 "Clause 88 violation: Cannot infer GPU capability." 

141 ) 

142 

143 # Check module 

144 if not hasattr(mod.experimental, "dlpack"): 

145 raise RuntimeError( 

146 "TensorFlow installation missing experimental.dlpack module. " 

147 "Clause 88 violation: Cannot infer DLPack capability." 

148 ) 

149 

150 return True 

151 

152 

153def _numpy_dtype_conversion_needed(first_slice: Any, detect_memory_type_func: Callable) -> bool: 

154 """Check if NumPy needs dtype conversion (only for torch sources).""" 

155 source_type = detect_memory_type_func(first_slice) 

156 return source_type == MemoryType.TORCH.value 

157 

158 

159def _torch_dtype_conversion_needed(first_slice: Any, detect_memory_type_func: Callable) -> bool: 

160 """Torch always needs dtype conversion to get correct torch dtype.""" 

161 return True 

162 

163 

164# ============================================================================ 

165# UNIFIED FRAMEWORK CONFIGURATION 

166# ============================================================================ 

167 

168_FRAMEWORK_CONFIG = { 

169 MemoryType.NUMPY: { 

170 # Metadata 

171 "import_name": "numpy", 

172 "display_name": "NumPy", 

173 "is_gpu": False, 

174 # Device operations 

175 "get_device_id": None, # CPU 

176 "set_device": None, # CPU 

177 "move_to_device": None, # CPU 

178 # Stack operations 

179 "allocate_stack": "np.empty(stack_shape, dtype=dtype)", 

180 "allocate_context": None, 

181 "needs_dtype_conversion": _numpy_dtype_conversion_needed, # Callable 

182 "assign_slice": None, # Standard: result[i] = slice 

183 "stack_handler": None, # Standard stacking 

184 # Dtype scaling 

185 "scaling_ops": { 

186 "min": "result.min()", 

187 "max": "result.max()", 

188 "astype": "result.astype(target_dtype)", 

189 "check_float": "np.issubdtype(result.dtype, np.floating)", 

190 "check_int": "target_dtype in [np.uint8, np.uint16, np.uint32, np.int8, np.int16, np.int32]", # noqa: E501 

191 "clamp": "np.clip(result, min_val, max_val)", 

192 }, 

193 # Conversion operations 

194 "conversion_ops": { 

195 "to_numpy": "data", 

196 "from_numpy": "data", 

197 "from_dlpack": None, 

198 "move_to_device": "data", 

199 }, 

200 # DLPack 

201 "supports_dlpack": False, 

202 "validate_dlpack": None, 

203 # GPU/Cleanup 

204 "lazy_getter": None, 

205 "gpu_check": None, 

206 "stream_context": None, 

207 "device_context": None, 

208 "cleanup_ops": None, 

209 "has_oom_recovery": False, 

210 "oom_exception_types": [], 

211 "oom_string_patterns": ["cannot allocate memory", "memory exhausted"], 

212 "oom_clear_cache": "import gc; gc.collect()", 

213 }, 

214 MemoryType.CUPY: { 

215 # Metadata 

216 "import_name": "cupy", 

217 "display_name": "CuPy", 

218 "is_gpu": True, 

219 # Device operations (eval expressions) 

220 "get_device_id": "data.device.id", 

221 "get_device_id_fallback": "0", 

222 "set_device": "{mod}.cuda.Device(device_id).use()", 

223 "move_to_device": "data.copy() if data.device.id != device_id else data", 

224 "move_context": "{mod}.cuda.Device(device_id)", 

225 # Stack operations 

226 "allocate_stack": "cupy.empty(stack_shape, dtype=first_slice.dtype)", 

227 "allocate_context": "cupy.cuda.Device(gpu_id)", 

228 "needs_dtype_conversion": False, 

229 "assign_slice": None, # Standard 

230 "stack_handler": None, # Standard 

231 # Dtype scaling 

232 "scaling_ops": { 

233 "min": "mod.min(result)", 

234 "max": "mod.max(result)", 

235 "astype": "result.astype(target_dtype)", 

236 "check_float": "mod.issubdtype(result.dtype, mod.floating)", 

237 "check_int": "not mod.issubdtype(target_dtype, mod.floating)", 

238 "clamp": "mod.clip(result, min_val, max_val)", 

239 }, 

240 # Conversion operations 

241 "conversion_ops": { 

242 "to_numpy": "data.get()", 

243 "from_numpy": "({mod}.cuda.Device(gpu_id), {mod}.array(data))[1]", 

244 "from_dlpack": "{mod}.from_dlpack(data)", 

245 "move_to_device": "data if data.device.id == gpu_id else ({mod}.cuda.Device(gpu_id), {mod}.array(data))[1]", # noqa: E501 

246 }, 

247 # DLPack 

248 "supports_dlpack": True, 

249 "validate_dlpack": None, 

250 # GPU/Cleanup 

251 "lazy_getter": "_get_cupy", 

252 "gpu_check": '{mod} is not None and hasattr({mod}, "cuda")', 

253 "stream_context": "{mod}.cuda.Stream()", 

254 "device_context": "{mod}.cuda.Device({device_id})", 

255 "cleanup_ops": "{mod}.get_default_memory_pool().free_all_blocks(); {mod}.get_default_pinned_memory_pool().free_all_blocks(); {mod}.cuda.runtime.deviceSynchronize()", # noqa: E501 

256 "has_oom_recovery": True, 

257 "oom_exception_types": [ 

258 "{mod}.cuda.memory.OutOfMemoryError", 

259 "{mod}.cuda.runtime.CUDARuntimeError", 

260 ], # noqa: E501 

261 "oom_string_patterns": ["out of memory", "cuda_error_out_of_memory"], 

262 "oom_clear_cache": "{mod}.get_default_memory_pool().free_all_blocks(); {mod}.get_default_pinned_memory_pool().free_all_blocks(); {mod}.cuda.runtime.deviceSynchronize()", # noqa: E501 

263 }, 

264 MemoryType.TORCH: { 

265 # Metadata 

266 "import_name": "torch", 

267 "display_name": "PyTorch", 

268 "is_gpu": True, 

269 # Device operations 

270 "get_device_id": "data.device.index if data.is_cuda else None", 

271 "get_device_id_fallback": "None", 

272 "set_device": None, # PyTorch handles device at tensor creation 

273 "move_to_device": 'data.to(f"cuda:{device_id}") if (not data.is_cuda or data.device.index != device_id) else data', # noqa: E501 

274 # Stack operations 

275 "allocate_stack": "torch.empty(stack_shape, dtype=sample_converted.dtype, device=sample_converted.device)", # noqa: E501 

276 "allocate_context": None, 

277 "needs_dtype_conversion": _torch_dtype_conversion_needed, # Callable 

278 "assign_slice": None, # Standard 

279 "stack_handler": None, # Standard 

280 # Dtype scaling 

281 "scaling_ops": { 

282 "min": "result.min()", 

283 "max": "result.max()", 

284 "astype": "result.to(target_dtype_mapped)", 

285 "check_float": "result.dtype in [mod.float16, mod.float32, mod.float64]", 

286 "check_int": "target_dtype_mapped in [mod.uint8, mod.int8, mod.int16, mod.int32, mod.int64]", # noqa: E501 

287 "needs_dtype_map": True, 

288 "clamp": "mod.clamp(result, min=min_val, max=max_val)", 

289 }, 

290 # Conversion operations 

291 "conversion_ops": { 

292 "to_numpy": "data.cpu().numpy()", 

293 "from_numpy": "{mod}.from_numpy(data).cuda(gpu_id)", 

294 "from_dlpack": "{mod}.from_dlpack(data)", 

295 "move_to_device": "data if data.device.index == gpu_id else data.cuda(gpu_id)", 

296 }, 

297 # DLPack 

298 "supports_dlpack": True, 

299 "validate_dlpack": None, 

300 # GPU/Cleanup 

301 "lazy_getter": "_get_torch", 

302 "gpu_check": '{mod} is not None and hasattr({mod}, "cuda") and {mod}.cuda.is_available()', 

303 "stream_context": "{mod}.cuda.Stream()", 

304 "device_context": "{mod}.cuda.device({device_id})", 

305 "cleanup_ops": "{mod}.cuda.empty_cache(); {mod}.cuda.synchronize()", 

306 "has_oom_recovery": True, 

307 "oom_exception_types": ["{mod}.cuda.OutOfMemoryError"], 

308 "oom_string_patterns": ["out of memory", "cuda_error_out_of_memory"], 

309 "oom_clear_cache": "{mod}.cuda.empty_cache(); {mod}.cuda.synchronize()", 

310 }, 

311 MemoryType.TENSORFLOW: { 

312 # Metadata 

313 "import_name": "tensorflow", 

314 "display_name": "TensorFlow", 

315 "is_gpu": True, 

316 # Device operations 

317 "get_device_id": 'int(data.device.lower().split(":")[-1]) if "gpu" in data.device.lower() else None', # noqa: E501 

318 "get_device_id_fallback": "None", 

319 "set_device": None, # TensorFlow handles device at tensor creation 

320 "move_to_device": "{mod}.identity(data)", 

321 "move_context": '{mod}.device(f"/device:GPU:{device_id}")', 

322 # Stack operations 

323 "allocate_stack": "tf.zeros(stack_shape, dtype=first_slice.dtype)", # TF doesn't have empty() # noqa: E501 

324 "allocate_context": 'tf.device(f"/device:GPU:{gpu_id}")', 

325 "needs_dtype_conversion": False, 

326 "assign_slice": None, # Standard 

327 "stack_handler": None, # Standard 

328 # Dtype scaling 

329 "scaling_ops": { 

330 "min": "mod.reduce_min(result)", 

331 "max": "mod.reduce_max(result)", 

332 "astype": "mod.cast(result, target_dtype_mapped)", 

333 "check_float": "result.dtype in [mod.float16, mod.float32, mod.float64]", 

334 "check_int": "target_dtype_mapped in [mod.uint8, mod.int8, mod.int16, mod.int32, mod.int64]", # noqa: E501 

335 "needs_dtype_map": True, 

336 "clamp": "mod.clip_by_value(result, min_val, max_val)", 

337 }, 

338 # Conversion operations 

339 "conversion_ops": { 

340 "to_numpy": "data.numpy()", 

341 "from_numpy": "{mod}.convert_to_tensor(data)", 

342 "from_dlpack": "{mod}.experimental.dlpack.from_dlpack(data)", 

343 "move_to_device": "data", 

344 }, 

345 # DLPack 

346 "supports_dlpack": True, 

347 "validate_dlpack": _tensorflow_validate_dlpack, # Custom validation 

348 # GPU/Cleanup 

349 "lazy_getter": "_get_tensorflow", 

350 "gpu_check": '{mod} is not None and {mod}.config.list_physical_devices("GPU")', 

351 "stream_context": None, # TensorFlow manages streams internally 

352 "device_context": '{mod}.device("/GPU:0")', 

353 "cleanup_ops": None, # TensorFlow has no explicit cache clearing API 

354 "has_oom_recovery": True, 

355 "oom_exception_types": [ 

356 "{mod}.errors.ResourceExhaustedError", 

357 "{mod}.errors.InvalidArgumentError", 

358 ], 

359 "oom_string_patterns": ["out of memory", "resource_exhausted"], 

360 "oom_clear_cache": None, # TensorFlow has no explicit cache clearing API 

361 }, 

362 MemoryType.JAX: { 

363 # Metadata 

364 "import_name": "jax", 

365 "display_name": "JAX", 

366 "is_gpu": True, 

367 # Device operations 

368 "get_device_id": 'int(str(data.device).lower().split(":")[-1]) if "gpu" in str(data.device).lower() else None', # noqa: E501 

369 "get_device_id_fallback": "None", 

370 "set_device": None, # JAX handles device at array creation 

371 "move_to_device": '{mod}.device_put(data, {mod}.devices("gpu")[device_id])', 

372 # Stack operations 

373 "allocate_stack": "jnp.empty(stack_shape, dtype=first_slice.dtype)", 

374 "allocate_context": None, 

375 "needs_dtype_conversion": False, 

376 "assign_slice": _jax_assign_slice, # Custom handler for immutability 

377 "stack_handler": None, # Standard 

378 # Dtype scaling 

379 "scaling_ops": { 

380 "min": "jnp.min(result)", 

381 "max": "jnp.max(result)", 

382 "astype": "result.astype(target_dtype_mapped)", 

383 "check_float": "result.dtype in [jnp.float16, jnp.float32, jnp.float64]", 

384 "check_int": "target_dtype_mapped in [jnp.uint8, jnp.int8, jnp.int16, jnp.int32, jnp.int64]", # noqa: E501 

385 "needs_dtype_map": True, 

386 "extra_import": "jax.numpy", 

387 "clamp": "jnp.clip(result, min_val, max_val)", 

388 }, 

389 # Conversion operations 

390 "conversion_ops": { 

391 "to_numpy": "np.asarray(data)", 

392 "from_numpy": "{mod}.device_put(data, {mod}.devices()[gpu_id])", 

393 "from_dlpack": "{mod}.dlpack.from_dlpack(data)", 

394 "move_to_device": "data", 

395 }, 

396 # DLPack 

397 "supports_dlpack": True, 

398 "validate_dlpack": None, 

399 # GPU/Cleanup 

400 "lazy_getter": "_get_jax", 

401 "gpu_check": '{mod} is not None and any(d.platform == "gpu" for d in {mod}.devices())', 

402 "stream_context": None, # JAX/XLA manages streams internally 

403 "device_context": '{mod}.default_device([d for d in {mod}.devices() if d.platform == "gpu"][0])', # noqa: E501 

404 "cleanup_ops": "{mod}.clear_caches()", 

405 "has_oom_recovery": True, 

406 "oom_exception_types": [], 

407 "oom_string_patterns": ["out of memory", "oom when allocating", "allocation failure"], 

408 "oom_clear_cache": "{mod}.clear_caches()", 

409 }, 

410 MemoryType.PYCLESPERANTO: { 

411 # Metadata 

412 "import_name": "pyclesperanto", 

413 "display_name": "pyclesperanto", 

414 "is_gpu": True, 

415 # Device operations (custom handlers) 

416 "get_device_id": _pyclesperanto_get_device_id, # Callable 

417 "get_device_id_fallback": "0", 

418 "set_device": _pyclesperanto_set_device, # Callable 

419 "move_to_device": _pyclesperanto_move_to_device, # Callable 

420 # Stack operations (custom handler) 

421 "allocate_stack": None, # Uses concatenate_along_z 

422 "allocate_context": None, 

423 "needs_dtype_conversion": False, 

424 "assign_slice": None, # Not used (custom stacking) 

425 "stack_handler": _pyclesperanto_stack_slices, # Custom stacking 

426 # Conversion operations 

427 "conversion_ops": { 

428 "to_numpy": "{mod}.pull(data)", 

429 "from_numpy": "{mod}.push(data)", 

430 "from_dlpack": None, 

431 "move_to_device": "data", 

432 }, 

433 # Dtype scaling (custom implementation in dtype_scaling.py) 

434 "scaling_ops": None, # Custom _scale_pyclesperanto function 

435 # DLPack 

436 "supports_dlpack": False, 

437 "validate_dlpack": None, 

438 # GPU/Cleanup 

439 "lazy_getter": None, 

440 "gpu_check": None, # pyclesperanto always uses GPU if available 

441 "stream_context": None, # OpenCL manages streams internally 

442 "device_context": None, # OpenCL device selection is global 

443 "cleanup_ops": None, # pyclesperanto/OpenCL has no explicit cache clearing API 

444 "has_oom_recovery": True, 

445 "oom_exception_types": [], 

446 "oom_string_patterns": [ 

447 "cl_mem_object_allocation_failure", 

448 "cl_out_of_resources", 

449 "out of memory", 

450 ], # noqa: E501 

451 "oom_clear_cache": None, # pyclesperanto/OpenCL has no explicit cache clearing API 

452 }, 

453}