Coverage for openhcs/core/memory/framework_config.py: 18.0%
65 statements
« prev ^ index » next coverage.py v7.11.0, created at 2025-11-04 02:09 +0000
« prev ^ index » next coverage.py v7.11.0, created at 2025-11-04 02:09 +0000
1"""
2Single source of truth for ALL framework-specific behavior.
4This module consolidates all framework-specific logic that was previously
5scattered across utils.py, stack_utils.py, gpu_cleanup.py, dtype_scaling.py,
6and framework_ops.py.
8Architecture:
9- Framework handlers: Custom logic for special cases (pyclesperanto, JAX, TensorFlow)
10- Unified config: Single _FRAMEWORK_CONFIG dict with all framework metadata
11- Polymorphic dispatch: Handlers can be callables or eval expressions
12"""
14import gc
15import logging
16from typing import Any, Optional, Callable
17from openhcs.constants.constants import MemoryType
19logger = logging.getLogger(__name__)
22# ============================================================================
23# FRAMEWORK HANDLERS - All special-case logic lives here
24# ============================================================================
26def _pyclesperanto_get_device_id(data: Any, mod: Any) -> int:
27 """Get device ID for pyclesperanto array."""
28 try:
29 current_device = mod.get_device()
30 if hasattr(current_device, 'id'):
31 return current_device.id
32 devices = mod.list_available_devices()
33 for i, device in enumerate(devices):
34 if str(device) == str(current_device):
35 return i
36 return 0
37 except Exception as e:
38 logger.warning(f"Failed to get device ID for pyclesperanto: {e}")
39 return 0
42def _pyclesperanto_set_device(device_id: int, mod: Any) -> None:
43 """Set device for pyclesperanto."""
44 devices = mod.list_available_devices()
45 if device_id >= len(devices):
46 raise ValueError(f"Device {device_id} not available. Available: {len(devices)}")
47 mod.select_device(device_id)
50def _pyclesperanto_move_to_device(data: Any, device_id: int, mod: Any, memory_type: str) -> Any:
51 """Move pyclesperanto array to device."""
52 # Import here to avoid circular dependency
53 from openhcs.core.memory.utils import _get_device_id
55 current_device_id = _get_device_id(data, memory_type)
57 if current_device_id != device_id:
58 mod.select_device(device_id)
59 result = mod.create_like(data)
60 mod.copy(data, result)
61 return result
62 return data
65def _pyclesperanto_stack_slices(slices: list, memory_type: str, gpu_id: int, mod: Any) -> Any:
66 """Stack slices using pyclesperanto's concatenate_along_z."""
67 from openhcs.core.memory.converters import convert_memory, detect_memory_type
69 converted_slices = []
70 conversion_count = 0
72 for slice_data in slices:
73 source_type = detect_memory_type(slice_data)
75 if source_type != memory_type:
76 conversion_count += 1
78 if source_type == memory_type:
79 converted_slices.append(slice_data)
80 else:
81 converted = convert_memory(slice_data, source_type, memory_type, gpu_id)
82 converted_slices.append(converted)
84 # Log batch conversion
85 if conversion_count > 0:
86 logger.debug(
87 f"🔄 MEMORY CONVERSION: Converted {conversion_count}/{len(slices)} slices "
88 f"to {memory_type} for pyclesperanto stacking"
89 )
91 return mod.concatenate_along_z(converted_slices)
94def _jax_assign_slice(result: Any, index: int, slice_data: Any) -> Any:
95 """Assign slice to JAX array (immutable)."""
96 return result.at[index].set(slice_data)
99def _tensorflow_validate_dlpack(obj: Any, mod: Any) -> bool:
100 """Validate TensorFlow DLPack support."""
101 # Check version
102 major, minor = map(int, mod.__version__.split('.')[:2])
103 if major < 2 or (major == 2 and minor < 12):
104 raise RuntimeError(
105 f"TensorFlow {mod.__version__} does not support stable DLPack. "
106 f"Version 2.12.0+ required. "
107 f"Clause 88 violation: Cannot infer DLPack capability."
108 )
110 # Check GPU
111 device_str = obj.device.lower()
112 if "gpu" not in device_str:
113 raise RuntimeError(
114 "TensorFlow tensor on CPU cannot use DLPack operations reliably. "
115 "Only GPU tensors are supported for DLPack operations. "
116 "Clause 88 violation: Cannot infer GPU capability."
117 )
119 # Check module
120 if not hasattr(mod.experimental, "dlpack"):
121 raise RuntimeError(
122 "TensorFlow installation missing experimental.dlpack module. "
123 "Clause 88 violation: Cannot infer DLPack capability."
124 )
126 return True
129def _numpy_dtype_conversion_needed(first_slice: Any, detect_memory_type_func: Callable) -> bool:
130 """Check if NumPy needs dtype conversion (only for torch sources)."""
131 source_type = detect_memory_type_func(first_slice)
132 return source_type == MemoryType.TORCH.value
135def _torch_dtype_conversion_needed(first_slice: Any, detect_memory_type_func: Callable) -> bool:
136 """Torch always needs dtype conversion to get correct torch dtype."""
137 return True
140# ============================================================================
141# UNIFIED FRAMEWORK CONFIGURATION
142# ============================================================================
144_FRAMEWORK_CONFIG = {
145 MemoryType.NUMPY: {
146 # Metadata
147 'import_name': 'numpy',
148 'display_name': 'NumPy',
149 'is_gpu': False,
151 # Device operations
152 'get_device_id': None, # CPU
153 'set_device': None, # CPU
154 'move_to_device': None, # CPU
156 # Stack operations
157 'allocate_stack': 'np.empty(stack_shape, dtype=dtype)',
158 'allocate_context': None,
159 'needs_dtype_conversion': _numpy_dtype_conversion_needed, # Callable
160 'assign_slice': None, # Standard: result[i] = slice
161 'stack_handler': None, # Standard stacking
163 # Dtype scaling
164 'scaling_ops': {
165 'min': 'result.min()',
166 'max': 'result.max()',
167 'astype': 'result.astype(target_dtype)',
168 'check_float': 'np.issubdtype(result.dtype, np.floating)',
169 'check_int': 'target_dtype in [np.uint8, np.uint16, np.uint32, np.int8, np.int16, np.int32]',
170 },
172 # Conversion operations
173 'conversion_ops': {
174 'to_numpy': 'data',
175 'from_numpy': 'data',
176 'from_dlpack': None,
177 'move_to_device': 'data',
178 },
180 # DLPack
181 'supports_dlpack': False,
182 'validate_dlpack': None,
184 # GPU/Cleanup
185 'lazy_getter': None,
186 'gpu_check': None,
187 'stream_context': None,
188 'device_context': None,
189 'cleanup_ops': None,
190 'has_oom_recovery': False,
191 'oom_exception_types': [],
192 'oom_string_patterns': ['cannot allocate memory', 'memory exhausted'],
193 'oom_clear_cache': 'import gc; gc.collect()',
194 },
196 MemoryType.CUPY: {
197 # Metadata
198 'import_name': 'cupy',
199 'display_name': 'CuPy',
200 'is_gpu': True,
202 # Device operations (eval expressions)
203 'get_device_id': 'data.device.id',
204 'get_device_id_fallback': '0',
205 'set_device': '{mod}.cuda.Device(device_id).use()',
206 'move_to_device': 'data.copy() if data.device.id != device_id else data',
207 'move_context': '{mod}.cuda.Device(device_id)',
209 # Stack operations
210 'allocate_stack': 'cupy.empty(stack_shape, dtype=first_slice.dtype)',
211 'allocate_context': 'cupy.cuda.Device(gpu_id)',
212 'needs_dtype_conversion': False,
213 'assign_slice': None, # Standard
214 'stack_handler': None, # Standard
216 # Dtype scaling
217 'scaling_ops': {
218 'min': 'mod.min(result)',
219 'max': 'mod.max(result)',
220 'astype': 'result.astype(target_dtype)',
221 'check_float': 'mod.issubdtype(result.dtype, mod.floating)',
222 'check_int': 'not mod.issubdtype(target_dtype, mod.floating)',
223 },
225 # Conversion operations
226 'conversion_ops': {
227 'to_numpy': 'data.get()',
228 'from_numpy': '({mod}.cuda.Device(gpu_id), {mod}.array(data))[1]',
229 'from_dlpack': '{mod}.from_dlpack(data)',
230 'move_to_device': 'data if data.device.id == gpu_id else ({mod}.cuda.Device(gpu_id), {mod}.array(data))[1]',
231 },
233 # DLPack
234 'supports_dlpack': True,
235 'validate_dlpack': None,
237 # GPU/Cleanup
238 'lazy_getter': '_get_cupy',
239 'gpu_check': '{mod} is not None and hasattr({mod}, "cuda")',
240 'stream_context': '{mod}.cuda.Stream()',
241 'device_context': '{mod}.cuda.Device({device_id})',
242 'cleanup_ops': '{mod}.get_default_memory_pool().free_all_blocks(); {mod}.get_default_pinned_memory_pool().free_all_blocks(); {mod}.cuda.runtime.deviceSynchronize()',
243 'has_oom_recovery': True,
244 'oom_exception_types': ['{mod}.cuda.memory.OutOfMemoryError', '{mod}.cuda.runtime.CUDARuntimeError'],
245 'oom_string_patterns': ['out of memory', 'cuda_error_out_of_memory'],
246 'oom_clear_cache': '{mod}.get_default_memory_pool().free_all_blocks(); {mod}.get_default_pinned_memory_pool().free_all_blocks(); {mod}.cuda.runtime.deviceSynchronize()',
247 },
249 MemoryType.TORCH: {
250 # Metadata
251 'import_name': 'torch',
252 'display_name': 'PyTorch',
253 'is_gpu': True,
255 # Device operations
256 'get_device_id': 'data.device.index if data.is_cuda else None',
257 'get_device_id_fallback': 'None',
258 'set_device': None, # PyTorch handles device at tensor creation
259 'move_to_device': 'data.to(f"cuda:{device_id}") if (not data.is_cuda or data.device.index != device_id) else data',
261 # Stack operations
262 'allocate_stack': 'torch.empty(stack_shape, dtype=sample_converted.dtype, device=sample_converted.device)',
263 'allocate_context': None,
264 'needs_dtype_conversion': _torch_dtype_conversion_needed, # Callable
265 'assign_slice': None, # Standard
266 'stack_handler': None, # Standard
268 # Dtype scaling
269 'scaling_ops': {
270 'min': 'result.min()',
271 'max': 'result.max()',
272 'astype': 'result.to(target_dtype_mapped)',
273 'check_float': 'result.dtype in [mod.float16, mod.float32, mod.float64]',
274 'check_int': 'target_dtype_mapped in [mod.uint8, mod.int8, mod.int16, mod.int32, mod.int64]',
275 'needs_dtype_map': True,
276 },
278 # Conversion operations
279 'conversion_ops': {
280 'to_numpy': 'data.cpu().numpy()',
281 'from_numpy': '{mod}.from_numpy(data).cuda(gpu_id)',
282 'from_dlpack': '{mod}.from_dlpack(data)',
283 'move_to_device': 'data if data.device.index == gpu_id else data.cuda(gpu_id)',
284 },
286 # DLPack
287 'supports_dlpack': True,
288 'validate_dlpack': None,
290 # GPU/Cleanup
291 'lazy_getter': '_get_torch',
292 'gpu_check': '{mod} is not None and hasattr({mod}, "cuda") and {mod}.cuda.is_available()',
293 'stream_context': '{mod}.cuda.Stream()',
294 'device_context': '{mod}.cuda.device({device_id})',
295 'cleanup_ops': '{mod}.cuda.empty_cache(); {mod}.cuda.synchronize()',
296 'has_oom_recovery': True,
297 'oom_exception_types': ['{mod}.cuda.OutOfMemoryError'],
298 'oom_string_patterns': ['out of memory', 'cuda_error_out_of_memory'],
299 'oom_clear_cache': '{mod}.cuda.empty_cache(); {mod}.cuda.synchronize()',
300 },
302 MemoryType.TENSORFLOW: {
303 # Metadata
304 'import_name': 'tensorflow',
305 'display_name': 'TensorFlow',
306 'is_gpu': True,
308 # Device operations
309 'get_device_id': 'int(data.device.lower().split(":")[-1]) if "gpu" in data.device.lower() else None',
310 'get_device_id_fallback': 'None',
311 'set_device': None, # TensorFlow handles device at tensor creation
312 'move_to_device': '{mod}.identity(data)',
313 'move_context': '{mod}.device(f"/device:GPU:{device_id}")',
315 # Stack operations
316 'allocate_stack': 'tf.zeros(stack_shape, dtype=first_slice.dtype)', # TF doesn't have empty()
317 'allocate_context': 'tf.device(f"/device:GPU:{gpu_id}")',
318 'needs_dtype_conversion': False,
319 'assign_slice': None, # Standard
320 'stack_handler': None, # Standard
322 # Dtype scaling
323 'scaling_ops': {
324 'min': 'mod.reduce_min(result)',
325 'max': 'mod.reduce_max(result)',
326 'astype': 'mod.cast(result, target_dtype_mapped)',
327 'check_float': 'result.dtype in [mod.float16, mod.float32, mod.float64]',
328 'check_int': 'target_dtype_mapped in [mod.uint8, mod.int8, mod.int16, mod.int32, mod.int64]',
329 'needs_dtype_map': True,
330 },
332 # Conversion operations
333 'conversion_ops': {
334 'to_numpy': 'data.numpy()',
335 'from_numpy': '{mod}.convert_to_tensor(data)',
336 'from_dlpack': '{mod}.experimental.dlpack.from_dlpack(data)',
337 'move_to_device': 'data',
338 },
340 # DLPack
341 'supports_dlpack': True,
342 'validate_dlpack': _tensorflow_validate_dlpack, # Custom validation
344 # GPU/Cleanup
345 'lazy_getter': '_get_tensorflow',
346 'gpu_check': '{mod} is not None and {mod}.config.list_physical_devices("GPU")',
347 'stream_context': None, # TensorFlow manages streams internally
348 'device_context': '{mod}.device("/GPU:0")',
349 'cleanup_ops': None, # TensorFlow has no explicit cache clearing API
350 'has_oom_recovery': True,
351 'oom_exception_types': ['{mod}.errors.ResourceExhaustedError', '{mod}.errors.InvalidArgumentError'],
352 'oom_string_patterns': ['out of memory', 'resource_exhausted'],
353 'oom_clear_cache': None, # TensorFlow has no explicit cache clearing API
354 },
356 MemoryType.JAX: {
357 # Metadata
358 'import_name': 'jax',
359 'display_name': 'JAX',
360 'is_gpu': True,
362 # Device operations
363 'get_device_id': 'int(str(data.device).lower().split(":")[-1]) if "gpu" in str(data.device).lower() else None',
364 'get_device_id_fallback': 'None',
365 'set_device': None, # JAX handles device at array creation
366 'move_to_device': '{mod}.device_put(data, {mod}.devices("gpu")[device_id])',
368 # Stack operations
369 'allocate_stack': 'jnp.empty(stack_shape, dtype=first_slice.dtype)',
370 'allocate_context': None,
371 'needs_dtype_conversion': False,
372 'assign_slice': _jax_assign_slice, # Custom handler for immutability
373 'stack_handler': None, # Standard
375 # Dtype scaling
376 'scaling_ops': {
377 'min': 'jnp.min(result)',
378 'max': 'jnp.max(result)',
379 'astype': 'result.astype(target_dtype_mapped)',
380 'check_float': 'result.dtype in [jnp.float16, jnp.float32, jnp.float64]',
381 'check_int': 'target_dtype_mapped in [jnp.uint8, jnp.int8, jnp.int16, jnp.int32, jnp.int64]',
382 'needs_dtype_map': True,
383 'extra_import': 'jax.numpy',
384 },
386 # Conversion operations
387 'conversion_ops': {
388 'to_numpy': 'np.asarray(data)',
389 'from_numpy': '{mod}.device_put(data, {mod}.devices()[gpu_id])',
390 'from_dlpack': '{mod}.dlpack.from_dlpack(data)',
391 'move_to_device': 'data',
392 },
394 # DLPack
395 'supports_dlpack': True,
396 'validate_dlpack': None,
398 # GPU/Cleanup
399 'lazy_getter': '_get_jax',
400 'gpu_check': '{mod} is not None and any(d.platform == "gpu" for d in {mod}.devices())',
401 'stream_context': None, # JAX/XLA manages streams internally
402 'device_context': '{mod}.default_device([d for d in {mod}.devices() if d.platform == "gpu"][0])',
403 'cleanup_ops': '{mod}.clear_caches()',
404 'has_oom_recovery': True,
405 'oom_exception_types': [],
406 'oom_string_patterns': ['out of memory', 'oom when allocating', 'allocation failure'],
407 'oom_clear_cache': '{mod}.clear_caches()',
408 },
410 MemoryType.PYCLESPERANTO: {
411 # Metadata
412 'import_name': 'pyclesperanto',
413 'display_name': 'pyclesperanto',
414 'is_gpu': True,
416 # Device operations (custom handlers)
417 'get_device_id': _pyclesperanto_get_device_id, # Callable
418 'get_device_id_fallback': '0',
419 'set_device': _pyclesperanto_set_device, # Callable
420 'move_to_device': _pyclesperanto_move_to_device, # Callable
422 # Stack operations (custom handler)
423 'allocate_stack': None, # Uses concatenate_along_z
424 'allocate_context': None,
425 'needs_dtype_conversion': False,
426 'assign_slice': None, # Not used (custom stacking)
427 'stack_handler': _pyclesperanto_stack_slices, # Custom stacking
429 # Conversion operations
430 'conversion_ops': {
431 'to_numpy': '{mod}.pull(data)',
432 'from_numpy': '{mod}.push(data)',
433 'from_dlpack': None,
434 'move_to_device': 'data',
435 },
437 # Dtype scaling (custom implementation in dtype_scaling.py)
438 'scaling_ops': None, # Custom _scale_pyclesperanto function
440 # DLPack
441 'supports_dlpack': False,
442 'validate_dlpack': None,
444 # GPU/Cleanup
445 'lazy_getter': None,
446 'gpu_check': None, # pyclesperanto always uses GPU if available
447 'stream_context': None, # OpenCL manages streams internally
448 'device_context': None, # OpenCL device selection is global
449 'cleanup_ops': None, # pyclesperanto/OpenCL has no explicit cache clearing API
450 'has_oom_recovery': True,
451 'oom_exception_types': [],
452 'oom_string_patterns': ['cl_mem_object_allocation_failure', 'cl_out_of_resources', 'out of memory'],
453 'oom_clear_cache': None, # pyclesperanto/OpenCL has no explicit cache clearing API
454 },
455}