Coverage for src/arraybridge/framework_config.py: 74%
80 statements
« prev ^ index » next coverage.py v7.11.0, created at 2025-11-03 05:09 +0000
« prev ^ index » next coverage.py v7.11.0, created at 2025-11-03 05:09 +0000
1"""
2Single source of truth for ALL framework-specific behavior.
4This module consolidates all framework-specific logic that was previously
5scattered across utils.py, stack_utils.py, gpu_cleanup.py, dtype_scaling.py,
6and framework_ops.py.
8Architecture:
9- Framework handlers: Custom logic for special cases (pyclesperanto, JAX, TensorFlow)
10- Unified config: Single _FRAMEWORK_CONFIG dict with all framework metadata
11- Polymorphic dispatch: Handlers can be callables or eval expressions
12"""
14import logging
15from typing import Any, Callable
17from arraybridge.types import MemoryType
19logger = logging.getLogger(__name__)
22# ============================================================================
23# FRAMEWORK HANDLERS - All special-case logic lives here
24# ============================================================================
27def _pyclesperanto_get_device_id(data: Any, mod: Any) -> int:
28 """Get device ID for pyclesperanto array."""
29 if mod is None:
30 return 0
31 try:
32 current_device = mod.get_device()
33 if hasattr(current_device, "id"):
34 return current_device.id
35 devices = mod.list_available_devices()
36 for i, device in enumerate(devices):
37 if str(device) == str(current_device):
38 return i
39 return 0
40 except Exception as e:
41 logger.warning(f"Failed to get device ID for pyclesperanto: {e}")
42 return 0
45def _pyclesperanto_set_device(device_id: int, mod: Any) -> None:
46 """Set device for pyclesperanto."""
47 if mod is None:
48 return
49 devices = mod.list_available_devices()
50 if device_id >= len(devices):
51 raise ValueError(f"Device {device_id} not available. Available: {len(devices)}")
52 mod.select_device(device_id)
55def _pyclesperanto_move_to_device(data: Any, device_id: int, mod: Any, memory_type: str) -> Any:
56 """Move pyclesperanto array to device."""
57 if mod is None:
58 return data
59 # Import here to avoid circular dependency
60 from arraybridge.utils import _get_device_id
62 current_device_id = _get_device_id(data, memory_type)
64 if current_device_id != device_id:
65 mod.select_device(device_id)
66 result = mod.create_like(data)
67 mod.copy(data, result)
68 return result
69 return data
72def _pyclesperanto_stack_slices(slices: list, memory_type: str, gpu_id: int, mod: Any) -> Any:
73 """Stack slices using pyclesperanto's concatenate_along_z."""
74 if mod is None:
75 return None
76 from arraybridge.converters import convert_memory, detect_memory_type
78 converted_slices = []
79 conversion_count = 0
81 for slice_data in slices:
82 source_type = detect_memory_type(slice_data)
84 if source_type != memory_type:
85 conversion_count += 1
87 if source_type == memory_type:
88 converted_slices.append(slice_data)
89 else:
90 converted = convert_memory(slice_data, source_type, memory_type, gpu_id)
91 converted_slices.append(converted)
93 # Log batch conversion
94 if conversion_count > 0:
95 logger.debug(
96 f"🔄 MEMORY CONVERSION: Converted {conversion_count}/{len(slices)} slices "
97 f"to {memory_type} for pyclesperanto stacking"
98 )
100 return mod.concatenate_along_z(converted_slices)
103def _jax_assign_slice(result: Any, index: int, slice_data: Any) -> Any:
104 """Assign slice to JAX array (immutable)."""
105 if result is None:
106 return None
107 return result.at[index].set(slice_data)
110def _tensorflow_validate_dlpack(obj: Any, mod: Any) -> bool:
111 """Validate TensorFlow DLPack support."""
112 if mod is None:
113 return False
114 # Check version
115 major, minor = map(int, mod.__version__.split(".")[:2])
116 if major < 2 or (major == 2 and minor < 12):
117 raise RuntimeError(
118 f"TensorFlow {mod.__version__} does not support stable DLPack. "
119 f"Version 2.12.0+ required. "
120 f"Clause 88 violation: Cannot infer DLPack capability."
121 )
123 # Check GPU
124 """Validate TensorFlow DLPack support."""
125 # Check version
126 major, minor = map(int, mod.__version__.split(".")[:2])
127 if major < 2 or (major == 2 and minor < 12):
128 raise RuntimeError(
129 f"TensorFlow {mod.__version__} does not support stable DLPack. "
130 f"Version 2.12.0+ required. "
131 f"Clause 88 violation: Cannot infer DLPack capability."
132 )
134 # Check GPU
135 device_str = obj.device.lower()
136 if "gpu" not in device_str:
137 raise RuntimeError(
138 "TensorFlow tensor on CPU cannot use DLPack operations reliably. "
139 "Only GPU tensors are supported for DLPack operations. "
140 "Clause 88 violation: Cannot infer GPU capability."
141 )
143 # Check module
144 if not hasattr(mod.experimental, "dlpack"):
145 raise RuntimeError(
146 "TensorFlow installation missing experimental.dlpack module. "
147 "Clause 88 violation: Cannot infer DLPack capability."
148 )
150 return True
153def _numpy_dtype_conversion_needed(first_slice: Any, detect_memory_type_func: Callable) -> bool:
154 """Check if NumPy needs dtype conversion (only for torch sources)."""
155 source_type = detect_memory_type_func(first_slice)
156 return source_type == MemoryType.TORCH.value
159def _torch_dtype_conversion_needed(first_slice: Any, detect_memory_type_func: Callable) -> bool:
160 """Torch always needs dtype conversion to get correct torch dtype."""
161 return True
164# ============================================================================
165# UNIFIED FRAMEWORK CONFIGURATION
166# ============================================================================
168_FRAMEWORK_CONFIG = {
169 MemoryType.NUMPY: {
170 # Metadata
171 "import_name": "numpy",
172 "display_name": "NumPy",
173 "is_gpu": False,
174 # Device operations
175 "get_device_id": None, # CPU
176 "set_device": None, # CPU
177 "move_to_device": None, # CPU
178 # Stack operations
179 "allocate_stack": "np.empty(stack_shape, dtype=dtype)",
180 "allocate_context": None,
181 "needs_dtype_conversion": _numpy_dtype_conversion_needed, # Callable
182 "assign_slice": None, # Standard: result[i] = slice
183 "stack_handler": None, # Standard stacking
184 # Dtype scaling
185 "scaling_ops": {
186 "min": "result.min()",
187 "max": "result.max()",
188 "astype": "result.astype(target_dtype)",
189 "check_float": "np.issubdtype(result.dtype, np.floating)",
190 "check_int": "target_dtype in [np.uint8, np.uint16, np.uint32, np.int8, np.int16, np.int32]", # noqa: E501
191 "clamp": "np.clip(result, min_val, max_val)",
192 },
193 # Conversion operations
194 "conversion_ops": {
195 "to_numpy": "data",
196 "from_numpy": "data",
197 "from_dlpack": None,
198 "move_to_device": "data",
199 },
200 # DLPack
201 "supports_dlpack": False,
202 "validate_dlpack": None,
203 # GPU/Cleanup
204 "lazy_getter": None,
205 "gpu_check": None,
206 "stream_context": None,
207 "device_context": None,
208 "cleanup_ops": None,
209 "has_oom_recovery": False,
210 "oom_exception_types": [],
211 "oom_string_patterns": ["cannot allocate memory", "memory exhausted"],
212 "oom_clear_cache": "import gc; gc.collect()",
213 },
214 MemoryType.CUPY: {
215 # Metadata
216 "import_name": "cupy",
217 "display_name": "CuPy",
218 "is_gpu": True,
219 # Device operations (eval expressions)
220 "get_device_id": "data.device.id",
221 "get_device_id_fallback": "0",
222 "set_device": "{mod}.cuda.Device(device_id).use()",
223 "move_to_device": "data.copy() if data.device.id != device_id else data",
224 "move_context": "{mod}.cuda.Device(device_id)",
225 # Stack operations
226 "allocate_stack": "cupy.empty(stack_shape, dtype=first_slice.dtype)",
227 "allocate_context": "cupy.cuda.Device(gpu_id)",
228 "needs_dtype_conversion": False,
229 "assign_slice": None, # Standard
230 "stack_handler": None, # Standard
231 # Dtype scaling
232 "scaling_ops": {
233 "min": "mod.min(result)",
234 "max": "mod.max(result)",
235 "astype": "result.astype(target_dtype)",
236 "check_float": "mod.issubdtype(result.dtype, mod.floating)",
237 "check_int": "not mod.issubdtype(target_dtype, mod.floating)",
238 "clamp": "mod.clip(result, min_val, max_val)",
239 },
240 # Conversion operations
241 "conversion_ops": {
242 "to_numpy": "data.get()",
243 "from_numpy": "({mod}.cuda.Device(gpu_id), {mod}.array(data))[1]",
244 "from_dlpack": "{mod}.from_dlpack(data)",
245 "move_to_device": "data if data.device.id == gpu_id else ({mod}.cuda.Device(gpu_id), {mod}.array(data))[1]", # noqa: E501
246 },
247 # DLPack
248 "supports_dlpack": True,
249 "validate_dlpack": None,
250 # GPU/Cleanup
251 "lazy_getter": "_get_cupy",
252 "gpu_check": '{mod} is not None and hasattr({mod}, "cuda")',
253 "stream_context": "{mod}.cuda.Stream()",
254 "device_context": "{mod}.cuda.Device({device_id})",
255 "cleanup_ops": "{mod}.get_default_memory_pool().free_all_blocks(); {mod}.get_default_pinned_memory_pool().free_all_blocks(); {mod}.cuda.runtime.deviceSynchronize()", # noqa: E501
256 "has_oom_recovery": True,
257 "oom_exception_types": [
258 "{mod}.cuda.memory.OutOfMemoryError",
259 "{mod}.cuda.runtime.CUDARuntimeError",
260 ], # noqa: E501
261 "oom_string_patterns": ["out of memory", "cuda_error_out_of_memory"],
262 "oom_clear_cache": "{mod}.get_default_memory_pool().free_all_blocks(); {mod}.get_default_pinned_memory_pool().free_all_blocks(); {mod}.cuda.runtime.deviceSynchronize()", # noqa: E501
263 },
264 MemoryType.TORCH: {
265 # Metadata
266 "import_name": "torch",
267 "display_name": "PyTorch",
268 "is_gpu": True,
269 # Device operations
270 "get_device_id": "data.device.index if data.is_cuda else None",
271 "get_device_id_fallback": "None",
272 "set_device": None, # PyTorch handles device at tensor creation
273 "move_to_device": 'data.to(f"cuda:{device_id}") if (not data.is_cuda or data.device.index != device_id) else data', # noqa: E501
274 # Stack operations
275 "allocate_stack": "torch.empty(stack_shape, dtype=sample_converted.dtype, device=sample_converted.device)", # noqa: E501
276 "allocate_context": None,
277 "needs_dtype_conversion": _torch_dtype_conversion_needed, # Callable
278 "assign_slice": None, # Standard
279 "stack_handler": None, # Standard
280 # Dtype scaling
281 "scaling_ops": {
282 "min": "result.min()",
283 "max": "result.max()",
284 "astype": "result.to(target_dtype_mapped)",
285 "check_float": "result.dtype in [mod.float16, mod.float32, mod.float64]",
286 "check_int": "target_dtype_mapped in [mod.uint8, mod.int8, mod.int16, mod.int32, mod.int64]", # noqa: E501
287 "needs_dtype_map": True,
288 "clamp": "mod.clamp(result, min=min_val, max=max_val)",
289 },
290 # Conversion operations
291 "conversion_ops": {
292 "to_numpy": "data.cpu().numpy()",
293 "from_numpy": "{mod}.from_numpy(data).cuda(gpu_id)",
294 "from_dlpack": "{mod}.from_dlpack(data)",
295 "move_to_device": "data if data.device.index == gpu_id else data.cuda(gpu_id)",
296 },
297 # DLPack
298 "supports_dlpack": True,
299 "validate_dlpack": None,
300 # GPU/Cleanup
301 "lazy_getter": "_get_torch",
302 "gpu_check": '{mod} is not None and hasattr({mod}, "cuda") and {mod}.cuda.is_available()',
303 "stream_context": "{mod}.cuda.Stream()",
304 "device_context": "{mod}.cuda.device({device_id})",
305 "cleanup_ops": "{mod}.cuda.empty_cache(); {mod}.cuda.synchronize()",
306 "has_oom_recovery": True,
307 "oom_exception_types": ["{mod}.cuda.OutOfMemoryError"],
308 "oom_string_patterns": ["out of memory", "cuda_error_out_of_memory"],
309 "oom_clear_cache": "{mod}.cuda.empty_cache(); {mod}.cuda.synchronize()",
310 },
311 MemoryType.TENSORFLOW: {
312 # Metadata
313 "import_name": "tensorflow",
314 "display_name": "TensorFlow",
315 "is_gpu": True,
316 # Device operations
317 "get_device_id": 'int(data.device.lower().split(":")[-1]) if "gpu" in data.device.lower() else None', # noqa: E501
318 "get_device_id_fallback": "None",
319 "set_device": None, # TensorFlow handles device at tensor creation
320 "move_to_device": "{mod}.identity(data)",
321 "move_context": '{mod}.device(f"/device:GPU:{device_id}")',
322 # Stack operations
323 "allocate_stack": "tf.zeros(stack_shape, dtype=first_slice.dtype)", # TF doesn't have empty() # noqa: E501
324 "allocate_context": 'tf.device(f"/device:GPU:{gpu_id}")',
325 "needs_dtype_conversion": False,
326 "assign_slice": None, # Standard
327 "stack_handler": None, # Standard
328 # Dtype scaling
329 "scaling_ops": {
330 "min": "mod.reduce_min(result)",
331 "max": "mod.reduce_max(result)",
332 "astype": "mod.cast(result, target_dtype_mapped)",
333 "check_float": "result.dtype in [mod.float16, mod.float32, mod.float64]",
334 "check_int": "target_dtype_mapped in [mod.uint8, mod.int8, mod.int16, mod.int32, mod.int64]", # noqa: E501
335 "needs_dtype_map": True,
336 "clamp": "mod.clip_by_value(result, min_val, max_val)",
337 },
338 # Conversion operations
339 "conversion_ops": {
340 "to_numpy": "data.numpy()",
341 "from_numpy": "{mod}.convert_to_tensor(data)",
342 "from_dlpack": "{mod}.experimental.dlpack.from_dlpack(data)",
343 "move_to_device": "data",
344 },
345 # DLPack
346 "supports_dlpack": True,
347 "validate_dlpack": _tensorflow_validate_dlpack, # Custom validation
348 # GPU/Cleanup
349 "lazy_getter": "_get_tensorflow",
350 "gpu_check": '{mod} is not None and {mod}.config.list_physical_devices("GPU")',
351 "stream_context": None, # TensorFlow manages streams internally
352 "device_context": '{mod}.device("/GPU:0")',
353 "cleanup_ops": None, # TensorFlow has no explicit cache clearing API
354 "has_oom_recovery": True,
355 "oom_exception_types": [
356 "{mod}.errors.ResourceExhaustedError",
357 "{mod}.errors.InvalidArgumentError",
358 ],
359 "oom_string_patterns": ["out of memory", "resource_exhausted"],
360 "oom_clear_cache": None, # TensorFlow has no explicit cache clearing API
361 },
362 MemoryType.JAX: {
363 # Metadata
364 "import_name": "jax",
365 "display_name": "JAX",
366 "is_gpu": True,
367 # Device operations
368 "get_device_id": 'int(str(data.device).lower().split(":")[-1]) if "gpu" in str(data.device).lower() else None', # noqa: E501
369 "get_device_id_fallback": "None",
370 "set_device": None, # JAX handles device at array creation
371 "move_to_device": '{mod}.device_put(data, {mod}.devices("gpu")[device_id])',
372 # Stack operations
373 "allocate_stack": "jnp.empty(stack_shape, dtype=first_slice.dtype)",
374 "allocate_context": None,
375 "needs_dtype_conversion": False,
376 "assign_slice": _jax_assign_slice, # Custom handler for immutability
377 "stack_handler": None, # Standard
378 # Dtype scaling
379 "scaling_ops": {
380 "min": "jnp.min(result)",
381 "max": "jnp.max(result)",
382 "astype": "result.astype(target_dtype_mapped)",
383 "check_float": "result.dtype in [jnp.float16, jnp.float32, jnp.float64]",
384 "check_int": "target_dtype_mapped in [jnp.uint8, jnp.int8, jnp.int16, jnp.int32, jnp.int64]", # noqa: E501
385 "needs_dtype_map": True,
386 "extra_import": "jax.numpy",
387 "clamp": "jnp.clip(result, min_val, max_val)",
388 },
389 # Conversion operations
390 "conversion_ops": {
391 "to_numpy": "np.asarray(data)",
392 "from_numpy": "{mod}.device_put(data, {mod}.devices()[gpu_id])",
393 "from_dlpack": "{mod}.dlpack.from_dlpack(data)",
394 "move_to_device": "data",
395 },
396 # DLPack
397 "supports_dlpack": True,
398 "validate_dlpack": None,
399 # GPU/Cleanup
400 "lazy_getter": "_get_jax",
401 "gpu_check": '{mod} is not None and any(d.platform == "gpu" for d in {mod}.devices())',
402 "stream_context": None, # JAX/XLA manages streams internally
403 "device_context": '{mod}.default_device([d for d in {mod}.devices() if d.platform == "gpu"][0])', # noqa: E501
404 "cleanup_ops": "{mod}.clear_caches()",
405 "has_oom_recovery": True,
406 "oom_exception_types": [],
407 "oom_string_patterns": ["out of memory", "oom when allocating", "allocation failure"],
408 "oom_clear_cache": "{mod}.clear_caches()",
409 },
410 MemoryType.PYCLESPERANTO: {
411 # Metadata
412 "import_name": "pyclesperanto",
413 "display_name": "pyclesperanto",
414 "is_gpu": True,
415 # Device operations (custom handlers)
416 "get_device_id": _pyclesperanto_get_device_id, # Callable
417 "get_device_id_fallback": "0",
418 "set_device": _pyclesperanto_set_device, # Callable
419 "move_to_device": _pyclesperanto_move_to_device, # Callable
420 # Stack operations (custom handler)
421 "allocate_stack": None, # Uses concatenate_along_z
422 "allocate_context": None,
423 "needs_dtype_conversion": False,
424 "assign_slice": None, # Not used (custom stacking)
425 "stack_handler": _pyclesperanto_stack_slices, # Custom stacking
426 # Conversion operations
427 "conversion_ops": {
428 "to_numpy": "{mod}.pull(data)",
429 "from_numpy": "{mod}.push(data)",
430 "from_dlpack": None,
431 "move_to_device": "data",
432 },
433 # Dtype scaling (custom implementation in dtype_scaling.py)
434 "scaling_ops": None, # Custom _scale_pyclesperanto function
435 # DLPack
436 "supports_dlpack": False,
437 "validate_dlpack": None,
438 # GPU/Cleanup
439 "lazy_getter": None,
440 "gpu_check": None, # pyclesperanto always uses GPU if available
441 "stream_context": None, # OpenCL manages streams internally
442 "device_context": None, # OpenCL device selection is global
443 "cleanup_ops": None, # pyclesperanto/OpenCL has no explicit cache clearing API
444 "has_oom_recovery": True,
445 "oom_exception_types": [],
446 "oom_string_patterns": [
447 "cl_mem_object_allocation_failure",
448 "cl_out_of_resources",
449 "out of memory",
450 ], # noqa: E501
451 "oom_clear_cache": None, # pyclesperanto/OpenCL has no explicit cache clearing API
452 },
453}