Coverage for src/arraybridge/decorators.py: 63%
177 statements
« prev ^ index » next coverage.py v7.11.0, created at 2025-11-03 05:09 +0000
« prev ^ index » next coverage.py v7.11.0, created at 2025-11-03 05:09 +0000
1"""
2Memory type declaration decorators for OpenHCS.
4This module provides decorators for explicitly declaring the memory interface
5of pure functions, enforcing Clause 106-A (Declared Memory Types) and supporting
6memory-type-aware dispatching and orchestration.
8These decorators annotate functions with input_memory_type and output_memory_type
9attributes and provide automatic thread-local CUDA stream management for GPU
10frameworks to enable true parallelization across multiple threads.
12REFACTORED: Uses enum-driven metaprogramming to eliminate 79% of code duplication.
13"""
15import functools
16import inspect
17import logging
18import threading
19from enum import Enum
20from typing import Any, Callable, Optional, TypeVar
22import numpy as np
24from arraybridge.dtype_scaling import SCALING_FUNCTIONS
25from arraybridge.framework_ops import _FRAMEWORK_OPS
26from arraybridge.oom_recovery import _execute_with_oom_recovery
27from arraybridge.slice_processing import process_slices
28from arraybridge.types import MemoryType
29from arraybridge.utils import optional_import
31logger = logging.getLogger(__name__)
33F = TypeVar("F", bound=Callable[..., Any])
36class DtypeConversion(Enum):
37 """Data type conversion modes for all memory type functions."""
39 PRESERVE_INPUT = "preserve" # Keep input dtype (default)
40 NATIVE_OUTPUT = "native" # Use framework's native output
41 UINT8 = "uint8" # Force uint8 (0-255 range)
42 UINT16 = "uint16" # Force uint16 (microscopy standard)
43 INT16 = "int16" # Force int16 (signed microscopy data)
44 INT32 = "int32" # Force int32 (large integer values)
45 FLOAT32 = "float32" # Force float32 (GPU performance)
46 FLOAT64 = "float64" # Force float64 (maximum precision)
48 @property
49 def numpy_dtype(self):
50 """Get the corresponding numpy dtype."""
51 dtype_map = {
52 self.UINT8: np.uint8,
53 self.UINT16: np.uint16,
54 self.INT16: np.int16,
55 self.INT32: np.int32,
56 self.FLOAT32: np.float32,
57 self.FLOAT64: np.float64,
58 }
59 return dtype_map.get(self, None)
62# Thread-local cache for lazy-loaded GPU frameworks
63_gpu_frameworks_cache = {}
66def _create_lazy_getter(framework_name: str):
67 """Factory function that creates a lazy import getter for a framework."""
69 def getter():
70 if framework_name not in _gpu_frameworks_cache:
71 _gpu_frameworks_cache[framework_name] = optional_import(framework_name)
72 if _gpu_frameworks_cache[framework_name] is not None:
73 logger.debug(
74 f"🔧 Lazy imported {framework_name} in thread "
75 f"{threading.current_thread().name}"
76 )
77 return _gpu_frameworks_cache[framework_name]
79 return getter
82# Auto-generate lazy getters for all GPU frameworks
83for mem_type in MemoryType:
84 ops = _FRAMEWORK_OPS[mem_type]
85 if ops["lazy_getter"] is not None:
86 getter_func = _create_lazy_getter(ops["import_name"])
87 globals()[f"_get_{ops['import_name']}"] = getter_func
90# Thread-local storage for GPU streams and contexts
91_thread_gpu_contexts = threading.local()
94class ThreadGPUContext:
95 """Thread-local GPU context manager for CUDA streams."""
97 def __init__(self):
98 self.cupy_stream = None
99 self.torch_stream = None
100 self.tensorflow_device = None
101 self.jax_device = None
103 def get_cupy_stream(self):
104 """Get or create thread-local CuPy stream."""
105 if self.cupy_stream is None:
106 cupy = globals().get("_get_cupy", lambda: None)() # noqa: F821
107 if cupy is not None and hasattr(cupy, "cuda"):
108 self.cupy_stream = cupy.cuda.Stream()
109 logger.debug(f"🔧 Created CuPy stream for thread {threading.current_thread().name}")
110 return self.cupy_stream
112 def get_torch_stream(self):
113 """Get or create thread-local PyTorch stream."""
114 if self.torch_stream is None:
115 torch = globals().get("_get_torch", lambda: None)() # noqa: F821
116 if torch is not None and hasattr(torch, "cuda") and torch.cuda.is_available():
117 self.torch_stream = torch.cuda.Stream()
118 logger.debug(
119 f"🔧 Created PyTorch stream for thread " f"{threading.current_thread().name}"
120 )
121 return self.torch_stream
124def _get_thread_gpu_context():
125 """Get or create thread-local GPU context."""
126 if not hasattr(_thread_gpu_contexts, "context"):
127 _thread_gpu_contexts.context = ThreadGPUContext()
128 return _thread_gpu_contexts.context
131def memory_types(
132 input_type: str, output_type: str, contract: Optional[Callable[[Any], bool]] = None
133) -> Callable[[F], F]:
134 """
135 Base decorator for declaring memory types of a function.
137 This is the foundation decorator that all memory-type-specific decorators build upon.
138 """
140 def decorator(func: F) -> F:
141 @functools.wraps(func)
142 def wrapper(*args, **kwargs):
143 result = func(*args, **kwargs)
145 # Apply contract validation if provided
146 if contract is not None and not contract(result):
147 raise ValueError(f"Function {func.__name__} violated its output contract")
149 return result
151 # Attach memory type metadata
152 wrapper.input_memory_type = input_type
153 wrapper.output_memory_type = output_type
155 return wrapper
157 return decorator
160def _create_dtype_wrapper(func, mem_type: MemoryType, func_name: str):
161 """
162 Auto-generate dtype preservation wrapper for any memory type.
164 This single function replaces 6 nearly-identical dtype wrapper functions.
165 """
166 _FRAMEWORK_OPS[mem_type]
167 scale_func = SCALING_FUNCTIONS[mem_type.value]
169 @functools.wraps(func)
170 def dtype_wrapper(image, *args, dtype_conversion=None, slice_by_slice: bool = False, **kwargs):
171 # Set default dtype_conversion if not provided
172 if dtype_conversion is None:
173 dtype_conversion = DtypeConversion.PRESERVE_INPUT
175 try:
176 # Store original dtype
177 original_dtype = image.dtype
179 # Handle slice_by_slice processing for 3D arrays
180 if slice_by_slice and hasattr(image, "ndim") and image.ndim == 3:
181 result = process_slices(image, func, args, kwargs)
182 else:
183 # Call the original function normally
184 result = func(image, *args, **kwargs)
186 # Apply dtype conversion based on enum value
187 if hasattr(result, "dtype") and dtype_conversion is not None:
188 if dtype_conversion == DtypeConversion.PRESERVE_INPUT:
189 # Preserve input dtype
190 if result.dtype != original_dtype:
191 result = scale_func(result, original_dtype)
192 elif dtype_conversion == DtypeConversion.NATIVE_OUTPUT:
193 # Return framework's native output dtype
194 pass # No conversion needed
195 else:
196 # Force specific dtype
197 target_dtype = dtype_conversion.numpy_dtype
198 if target_dtype is not None:
199 result = scale_func(result, target_dtype)
201 return result
202 except Exception as e:
203 logger.error(
204 f"Error in {mem_type.value} dtype/slice preserving wrapper " f"for {func_name}: {e}"
205 )
206 # Return original result on error
207 return func(image, *args, **kwargs)
209 # Update function signature to include new parameters
210 try:
211 original_sig = inspect.signature(func)
212 new_params = list(original_sig.parameters.values())
214 # Check if parameters already exist
215 param_names = [p.name for p in new_params]
217 # Add dtype_conversion parameter first (before slice_by_slice)
218 if "dtype_conversion" not in param_names:
219 dtype_param = inspect.Parameter(
220 "dtype_conversion",
221 inspect.Parameter.KEYWORD_ONLY,
222 default=DtypeConversion.PRESERVE_INPUT,
223 annotation=Optional[DtypeConversion],
224 )
225 new_params.append(dtype_param)
227 # Add slice_by_slice parameter
228 if "slice_by_slice" not in param_names:
229 slice_param = inspect.Parameter(
230 "slice_by_slice", inspect.Parameter.KEYWORD_ONLY, default=False, annotation=bool
231 )
232 new_params.append(slice_param)
234 # Create new signature
235 new_sig = original_sig.replace(parameters=new_params)
236 dtype_wrapper.__signature__ = new_sig
238 # Update docstring
239 if dtype_wrapper.__doc__:
240 dtype_wrapper.__doc__ += (
241 f"\n\n Additional Parameters " f"(added by {mem_type.value} decorator):\n"
242 )
243 dtype_wrapper.__doc__ += (
244 " dtype_conversion (DtypeConversion, optional): "
245 "How to handle output dtype.\n"
246 )
247 dtype_wrapper.__doc__ += " Defaults to PRESERVE_INPUT (match input dtype).\n"
248 dtype_wrapper.__doc__ += (
249 " slice_by_slice (bool, optional): " "Process 3D arrays slice-by-slice.\n"
250 )
251 dtype_wrapper.__doc__ += (
252 " Defaults to False. " "Prevents cross-slice contamination.\n"
253 )
255 except Exception as e:
256 logger.warning(f"Could not update signature for {func_name}: {e}")
258 return dtype_wrapper
261def _create_gpu_wrapper(func, mem_type: MemoryType, oom_recovery: bool):
262 """
263 Auto-generate GPU stream/device wrapper for any GPU memory type.
265 This function creates the GPU-specific wrapper with stream management and OOM recovery.
266 """
267 ops = _FRAMEWORK_OPS[mem_type]
268 framework_name = ops["import_name"]
269 lazy_getter = globals().get(ops["lazy_getter"])
271 @functools.wraps(func)
272 def gpu_wrapper(*args, **kwargs):
273 framework = lazy_getter()
275 # Check if GPU is available for this framework
276 if framework is not None:
277 gpu_check_expr = ops["gpu_check"].format(mod=framework_name)
278 try:
279 gpu_available = eval(gpu_check_expr, {framework_name: framework})
280 except Exception:
281 gpu_available = False
283 if gpu_available:
284 # Get thread-local context
285 ctx = _get_thread_gpu_context()
287 # Get stream if framework supports it
288 stream = None
289 if mem_type == MemoryType.CUPY:
290 stream = ctx.get_cupy_stream()
291 elif mem_type == MemoryType.TORCH:
292 stream = ctx.get_torch_stream()
294 # Define execution function that captures args/kwargs
295 def execute_with_stream():
296 if stream is not None:
297 with stream:
298 return func(*args, **kwargs)
299 else:
300 return func(*args, **kwargs)
302 # Execute with OOM recovery if enabled
303 if oom_recovery and ops["has_oom_recovery"]:
304 return _execute_with_oom_recovery(execute_with_stream, mem_type.value)
305 else:
306 return execute_with_stream()
308 # CPU fallback or framework not available
309 return func(*args, **kwargs)
311 # Preserve memory type attributes
312 gpu_wrapper.input_memory_type = func.input_memory_type
313 gpu_wrapper.output_memory_type = func.output_memory_type
315 return gpu_wrapper
318def _create_memory_decorator(mem_type: MemoryType):
319 """
320 Factory function that creates a decorator for a specific memory type.
322 This single factory replaces 6 nearly-identical decorator functions.
323 """
324 ops = _FRAMEWORK_OPS[mem_type]
326 def decorator(
327 func=None,
328 *,
329 input_type=mem_type.value,
330 output_type=mem_type.value,
331 oom_recovery=True,
332 contract=None,
333 ):
334 """
335 Decorator for {mem_type} memory type functions.
337 Args:
338 func: Function to decorate (when used as @decorator)
339 input_type: Expected input memory type (default: {mem_type})
340 output_type: Expected output memory type (default: {mem_type})
341 oom_recovery: Enable automatic OOM recovery (default: True)
342 contract: Optional validation function for outputs
344 Returns:
345 Decorated function with memory type metadata and dtype preservation
346 """
348 def inner_decorator(func):
349 # Apply base memory_types decorator
350 memory_decorator = memory_types(
351 input_type=input_type, output_type=output_type, contract=contract
352 )
353 func = memory_decorator(func)
355 # Apply dtype preservation wrapper
356 func = _create_dtype_wrapper(func, mem_type, func.__name__)
358 # Apply GPU wrapper if this is a GPU memory type
359 if ops["gpu_check"] is not None:
360 func = _create_gpu_wrapper(func, mem_type, oom_recovery)
362 return func
364 # Handle both @decorator and @decorator() forms
365 if func is None:
366 return inner_decorator
367 return inner_decorator(func)
369 # Set proper function name and docstring
370 decorator.__name__ = mem_type.value
371 decorator.__doc__ = decorator.__doc__.format(mem_type=ops["display_name"])
373 return decorator
376# Auto-generate all 6 memory type decorators
377for mem_type in MemoryType:
378 decorator_func = _create_memory_decorator(mem_type)
379 globals()[mem_type.value] = decorator_func
382# Export all decorators
383__all__ = [
384 "memory_types",
385 "DtypeConversion",
386 "numpy", # noqa: F822
387 "cupy", # noqa: F822
388 "torch", # noqa: F822
389 "tensorflow", # noqa: F822
390 "jax", # noqa: F822
391 "pyclesperanto", # noqa: F822
392]