Coverage for openhcs/core/memory/decorators.py: 58.7%
177 statements
« prev ^ index » next coverage.py v7.11.0, created at 2025-11-04 02:09 +0000
« prev ^ index » next coverage.py v7.11.0, created at 2025-11-04 02:09 +0000
1"""
2Memory type declaration decorators for OpenHCS.
4This module provides decorators for explicitly declaring the memory interface
5of pure functions, enforcing Clause 106-A (Declared Memory Types) and supporting
6memory-type-aware dispatching and orchestration.
8These decorators annotate functions with input_memory_type and output_memory_type
9attributes and provide automatic thread-local CUDA stream management for GPU
10frameworks to enable true parallelization across multiple threads.
12REFACTORED: Uses enum-driven metaprogramming to eliminate 79% of code duplication.
13"""
15import functools
16import inspect
17import logging
18import threading
19from typing import Any, Callable, Optional, TypeVar
21from openhcs.constants.constants import VALID_MEMORY_TYPES, MemoryType
22from openhcs.core.utils import optional_import
23from openhcs.core.memory.oom_recovery import _execute_with_oom_recovery
24from openhcs.core.memory.framework_ops import _FRAMEWORK_OPS
25from openhcs.core.memory.dtype_scaling import SCALING_FUNCTIONS
26from openhcs.core.memory.slice_processing import process_slices
28logger = logging.getLogger(__name__)
30F = TypeVar('F', bound=Callable[..., Any])
32# Dtype conversion enum and utilities for consistent dtype handling across all frameworks
33from enum import Enum
34import numpy as np
36class DtypeConversion(Enum):
37 """Data type conversion modes for all memory type functions."""
39 PRESERVE_INPUT = "preserve" # Keep input dtype (default)
40 NATIVE_OUTPUT = "native" # Use framework's native output
41 UINT8 = "uint8" # Force uint8 (0-255 range)
42 UINT16 = "uint16" # Force uint16 (microscopy standard)
43 INT16 = "int16" # Force int16 (signed microscopy data)
44 INT32 = "int32" # Force int32 (large integer values)
45 FLOAT32 = "float32" # Force float32 (GPU performance)
46 FLOAT64 = "float64" # Force float64 (maximum precision)
48 @property
49 def numpy_dtype(self):
50 """Get the corresponding numpy dtype."""
51 dtype_map = {
52 self.UINT8: np.uint8,
53 self.UINT16: np.uint16,
54 self.INT16: np.int16,
55 self.INT32: np.int32,
56 self.FLOAT32: np.float32,
57 self.FLOAT64: np.float64,
58 }
59 return dtype_map.get(self, None)
62# Thread-local cache for lazy-loaded GPU frameworks
63_gpu_frameworks_cache = {}
66def _create_lazy_getter(framework_name: str):
67 """Factory function that creates a lazy import getter for a framework."""
68 def getter():
69 if framework_name not in _gpu_frameworks_cache:
70 _gpu_frameworks_cache[framework_name] = optional_import(framework_name)
71 if _gpu_frameworks_cache[framework_name] is not None:
72 logger.debug(f"🔧 Lazy imported {framework_name} in thread {threading.current_thread().name}")
73 return _gpu_frameworks_cache[framework_name]
74 return getter
77# Auto-generate lazy getters for all GPU frameworks
78for mem_type in MemoryType:
79 ops = _FRAMEWORK_OPS[mem_type]
80 if ops['lazy_getter'] is not None:
81 getter_func = _create_lazy_getter(ops['import_name'])
82 globals()[f"_get_{ops['import_name']}"] = getter_func
85# Thread-local storage for GPU streams and contexts
86_thread_gpu_contexts = threading.local()
88class ThreadGPUContext:
89 """Thread-local GPU context manager for CUDA streams."""
91 def __init__(self):
92 self.cupy_stream = None
93 self.torch_stream = None
94 self.tensorflow_device = None
95 self.jax_device = None
97 def get_cupy_stream(self):
98 """Get or create thread-local CuPy stream."""
99 if self.cupy_stream is None:
100 cupy = _get_cupy()
101 if cupy is not None and hasattr(cupy, 'cuda'):
102 self.cupy_stream = cupy.cuda.Stream()
103 logger.debug(f"🔧 Created CuPy stream for thread {threading.current_thread().name}")
104 return self.cupy_stream
106 def get_torch_stream(self):
107 """Get or create thread-local PyTorch stream."""
108 if self.torch_stream is None:
109 torch = _get_torch()
110 if torch is not None and hasattr(torch, 'cuda') and torch.cuda.is_available():
111 self.torch_stream = torch.cuda.Stream()
112 logger.debug(f"🔧 Created PyTorch stream for thread {threading.current_thread().name}")
113 return self.torch_stream
116def _get_thread_gpu_context():
117 """Get or create thread-local GPU context."""
118 if not hasattr(_thread_gpu_contexts, 'context'):
119 _thread_gpu_contexts.context = ThreadGPUContext()
120 return _thread_gpu_contexts.context
123def memory_types(
124 input_type: str,
125 output_type: str,
126 contract: Optional[Callable[[Any], bool]] = None
127) -> Callable[[F], F]:
128 """
129 Base decorator for declaring memory types of a function.
131 This is the foundation decorator that all memory-type-specific decorators build upon.
132 """
133 def decorator(func: F) -> F:
134 @functools.wraps(func)
135 def wrapper(*args, **kwargs):
136 result = func(*args, **kwargs)
138 # Apply contract validation if provided
139 if contract is not None and not contract(result): 139 ↛ 140line 139 didn't jump to line 140 because the condition on line 139 was never true
140 raise ValueError(f"Function {func.__name__} violated its output contract")
142 return result
144 # Attach memory type metadata
145 wrapper.input_memory_type = input_type
146 wrapper.output_memory_type = output_type
148 return wrapper
150 return decorator
153def _create_dtype_wrapper(func, mem_type: MemoryType, func_name: str):
154 """
155 Auto-generate dtype preservation wrapper for any memory type.
157 This single function replaces 6 nearly-identical dtype wrapper functions.
158 """
159 ops = _FRAMEWORK_OPS[mem_type]
160 scale_func = SCALING_FUNCTIONS[mem_type.value]
162 @functools.wraps(func)
163 def dtype_wrapper(image, *args, dtype_conversion=None, slice_by_slice: bool = False, **kwargs):
164 # Set default dtype_conversion if not provided
165 if dtype_conversion is None:
166 dtype_conversion = DtypeConversion.PRESERVE_INPUT
168 try:
169 # Store original dtype
170 original_dtype = image.dtype
172 # Handle slice_by_slice processing for 3D arrays
173 if slice_by_slice and hasattr(image, 'ndim') and image.ndim == 3: 173 ↛ 174line 173 didn't jump to line 174 because the condition on line 173 was never true
174 result = process_slices(image, func, args, kwargs)
175 else:
176 # Call the original function normally
177 result = func(image, *args, **kwargs)
179 # Apply dtype conversion based on enum value
180 if hasattr(result, 'dtype') and dtype_conversion is not None:
181 if dtype_conversion == DtypeConversion.PRESERVE_INPUT: 181 ↛ 185line 181 didn't jump to line 185 because the condition on line 181 was always true
182 # Preserve input dtype
183 if result.dtype != original_dtype: 183 ↛ 184line 183 didn't jump to line 184 because the condition on line 183 was never true
184 result = scale_func(result, original_dtype)
185 elif dtype_conversion == DtypeConversion.NATIVE_OUTPUT:
186 # Return framework's native output dtype
187 pass # No conversion needed
188 else:
189 # Force specific dtype
190 target_dtype = dtype_conversion.numpy_dtype
191 if target_dtype is not None:
192 result = scale_func(result, target_dtype)
194 return result
195 except Exception as e:
196 logger.error(f"Error in {mem_type.value} dtype/slice preserving wrapper for {func_name}: {e}")
197 # Return original result on error
198 return func(image, *args, **kwargs)
200 # Update function signature to include new parameters
201 try:
202 original_sig = inspect.signature(func)
203 new_params = list(original_sig.parameters.values())
205 # Check if parameters already exist
206 param_names = [p.name for p in new_params]
208 # Add dtype_conversion parameter first (before slice_by_slice)
209 if 'dtype_conversion' not in param_names: 209 ↛ 219line 209 didn't jump to line 219 because the condition on line 209 was always true
210 dtype_param = inspect.Parameter(
211 'dtype_conversion',
212 inspect.Parameter.KEYWORD_ONLY,
213 default=DtypeConversion.PRESERVE_INPUT,
214 annotation=Optional[DtypeConversion]
215 )
216 new_params.append(dtype_param)
218 # Add slice_by_slice parameter
219 if 'slice_by_slice' not in param_names:
220 slice_param = inspect.Parameter(
221 'slice_by_slice',
222 inspect.Parameter.KEYWORD_ONLY,
223 default=False,
224 annotation=bool
225 )
226 new_params.append(slice_param)
228 # Create new signature
229 new_sig = original_sig.replace(parameters=new_params)
230 dtype_wrapper.__signature__ = new_sig
232 # Update docstring
233 if dtype_wrapper.__doc__:
234 dtype_wrapper.__doc__ += f"\n\n Additional Parameters (added by {mem_type.value} decorator):\n"
235 dtype_wrapper.__doc__ += " dtype_conversion (DtypeConversion, optional): How to handle output dtype.\n"
236 dtype_wrapper.__doc__ += " Defaults to PRESERVE_INPUT (match input dtype).\n"
237 dtype_wrapper.__doc__ += " slice_by_slice (bool, optional): Process 3D arrays slice-by-slice.\n"
238 dtype_wrapper.__doc__ += " Defaults to False. Prevents cross-slice contamination.\n"
240 except Exception as e:
241 logger.warning(f"Could not update signature for {func_name}: {e}")
243 return dtype_wrapper
246def _create_gpu_wrapper(func, mem_type: MemoryType, oom_recovery: bool):
247 """
248 Auto-generate GPU stream/device wrapper for any GPU memory type.
250 This function creates the GPU-specific wrapper with stream management and OOM recovery.
251 """
252 ops = _FRAMEWORK_OPS[mem_type]
253 framework_name = ops['import_name']
254 lazy_getter = globals().get(ops['lazy_getter'])
256 @functools.wraps(func)
257 def gpu_wrapper(*args, **kwargs):
258 framework = lazy_getter()
260 # Check if GPU is available for this framework
261 if framework is not None:
262 gpu_check_expr = ops['gpu_check'].format(mod=framework_name)
263 try:
264 gpu_available = eval(gpu_check_expr, {framework_name: framework})
265 except:
266 gpu_available = False
268 if gpu_available:
269 # Get thread-local context
270 ctx = _get_thread_gpu_context()
272 # Get stream if framework supports it
273 stream = None
274 if mem_type == MemoryType.CUPY:
275 stream = ctx.get_cupy_stream()
276 elif mem_type == MemoryType.TORCH:
277 stream = ctx.get_torch_stream()
279 # Define execution function that captures args/kwargs
280 def execute_with_stream():
281 if stream is not None:
282 with stream:
283 return func(*args, **kwargs)
284 else:
285 return func(*args, **kwargs)
287 # Execute with OOM recovery if enabled
288 if oom_recovery and ops['has_oom_recovery']:
289 return _execute_with_oom_recovery(execute_with_stream, mem_type.value)
290 else:
291 return execute_with_stream()
293 # CPU fallback or framework not available
294 return func(*args, **kwargs)
296 # Preserve memory type attributes
297 gpu_wrapper.input_memory_type = func.input_memory_type
298 gpu_wrapper.output_memory_type = func.output_memory_type
300 return gpu_wrapper
303def _create_memory_decorator(mem_type: MemoryType):
304 """
305 Factory function that creates a decorator for a specific memory type.
307 This single factory replaces 6 nearly-identical decorator functions.
308 """
309 ops = _FRAMEWORK_OPS[mem_type]
311 def decorator(func=None, *, input_type=mem_type.value, output_type=mem_type.value,
312 oom_recovery=True, contract=None):
313 """
314 Decorator for {mem_type} memory type functions.
316 Args:
317 func: Function to decorate (when used as @decorator)
318 input_type: Expected input memory type (default: {mem_type})
319 output_type: Expected output memory type (default: {mem_type})
320 oom_recovery: Enable automatic OOM recovery (default: True)
321 contract: Optional validation function for outputs
323 Returns:
324 Decorated function with memory type metadata and dtype preservation
325 """
326 def inner_decorator(func):
327 # Apply base memory_types decorator
328 memory_decorator = memory_types(input_type=input_type, output_type=output_type, contract=contract)
329 func = memory_decorator(func)
331 # Apply dtype preservation wrapper
332 func = _create_dtype_wrapper(func, mem_type, func.__name__)
334 # Apply GPU wrapper if this is a GPU memory type
335 if ops['gpu_check'] is not None:
336 func = _create_gpu_wrapper(func, mem_type, oom_recovery)
338 return func
340 # Handle both @decorator and @decorator() forms
341 if func is None: 341 ↛ 342line 341 didn't jump to line 342 because the condition on line 341 was never true
342 return inner_decorator
343 return inner_decorator(func)
345 # Set proper function name and docstring
346 decorator.__name__ = mem_type.value
347 decorator.__doc__ = decorator.__doc__.format(mem_type=ops['display_name'])
349 return decorator
352# Auto-generate all 6 memory type decorators
353for mem_type in MemoryType:
354 decorator_func = _create_memory_decorator(mem_type)
355 globals()[mem_type.value] = decorator_func
358# Export all decorators
359__all__ = [
360 'memory_types',
361 'DtypeConversion',
362 'numpy',
363 'cupy',
364 'torch',
365 'tensorflow',
366 'jax',
367 'pyclesperanto',
368]