Coverage for openhcs/core/memory/decorators.py: 28.9%
763 statements
« prev ^ index » next coverage.py v7.10.7, created at 2025-10-01 18:33 +0000
« prev ^ index » next coverage.py v7.10.7, created at 2025-10-01 18:33 +0000
1"""
2Memory type declaration decorators for OpenHCS.
4This module provides decorators for explicitly declaring the memory interface
5of pure functions, enforcing Clause 106-A (Declared Memory Types) and supporting
6memory-type-aware dispatching and orchestration.
8These decorators annotate functions with input_memory_type and output_memory_type
9attributes and provide automatic thread-local CUDA stream management for GPU
10frameworks to enable true parallelization across multiple threads.
11"""
13import functools
14import logging
15import threading
16from typing import Any, Callable, Optional, TypeVar
18from openhcs.constants.constants import VALID_MEMORY_TYPES
19from openhcs.core.utils import optional_import
20from openhcs.core.memory.oom_recovery import _execute_with_oom_recovery
22# Direct import for default contract (inlined single-use method per RST principle)
24logger = logging.getLogger(__name__)
26F = TypeVar('F', bound=Callable[..., Any])
28# Dtype conversion enum and utilities for consistent dtype handling across all frameworks
29from enum import Enum
30import numpy as np
32class DtypeConversion(Enum):
33 """Data type conversion modes for all memory type functions."""
35 PRESERVE_INPUT = "preserve" # Keep input dtype (default)
36 NATIVE_OUTPUT = "native" # Use framework's native output
37 UINT8 = "uint8" # Force uint8 (0-255 range)
38 UINT16 = "uint16" # Force uint16 (microscopy standard)
39 INT16 = "int16" # Force int16 (signed microscopy data)
40 INT32 = "int32" # Force int32 (large integer values)
41 FLOAT32 = "float32" # Force float32 (GPU performance)
42 FLOAT64 = "float64" # Force float64 (maximum precision)
44 @property
45 def numpy_dtype(self):
46 """Get the corresponding numpy dtype."""
47 dtype_map = {
48 self.UINT8: np.uint8,
49 self.UINT16: np.uint16,
50 self.INT16: np.int16,
51 self.INT32: np.int32,
52 self.FLOAT32: np.float32,
53 self.FLOAT64: np.float64,
54 }
55 return dtype_map.get(self, None)
58def _scale_and_convert_numpy(result, target_dtype):
59 """Scale numpy results to target integer range and convert dtype."""
60 if not hasattr(result, 'dtype'):
61 return result
63 # Check if result is floating point and target is integer
64 result_is_float = np.issubdtype(result.dtype, np.floating)
65 target_is_int = target_dtype in [np.uint8, np.uint16, np.uint32, np.int8, np.int16, np.int32]
67 if result_is_float and target_is_int:
68 # Scale floating point results to integer range
69 result_min = result.min()
70 result_max = result.max()
72 if result_max > result_min: # Avoid division by zero
73 # Normalize to [0, 1] range
74 normalized = (result - result_min) / (result_max - result_min)
76 # Scale to target dtype range
77 if target_dtype == np.uint8:
78 scaled = normalized * 255.0
79 elif target_dtype == np.uint16:
80 scaled = normalized * 65535.0
81 elif target_dtype == np.uint32:
82 scaled = normalized * 4294967295.0
83 elif target_dtype == np.int16:
84 scaled = normalized * 65535.0 - 32768.0
85 elif target_dtype == np.int32:
86 scaled = normalized * 4294967295.0 - 2147483648.0
87 else:
88 scaled = normalized
90 return scaled.astype(target_dtype)
91 else:
92 # Constant image, just convert dtype
93 return result.astype(target_dtype)
94 else:
95 # Direct conversion for compatible types
96 return result.astype(target_dtype)
99def _scale_and_convert_pyclesperanto(result, target_dtype):
100 """Scale pyclesperanto results to target integer range and convert dtype."""
101 try:
102 import pyclesperanto as cle
103 except ImportError:
104 return result
106 if not hasattr(result, 'dtype'):
107 return result
109 # Check if result is floating point and target is integer
110 result_is_float = np.issubdtype(result.dtype, np.floating)
111 target_is_int = target_dtype in [np.uint8, np.uint16, np.uint32, np.int8, np.int16, np.int32]
113 if result_is_float and target_is_int:
114 # Get min/max of result for proper scaling
115 result_min = float(cle.minimum_of_all_pixels(result))
116 result_max = float(cle.maximum_of_all_pixels(result))
118 if result_max > result_min: # Avoid division by zero
119 # Normalize to [0, 1] range
120 normalized = cle.subtract_image_from_scalar(result, scalar=result_min)
121 range_val = result_max - result_min
122 normalized = cle.multiply_image_and_scalar(normalized, scalar=1.0/range_val)
124 # Scale to target dtype range
125 if target_dtype == np.uint8:
126 scaled = cle.multiply_image_and_scalar(normalized, scalar=255.0)
127 elif target_dtype == np.uint16:
128 scaled = cle.multiply_image_and_scalar(normalized, scalar=65535.0)
129 elif target_dtype == np.uint32:
130 scaled = cle.multiply_image_and_scalar(normalized, scalar=4294967295.0)
131 elif target_dtype == np.int16:
132 scaled = cle.multiply_image_and_scalar(normalized, scalar=65535.0)
133 scaled = cle.subtract_image_from_scalar(scaled, scalar=32768.0)
134 elif target_dtype == np.int32:
135 scaled = cle.multiply_image_and_scalar(normalized, scalar=4294967295.0)
136 scaled = cle.subtract_image_from_scalar(scaled, scalar=2147483648.0)
137 else:
138 scaled = normalized
140 # Convert to target dtype using push/pull method
141 scaled_cpu = cle.pull(scaled).astype(target_dtype)
142 return cle.push(scaled_cpu)
143 else:
144 # Constant image, just convert dtype
145 result_cpu = cle.pull(result).astype(target_dtype)
146 return cle.push(result_cpu)
147 else:
148 # Direct conversion for compatible types
149 result_cpu = cle.pull(result).astype(target_dtype)
150 return cle.push(result_cpu)
153def _scale_and_convert_cupy(result, target_dtype):
154 """Scale CuPy results to target integer range and convert dtype."""
155 try:
156 import cupy as cp
157 except ImportError:
158 return result
160 if not hasattr(result, 'dtype'):
161 return result
163 # If result is floating point and target is integer, scale appropriately
164 if cp.issubdtype(result.dtype, cp.floating) and not cp.issubdtype(target_dtype, cp.floating):
165 # Clip to [0, 1] range and scale to integer range
166 clipped = cp.clip(result, 0, 1)
167 if target_dtype == cp.uint8:
168 return (clipped * 255).astype(target_dtype)
169 elif target_dtype == cp.uint16:
170 return (clipped * 65535).astype(target_dtype)
171 elif target_dtype == cp.uint32:
172 return (clipped * 4294967295).astype(target_dtype)
173 else:
174 # For other integer types, just convert without scaling
175 return result.astype(target_dtype)
177 # Direct conversion for same numeric type families
178 return result.astype(target_dtype)
181# GPU frameworks imported lazily to prevent thread explosion
182# These will be imported only when actually needed by functions
183_gpu_frameworks_cache = {}
185def _get_cupy():
186 """Lazy import CuPy only when needed."""
187 if 'cupy' not in _gpu_frameworks_cache:
188 _gpu_frameworks_cache['cupy'] = optional_import("cupy")
189 if _gpu_frameworks_cache['cupy'] is not None:
190 logger.debug(f"🔧 Lazy imported CuPy in thread {threading.current_thread().name}")
191 return _gpu_frameworks_cache['cupy']
193def _get_torch():
194 """Lazy import PyTorch only when needed."""
195 if 'torch' not in _gpu_frameworks_cache:
196 _gpu_frameworks_cache['torch'] = optional_import("torch")
197 if _gpu_frameworks_cache['torch'] is not None:
198 logger.debug(f"🔧 Lazy imported PyTorch in thread {threading.current_thread().name}")
199 return _gpu_frameworks_cache['torch']
201def _get_tensorflow():
202 """Lazy import TensorFlow only when needed."""
203 if 'tensorflow' not in _gpu_frameworks_cache:
204 _gpu_frameworks_cache['tensorflow'] = optional_import("tensorflow")
205 if _gpu_frameworks_cache['tensorflow'] is not None:
206 logger.debug(f"🔧 Lazy imported TensorFlow in thread {threading.current_thread().name}")
207 return _gpu_frameworks_cache['tensorflow']
209def _get_jax():
210 """Lazy import JAX only when needed."""
211 if 'jax' not in _gpu_frameworks_cache:
212 _gpu_frameworks_cache['jax'] = optional_import("jax")
213 if _gpu_frameworks_cache['jax'] is not None:
214 logger.debug(f"🔧 Lazy imported JAX in thread {threading.current_thread().name}")
215 return _gpu_frameworks_cache['jax']
217# Thread-local storage for GPU streams and contexts
218_thread_gpu_contexts = threading.local()
220class ThreadGPUContext:
221 """Unified thread-local GPU context manager to prevent stream leaks."""
223 def __init__(self):
224 self._cupy_stream = None
225 self._torch_stream = None
226 self._thread_name = threading.current_thread().name
228 def get_cupy_stream(self):
229 """Get or create the single CuPy stream for this thread."""
230 if self._cupy_stream is None:
231 cp = _get_cupy()
232 if cp is not None and hasattr(cp, 'cuda'):
233 self._cupy_stream = cp.cuda.Stream()
234 logger.debug(f"🔧 Created CuPy stream for thread {self._thread_name}")
235 return self._cupy_stream
237 def get_torch_stream(self):
238 """Get or create the single PyTorch stream for this thread."""
239 if self._torch_stream is None:
240 torch = _get_torch()
241 if torch is not None and hasattr(torch, 'cuda') and torch.cuda.is_available():
242 self._torch_stream = torch.cuda.Stream()
243 logger.debug(f"🔧 Created PyTorch stream for thread {self._thread_name}")
244 return self._torch_stream
246 def cleanup(self):
247 """Clean up streams when thread exits."""
248 if self._cupy_stream is not None:
249 logger.debug(f"🔧 Cleaning up CuPy stream for thread {self._thread_name}")
250 self._cupy_stream = None
252 if self._torch_stream is not None:
253 logger.debug(f"🔧 Cleaning up PyTorch stream for thread {self._thread_name}")
254 self._torch_stream = None
256def get_thread_gpu_context() -> ThreadGPUContext:
257 """Get the unified GPU context for the current thread."""
258 if not hasattr(_thread_gpu_contexts, 'gpu_context'):
259 _thread_gpu_contexts.gpu_context = ThreadGPUContext()
261 # Register cleanup for when thread exits
262 import weakref
263 def cleanup_on_thread_exit():
264 if hasattr(_thread_gpu_contexts, 'gpu_context'):
265 _thread_gpu_contexts.gpu_context.cleanup()
267 # Use weakref to avoid circular references
268 current_thread = threading.current_thread()
269 if hasattr(current_thread, '_cleanup_funcs'):
270 current_thread._cleanup_funcs.append(cleanup_on_thread_exit)
271 else:
272 current_thread._cleanup_funcs = [cleanup_on_thread_exit]
274 return _thread_gpu_contexts.gpu_context
277def memory_types(*, input_type: str, output_type: str, contract: Optional['ProcessingContract'] = None) -> Callable[[F], F]:
278 """
279 Decorator that explicitly declares the memory types for a function's input and output.
281 This decorator enforces Clause 106-A (Declared Memory Types) by requiring explicit
282 memory type declarations for both input and output.
284 Args:
285 input_type: The memory type for the function's input (e.g., "numpy", "cupy")
286 output_type: The memory type for the function's output (e.g., "numpy", "cupy")
287 contract: Optional processing contract declaration (defaults to PURE_3D)
289 Returns:
290 A decorator function that sets the memory type attributes
292 Raises:
293 ValueError: If input_type or output_type is not a supported memory type
294 """
295 # 🔒 Clause 88 — No Inferred Capabilities
296 # Validate memory types at decoration time, not runtime
297 if not input_type: 297 ↛ 298line 297 didn't jump to line 298 because the condition on line 297 was never true
298 raise ValueError(
299 "Clause 106-A Violation: input_type must be explicitly declared. "
300 "No default or inferred memory types are allowed."
301 )
303 if not output_type: 303 ↛ 304line 303 didn't jump to line 304 because the condition on line 303 was never true
304 raise ValueError(
305 "Clause 106-A Violation: output_type must be explicitly declared. "
306 "No default or inferred memory types are allowed."
307 )
309 # Validate that memory types are supported
310 if input_type not in VALID_MEMORY_TYPES: 310 ↛ 311line 310 didn't jump to line 311 because the condition on line 310 was never true
311 raise ValueError(
312 f"Clause 106-A Violation: input_type '{input_type}' is not supported. "
313 f"Supported types are: {', '.join(sorted(VALID_MEMORY_TYPES))}"
314 )
316 if output_type not in VALID_MEMORY_TYPES: 316 ↛ 317line 316 didn't jump to line 317 because the condition on line 316 was never true
317 raise ValueError(
318 f"Clause 106-A Violation: output_type '{output_type}' is not supported. "
319 f"Supported types are: {', '.join(sorted(VALID_MEMORY_TYPES))}"
320 )
322 def decorator(func: F) -> F:
323 """
324 Decorator function that sets memory type attributes on the function.
326 Args:
327 func: The function to decorate
329 Returns:
330 The decorated function with memory type attributes set
332 Raises:
333 ValueError: If the function already has different memory type attributes
334 """
335 # 🔒 Clause 66 — Immutability
336 # Check if memory type attributes already exist
337 if hasattr(func, 'input_memory_type') and func.input_memory_type != input_type: 337 ↛ 338line 337 didn't jump to line 338 because the condition on line 337 was never true
338 raise ValueError(
339 f"Clause 66 Violation: Function '{func.__name__}' already has input_memory_type "
340 f"'{func.input_memory_type}', cannot change to '{input_type}'."
341 )
343 if hasattr(func, 'output_memory_type') and func.output_memory_type != output_type: 343 ↛ 344line 343 didn't jump to line 344 because the condition on line 343 was never true
344 raise ValueError(
345 f"Clause 66 Violation: Function '{func.__name__}' already has output_memory_type "
346 f"'{func.output_memory_type}', cannot change to '{output_type}'."
347 )
349 # Set memory type attributes using canonical names
350 # 🔒 Clause 106-A.2 — Canonical Memory Type Attributes
351 func.input_memory_type = input_type
352 func.output_memory_type = output_type
354 # Set processing contract with fail-loud behavior (inlined per RST principle)
355 if contract is None: 355 ↛ 359line 355 didn't jump to line 359 because the condition on line 355 was always true
356 from openhcs.processing.backends.lib_registry.unified_registry import ProcessingContract
357 func.__processing_contract__ = ProcessingContract.PURE_3D
358 else:
359 func.__processing_contract__ = contract
361 # Return the function unchanged (no wrapper)
362 return func
364 return decorator
367def numpy(
368 func: Optional[F] = None,
369 *,
370 input_type: str = "numpy",
371 output_type: str = "numpy",
372 contract: Optional['ProcessingContract'] = None
373) -> Any:
374 """
375 Decorator that declares a function as operating on numpy arrays.
377 This is a convenience wrapper around memory_types with numpy defaults.
379 Args:
380 func: The function to decorate (optional)
381 input_type: The memory type for the function's input (default: "numpy")
382 output_type: The memory type for the function's output (default: "numpy")
383 contract: Optional processing contract declaration (defaults to PURE_3D)
385 Returns:
386 The decorated function with memory type attributes set
388 Raises:
389 ValueError: If input_type or output_type is not a supported memory type
390 """
391 def decorator_with_dtype_preservation(func: F) -> F:
392 # Set memory type attributes and contract
393 memory_decorator = memory_types(input_type=input_type, output_type=output_type, contract=contract)
394 func = memory_decorator(func)
396 # Apply dtype preservation wrapper
397 func = _create_numpy_dtype_preserving_wrapper(func, func.__name__)
399 return func
401 # Handle both @numpy and @numpy(input_type=..., output_type=...) forms
402 if func is None: 402 ↛ 403line 402 didn't jump to line 403 because the condition on line 402 was never true
403 return decorator_with_dtype_preservation
405 return decorator_with_dtype_preservation(func)
408def cupy(func: Optional[F] = None, *, input_type: str = "cupy", output_type: str = "cupy", oom_recovery: bool = True, contract: Optional['ProcessingContract'] = None) -> Any:
409 """
410 Decorator that declares a function as operating on cupy arrays.
412 This decorator provides automatic thread-local CUDA stream management for
413 true parallelization across multiple threads. Each thread gets its own
414 persistent CUDA stream that is reused for all CuPy operations.
416 Args:
417 func: The function to decorate (optional)
418 input_type: The memory type for the function's input (default: "cupy")
419 output_type: The memory type for the function's output (default: "cupy")
420 oom_recovery: Enable automatic OOM recovery (default: True)
421 contract: Optional processing contract declaration (defaults to PURE_3D)
423 Returns:
424 The decorated function with memory type attributes and stream management
426 Raises:
427 ValueError: If input_type or output_type is not a supported memory type
428 """
429 def decorator(func: F) -> F:
430 # Set memory type attributes and contract
431 memory_decorator = memory_types(input_type=input_type, output_type=output_type, contract=contract)
432 func = memory_decorator(func)
434 # Apply dtype preservation wrapper
435 func = _create_cupy_dtype_preserving_wrapper(func, func.__name__)
437 # Add CUDA stream wrapper if CuPy is available (lazy import)
438 @functools.wraps(func)
439 def wrapper(*args, **kwargs):
440 cp = _get_cupy()
441 if cp is not None and hasattr(cp, 'cuda'):
442 # Get unified thread context and CuPy stream
443 gpu_context = get_thread_gpu_context()
444 cupy_stream = gpu_context.get_cupy_stream()
446 def execute_with_stream():
447 if cupy_stream is not None:
448 # Execute function in stream context
449 with cupy_stream:
450 return func(*args, **kwargs)
451 else:
452 # No CUDA available, execute without stream
453 return func(*args, **kwargs)
455 # Execute with OOM recovery if enabled
456 if oom_recovery:
457 return _execute_with_oom_recovery(execute_with_stream, input_type)
458 else:
459 return execute_with_stream()
460 else:
461 # CuPy not available, execute without stream
462 return func(*args, **kwargs)
464 # Preserve memory type attributes
465 wrapper.input_memory_type = func.input_memory_type
466 wrapper.output_memory_type = func.output_memory_type
468 return wrapper
470 # Handle both @cupy and @cupy(input_type=..., output_type=...) forms
471 if func is None: 471 ↛ 472line 471 didn't jump to line 472 because the condition on line 471 was never true
472 return decorator
474 return decorator(func)
477def torch(
478 func: Optional[F] = None,
479 *,
480 input_type: str = "torch",
481 output_type: str = "torch",
482 oom_recovery: bool = True,
483 contract: Optional['ProcessingContract'] = None
484) -> Any:
485 """
486 Decorator that declares a function as operating on torch tensors.
488 This decorator provides automatic thread-local CUDA stream management for
489 true parallelization across multiple threads. Each thread gets its own
490 persistent CUDA stream that is reused for all PyTorch operations.
492 Args:
493 func: The function to decorate (optional)
494 input_type: The memory type for the function's input (default: "torch")
495 output_type: The memory type for the function's output (default: "torch")
496 oom_recovery: Enable automatic OOM recovery (default: True)
497 contract: Optional processing contract declaration (defaults to PURE_3D)
499 Returns:
500 The decorated function with memory type attributes and stream management
502 Raises:
503 ValueError: If input_type or output_type is not a supported memory type
504 """
505 def decorator(func: F) -> F:
506 # Set memory type attributes and contract
507 memory_decorator = memory_types(input_type=input_type, output_type=output_type, contract=contract)
508 func = memory_decorator(func)
510 # Apply dtype preservation wrapper
511 func = _create_torch_dtype_preserving_wrapper(func, func.__name__)
513 # Add CUDA stream wrapper if PyTorch is available and CUDA is available (lazy import)
514 @functools.wraps(func)
515 def wrapper(*args, **kwargs):
516 torch = _get_torch()
517 if torch is not None and hasattr(torch, 'cuda') and torch.cuda.is_available():
518 # Get unified thread context and PyTorch stream
519 gpu_context = get_thread_gpu_context()
520 torch_stream = gpu_context.get_torch_stream()
522 def execute_with_stream():
523 if torch_stream is not None:
524 # Execute function in stream context
525 with torch.cuda.stream(torch_stream):
526 return func(*args, **kwargs)
527 else:
528 # No CUDA available, execute without stream
529 return func(*args, **kwargs)
531 # Execute with OOM recovery if enabled
532 if oom_recovery:
533 return _execute_with_oom_recovery(execute_with_stream, input_type)
534 else:
535 return execute_with_stream()
536 else:
537 # PyTorch not available or CUDA not available, execute without stream
538 return func(*args, **kwargs)
540 # Preserve memory type attributes
541 wrapper.input_memory_type = func.input_memory_type
542 wrapper.output_memory_type = func.output_memory_type
544 return wrapper
546 # Handle both @torch and @torch(input_type=..., output_type=...) forms
547 if func is None: 547 ↛ 548line 547 didn't jump to line 548 because the condition on line 547 was never true
548 return decorator
550 return decorator(func)
553def tensorflow(
554 func: Optional[F] = None,
555 *,
556 input_type: str = "tensorflow",
557 output_type: str = "tensorflow",
558 oom_recovery: bool = True,
559 contract: Optional['ProcessingContract'] = None
560) -> Any:
561 """
562 Decorator that declares a function as operating on tensorflow tensors.
564 This decorator provides automatic thread-local GPU device context management
565 for parallelization across multiple threads. TensorFlow manages CUDA streams
566 internally, so we use device contexts for thread isolation.
568 Args:
569 func: The function to decorate (optional)
570 input_type: The memory type for the function's input (default: "tensorflow")
571 output_type: The memory type for the function's output (default: "tensorflow")
572 oom_recovery: Enable automatic OOM recovery (default: True)
573 contract: Optional processing contract declaration (defaults to PURE_3D)
575 Returns:
576 The decorated function with memory type attributes and device management
578 Raises:
579 ValueError: If input_type or output_type is not a supported memory type
580 """
581 def decorator(func: F) -> F:
582 # Set memory type attributes and contract
583 memory_decorator = memory_types(input_type=input_type, output_type=output_type, contract=contract)
584 func = memory_decorator(func)
586 # Apply dtype preservation wrapper
587 func = _create_tensorflow_dtype_preserving_wrapper(func, func.__name__)
589 # Add device context wrapper if TensorFlow is available and GPU is available (lazy import)
590 @functools.wraps(func)
591 def wrapper(*args, **kwargs):
592 tf = _get_tensorflow()
593 if tf is not None and tf.config.list_physical_devices('GPU'):
594 def execute_with_device():
595 # Use GPU device context for thread isolation
596 # TensorFlow manages internal CUDA streams automatically
597 with tf.device('/GPU:0'):
598 return func(*args, **kwargs)
600 # Execute with OOM recovery if enabled
601 if oom_recovery:
602 return _execute_with_oom_recovery(execute_with_device, input_type)
603 else:
604 return execute_with_device()
605 else:
606 # TensorFlow not available or GPU not available, execute without device context
607 return func(*args, **kwargs)
609 # Preserve memory type attributes
610 wrapper.input_memory_type = func.input_memory_type
611 wrapper.output_memory_type = func.output_memory_type
613 return wrapper
615 # Handle both @tensorflow and @tensorflow(input_type=..., output_type=...) forms
616 if func is None: 616 ↛ 617line 616 didn't jump to line 617 because the condition on line 616 was never true
617 return decorator
619 return decorator(func)
622def jax(
623 func: Optional[F] = None,
624 *,
625 input_type: str = "jax",
626 output_type: str = "jax",
627 oom_recovery: bool = True,
628 contract: Optional['ProcessingContract'] = None
629) -> Any:
630 """
631 Decorator that declares a function as operating on JAX arrays.
633 This decorator provides automatic thread-local GPU device placement for
634 parallelization across multiple threads. JAX/XLA manages CUDA streams
635 internally, so we use device placement for thread isolation.
637 Args:
638 func: The function to decorate (optional)
639 input_type: The memory type for the function's input (default: "jax")
640 output_type: The memory type for the function's output (default: "jax")
641 oom_recovery: Enable automatic OOM recovery (default: True)
642 contract: Optional processing contract declaration (defaults to PURE_3D)
644 Returns:
645 The decorated function with memory type attributes and device management
647 Raises:
648 ValueError: If input_type or output_type is not a supported memory type
649 """
650 def decorator(func: F) -> F:
651 # Set memory type attributes and contract
652 memory_decorator = memory_types(input_type=input_type, output_type=output_type, contract=contract)
653 func = memory_decorator(func)
655 # Apply dtype preservation wrapper
656 func = _create_jax_dtype_preserving_wrapper(func, func.__name__)
658 # Add device placement wrapper if JAX is available and GPU is available (lazy import)
659 @functools.wraps(func)
660 def wrapper(*args, **kwargs):
661 jax_module = _get_jax()
662 if jax_module is not None:
663 devices = jax_module.devices()
664 gpu_devices = [d for d in devices if d.platform == 'gpu']
666 if gpu_devices:
667 def execute_with_device():
668 # Use GPU device placement for thread isolation
669 # JAX/XLA manages internal CUDA streams automatically
670 with jax_module.default_device(gpu_devices[0]):
671 return func(*args, **kwargs)
673 # Execute with OOM recovery if enabled
674 if oom_recovery:
675 return _execute_with_oom_recovery(execute_with_device, input_type)
676 else:
677 return execute_with_device()
678 else:
679 # No GPU devices available, execute without device placement
680 return func(*args, **kwargs)
681 else:
682 # JAX not available, execute without device placement
683 return func(*args, **kwargs)
685 # Preserve memory type attributes
686 wrapper.input_memory_type = func.input_memory_type
687 wrapper.output_memory_type = func.output_memory_type
689 return wrapper
691 # Handle both @jax and @jax(input_type=..., output_type=...) forms
692 if func is None: 692 ↛ 693line 692 didn't jump to line 693 because the condition on line 692 was never true
693 return decorator
695 return decorator(func)
698def pyclesperanto(
699 func: Optional[F] = None,
700 *,
701 input_type: str = "pyclesperanto",
702 output_type: str = "pyclesperanto",
703 oom_recovery: bool = True,
704 contract: Optional['ProcessingContract'] = None
705) -> Any:
706 """
707 Decorator that declares a function as operating on pyclesperanto GPU arrays.
709 This decorator provides automatic OOM recovery for pyclesperanto functions.
711 Args:
712 func: The function to decorate (optional)
713 input_type: The memory type for the function's input (default: "pyclesperanto")
714 output_type: The memory type for the function's output (default: "pyclesperanto")
715 oom_recovery: Enable automatic OOM recovery (default: True)
716 contract: Optional processing contract declaration (defaults to PURE_3D)
718 Returns:
719 The decorated function with memory type attributes and OOM recovery
721 Raises:
722 ValueError: If input_type or output_type is not a supported memory type
723 """
724 def decorator(func: F) -> F:
725 # Set memory type attributes and contract
726 memory_decorator = memory_types(input_type=input_type, output_type=output_type, contract=contract)
727 func = memory_decorator(func)
729 # Apply dtype preservation wrapper
730 func = _create_pyclesperanto_dtype_preserving_wrapper(func, func.__name__)
732 # Add OOM recovery wrapper
733 @functools.wraps(func)
734 def wrapper(*args, **kwargs):
735 if oom_recovery:
736 return _execute_with_oom_recovery(lambda: func(*args, **kwargs), input_type)
737 else:
738 return func(*args, **kwargs)
740 # Preserve memory type attributes
741 wrapper.input_memory_type = func.input_memory_type
742 wrapper.output_memory_type = func.output_memory_type
744 # Make wrapper pickleable by preserving original function identity
745 wrapper.__module__ = getattr(func, '__module__', wrapper.__module__)
746 wrapper.__qualname__ = getattr(func, '__qualname__', wrapper.__qualname__)
748 # Store reference to original function for pickle support
749 wrapper.__wrapped__ = func
751 return wrapper
753 # Handle both @pyclesperanto and @pyclesperanto(input_type=..., output_type=...) forms
754 if func is None: 754 ↛ 755line 754 didn't jump to line 755 because the condition on line 754 was never true
755 return decorator
757 return decorator(func)
760# ============================================================================
761# Dtype Preservation Wrapper Functions
762# ============================================================================
764def _create_numpy_dtype_preserving_wrapper(original_func, func_name):
765 """
766 Create a wrapper that preserves input data type and adds slice_by_slice parameter for NumPy functions.
768 Many scikit-image functions return float64 regardless of input type.
769 This wrapper ensures the output has the same dtype as the input and adds
770 a slice_by_slice parameter to avoid cross-slice contamination in 3D arrays.
771 """
772 import numpy as np
773 import inspect
774 from functools import wraps
776 @wraps(original_func)
777 def numpy_dtype_and_slice_preserving_wrapper(image, *args, dtype_conversion=None, slice_by_slice: bool = False, **kwargs):
778 # Set default dtype_conversion if not provided and DtypeConversion is available
779 if dtype_conversion is None and DtypeConversion is not None:
780 dtype_conversion = DtypeConversion.PRESERVE_INPUT
782 try:
783 # Store original dtype
784 original_dtype = image.dtype
786 # Handle slice_by_slice processing for 3D arrays using OpenHCS stack utilities
787 if slice_by_slice and hasattr(image, 'ndim') and image.ndim == 3:
788 from openhcs.core.memory.stack_utils import unstack_slices, stack_slices, _detect_memory_type
790 # Detect memory type and use proper OpenHCS utilities
791 memory_type = _detect_memory_type(image)
792 gpu_id = 0 # Default GPU ID for slice processing
794 # Unstack 3D array into 2D slices
795 slices_2d = unstack_slices(image, memory_type, gpu_id)
797 # Process each slice and handle special outputs
798 main_outputs = []
799 special_outputs_list = []
801 for slice_2d in slices_2d:
802 slice_result = original_func(slice_2d, *args, **kwargs)
804 # Check if result is a tuple (indicating special outputs)
805 if isinstance(slice_result, tuple):
806 main_outputs.append(slice_result[0]) # First element is main output
807 special_outputs_list.append(slice_result[1:]) # Rest are special outputs
808 else:
809 main_outputs.append(slice_result) # Single output
811 # Stack main outputs back into 3D array
812 result = stack_slices(main_outputs, memory_type, gpu_id)
814 # If we have special outputs, combine them and return tuple
815 if special_outputs_list:
816 # Combine special outputs from all slices
817 combined_special_outputs = []
818 num_special_outputs = len(special_outputs_list[0])
820 for i in range(num_special_outputs):
821 # Collect the i-th special output from all slices
822 special_output_values = [slice_outputs[i] for slice_outputs in special_outputs_list]
823 combined_special_outputs.append(special_output_values)
825 # Return tuple: (stacked_main_output, combined_special_output1, combined_special_output2, ...)
826 result = (result, *combined_special_outputs)
827 else:
828 # Call the original function normally
829 result = original_func(image, *args, **kwargs)
831 # Apply dtype conversion based on enum value
832 if hasattr(result, 'dtype') and dtype_conversion is not None:
833 if dtype_conversion == DtypeConversion.PRESERVE_INPUT:
834 # Preserve input dtype
835 if result.dtype != original_dtype:
836 result = _scale_and_convert_numpy(result, original_dtype)
837 elif dtype_conversion == DtypeConversion.NATIVE_OUTPUT:
838 # Return NumPy's native output dtype
839 pass # No conversion needed
840 else:
841 # Force specific dtype
842 target_dtype = dtype_conversion.numpy_dtype
843 if target_dtype is not None:
844 result = _scale_and_convert_numpy(result, target_dtype)
846 return result
847 except Exception as e:
848 logger.error(f"Error in NumPy dtype/slice preserving wrapper for {func_name}: {e}")
849 # Return original result on error
850 return original_func(image, *args, **kwargs)
852 # Update function signature to include new parameters
853 try:
854 original_sig = inspect.signature(original_func)
855 new_params = list(original_sig.parameters.values())
857 # Check if slice_by_slice parameter already exists
858 param_names = [p.name for p in new_params]
859 # Add dtype_conversion parameter first (before slice_by_slice)
860 if 'dtype_conversion' not in param_names: 860 ↛ 869line 860 didn't jump to line 869 because the condition on line 860 was always true
861 dtype_param = inspect.Parameter(
862 'dtype_conversion',
863 inspect.Parameter.KEYWORD_ONLY,
864 default=DtypeConversion.PRESERVE_INPUT,
865 annotation=DtypeConversion
866 )
867 new_params.append(dtype_param)
869 if 'slice_by_slice' not in param_names: 869 ↛ 880line 869 didn't jump to line 880 because the condition on line 869 was always true
870 # Add slice_by_slice parameter as keyword-only (after dtype_conversion)
871 slice_param = inspect.Parameter(
872 'slice_by_slice',
873 inspect.Parameter.KEYWORD_ONLY,
874 default=False,
875 annotation=bool
876 )
877 new_params.append(slice_param)
879 # Create new signature and override the @wraps signature
880 new_sig = original_sig.replace(parameters=new_params)
881 numpy_dtype_and_slice_preserving_wrapper.__signature__ = new_sig
883 # Set type annotations manually for get_type_hints() compatibility
884 numpy_dtype_and_slice_preserving_wrapper.__annotations__ = getattr(original_func, '__annotations__', {}).copy()
885 numpy_dtype_and_slice_preserving_wrapper.__annotations__['dtype_conversion'] = DtypeConversion
886 numpy_dtype_and_slice_preserving_wrapper.__annotations__['slice_by_slice'] = bool
888 except Exception:
889 # If signature modification fails, continue without it
890 pass
892 # Update docstring to mention slice_by_slice parameter
893 original_doc = numpy_dtype_and_slice_preserving_wrapper.__doc__ or ""
894 additional_doc = """
896 Additional OpenHCS Parameters
897 -----------------------------
898 slice_by_slice : bool, optional (default: False)
899 If True, process 3D arrays slice-by-slice to avoid cross-slice contamination.
900 If False, use original 3D behavior. Recommended for edge detection functions
901 on stitched microscopy data to prevent artifacts at field boundaries.
903 dtype_conversion : DtypeConversion, optional (default: PRESERVE_INPUT)
904 Controls output data type conversion:
906 - PRESERVE_INPUT: Keep input dtype (uint16 → uint16)
907 - NATIVE_OUTPUT: Use NumPy's native output dtype
908 - UINT8: Force 8-bit unsigned integer (0-255 range)
909 - UINT16: Force 16-bit unsigned integer (microscopy standard)
910 - INT16: Force 16-bit signed integer
911 - INT32: Force 32-bit signed integer
912 - FLOAT32: Force 32-bit float (GPU performance)
913 - FLOAT64: Force 64-bit float (maximum precision)
914 """
915 numpy_dtype_and_slice_preserving_wrapper.__doc__ = original_doc + additional_doc
917 return numpy_dtype_and_slice_preserving_wrapper
920def _create_cupy_dtype_preserving_wrapper(original_func, func_name):
921 """
922 Create a wrapper that preserves input data type and adds slice_by_slice parameter for CuPy functions.
924 This uses the SAME pattern as scikit-image for consistency. CuPy functions generally preserve
925 dtypes better than scikit-image, but this wrapper ensures consistent behavior and adds
926 slice_by_slice parameter to avoid cross-slice contamination in 3D arrays.
927 """
928 import inspect
929 from functools import wraps
931 @wraps(original_func)
932 def cupy_dtype_and_slice_preserving_wrapper(image, *args, dtype_conversion=None, slice_by_slice: bool = False, **kwargs):
933 # Set default dtype_conversion if not provided and DtypeConversion is available
934 if dtype_conversion is None and DtypeConversion is not None:
935 dtype_conversion = DtypeConversion.PRESERVE_INPUT
937 try:
938 cupy = optional_import("cupy")
939 if cupy is None:
940 return original_func(image, *args, **kwargs)
942 # Store original dtype
943 original_dtype = image.dtype
945 # Handle slice_by_slice processing for 3D arrays using OpenHCS stack utilities
946 if slice_by_slice and hasattr(image, 'ndim') and image.ndim == 3:
947 from openhcs.core.memory.stack_utils import unstack_slices, stack_slices, _detect_memory_type
949 # Detect memory type and use proper OpenHCS utilities
950 memory_type = _detect_memory_type(image)
951 gpu_id = image.device.id if hasattr(image, 'device') else 0
953 # Unstack 3D array into 2D slices
954 slices_2d = unstack_slices(image, memory_type, gpu_id)
956 # Process each slice and handle special outputs
957 main_outputs = []
958 special_outputs_list = []
960 for slice_2d in slices_2d:
961 slice_result = original_func(slice_2d, *args, **kwargs)
963 # Check if result is a tuple (indicating special outputs)
964 if isinstance(slice_result, tuple):
965 main_outputs.append(slice_result[0]) # First element is main output
966 special_outputs_list.append(slice_result[1:]) # Rest are special outputs
967 else:
968 main_outputs.append(slice_result) # Single output
970 # Stack main outputs back into 3D array
971 result = stack_slices(main_outputs, memory_type, gpu_id)
973 # If we have special outputs, combine them and return tuple
974 if special_outputs_list:
975 # Combine special outputs from all slices
976 combined_special_outputs = []
977 num_special_outputs = len(special_outputs_list[0])
979 for i in range(num_special_outputs):
980 # Collect the i-th special output from all slices
981 special_output_values = [slice_outputs[i] for slice_outputs in special_outputs_list]
982 combined_special_outputs.append(special_output_values)
984 # Return tuple: (stacked_main_output, combined_special_output1, combined_special_output2, ...)
985 result = (result, *combined_special_outputs)
986 else:
987 # Call the original function normally
988 result = original_func(image, *args, **kwargs)
990 # Apply dtype conversion based on enum value
991 if hasattr(result, 'dtype') and dtype_conversion is not None:
992 if dtype_conversion == DtypeConversion.PRESERVE_INPUT:
993 # Preserve input dtype
994 if result.dtype != original_dtype:
995 result = _scale_and_convert_cupy(result, original_dtype)
996 elif dtype_conversion == DtypeConversion.NATIVE_OUTPUT:
997 # Return CuPy's native output dtype
998 pass # No conversion needed
999 else:
1000 # Force specific dtype
1001 target_dtype = dtype_conversion.numpy_dtype
1002 if target_dtype is not None:
1003 result = _scale_and_convert_cupy(result, target_dtype)
1005 return result
1006 except Exception as e:
1007 logger.error(f"Error in CuPy dtype/slice preserving wrapper for {func_name}: {e}")
1008 # Return original result on error
1009 return original_func(image, *args, **kwargs)
1011 # Update function signature to include new parameters
1012 try:
1013 original_sig = inspect.signature(original_func)
1014 new_params = list(original_sig.parameters.values())
1016 # Check if slice_by_slice parameter already exists
1017 param_names = [p.name for p in new_params]
1018 # Add dtype_conversion parameter first (before slice_by_slice)
1019 if 'dtype_conversion' not in param_names: 1019 ↛ 1028line 1019 didn't jump to line 1028 because the condition on line 1019 was always true
1020 dtype_param = inspect.Parameter(
1021 'dtype_conversion',
1022 inspect.Parameter.KEYWORD_ONLY,
1023 default=DtypeConversion.PRESERVE_INPUT,
1024 annotation=DtypeConversion
1025 )
1026 new_params.append(dtype_param)
1028 if 'slice_by_slice' not in param_names: 1028 ↛ 1039line 1028 didn't jump to line 1039 because the condition on line 1028 was always true
1029 # Add slice_by_slice parameter as keyword-only (after dtype_conversion)
1030 slice_param = inspect.Parameter(
1031 'slice_by_slice',
1032 inspect.Parameter.KEYWORD_ONLY,
1033 default=False,
1034 annotation=bool
1035 )
1036 new_params.append(slice_param)
1038 # Create new signature and override the @wraps signature
1039 new_sig = original_sig.replace(parameters=new_params)
1040 cupy_dtype_and_slice_preserving_wrapper.__signature__ = new_sig
1042 # Set type annotations manually for get_type_hints() compatibility
1043 cupy_dtype_and_slice_preserving_wrapper.__annotations__ = getattr(original_func, '__annotations__', {}).copy()
1044 cupy_dtype_and_slice_preserving_wrapper.__annotations__['dtype_conversion'] = DtypeConversion
1045 cupy_dtype_and_slice_preserving_wrapper.__annotations__['slice_by_slice'] = bool
1047 except Exception:
1048 # If signature modification fails, continue without it
1049 pass
1051 # Update docstring to mention slice_by_slice parameter
1052 original_doc = cupy_dtype_and_slice_preserving_wrapper.__doc__ or ""
1053 additional_doc = """
1055 Additional OpenHCS Parameters
1056 -----------------------------
1057 slice_by_slice : bool, optional (default: False)
1058 If True, process 3D arrays slice-by-slice to avoid cross-slice contamination.
1059 If False, use original 3D behavior. Recommended for edge detection functions
1060 on stitched microscopy data to prevent artifacts at field boundaries.
1062 dtype_conversion : DtypeConversion, optional (default: PRESERVE_INPUT)
1063 Controls output data type conversion:
1065 - PRESERVE_INPUT: Keep input dtype (uint16 → uint16)
1066 - NATIVE_OUTPUT: Use CuPy's native output dtype
1067 - UINT8: Force 8-bit unsigned integer (0-255 range)
1068 - UINT16: Force 16-bit unsigned integer (microscopy standard)
1069 - INT16: Force 16-bit signed integer
1070 - INT32: Force 32-bit signed integer
1071 - FLOAT32: Force 32-bit float (GPU performance)
1072 - FLOAT64: Force 64-bit float (maximum precision)
1073 """
1074 cupy_dtype_and_slice_preserving_wrapper.__doc__ = original_doc + additional_doc
1076 return cupy_dtype_and_slice_preserving_wrapper
1079def _create_torch_dtype_preserving_wrapper(original_func, func_name):
1080 """
1081 Create a wrapper that preserves input data type and adds slice_by_slice parameter for PyTorch functions.
1083 This follows the same pattern as existing dtype preservation wrappers for consistency.
1084 PyTorch functions generally preserve dtypes well, but this wrapper ensures consistent behavior
1085 and adds slice_by_slice parameter to avoid cross-slice contamination in 3D arrays.
1086 """
1087 import inspect
1088 from functools import wraps
1090 @wraps(original_func)
1091 def torch_dtype_and_slice_preserving_wrapper(image, *args, dtype_conversion=None, slice_by_slice: bool = False, **kwargs):
1092 # Set default dtype_conversion if not provided
1093 if dtype_conversion is None:
1094 dtype_conversion = DtypeConversion.PRESERVE_INPUT
1096 try:
1097 torch = optional_import("torch")
1098 if torch is None:
1099 return original_func(image, *args, **kwargs)
1101 # Store original dtype
1102 original_dtype = image.dtype if hasattr(image, 'dtype') else None
1104 # Handle slice_by_slice processing for 3D arrays
1105 if slice_by_slice and hasattr(image, 'ndim') and image.ndim == 3:
1106 from openhcs.core.memory.stack_utils import unstack_slices, stack_slices, _detect_memory_type
1108 # Detect memory type and use proper OpenHCS utilities
1109 memory_type = _detect_memory_type(image)
1110 gpu_id = image.device.index if hasattr(image, 'device') and image.device.type == 'cuda' else 0
1112 # Unstack 3D array into 2D slices
1113 slices_2d = unstack_slices(image, memory_type=memory_type, gpu_id=gpu_id)
1115 # Process each slice and handle special outputs
1116 main_outputs = []
1117 special_outputs_list = []
1119 for slice_2d in slices_2d:
1120 slice_result = original_func(slice_2d, *args, **kwargs)
1122 # Check if result is a tuple (indicating special outputs)
1123 if isinstance(slice_result, tuple):
1124 main_outputs.append(slice_result[0]) # First element is main output
1125 special_outputs_list.append(slice_result[1:]) # Rest are special outputs
1126 else:
1127 main_outputs.append(slice_result) # Single output
1129 # Stack main outputs back into 3D array
1130 result = stack_slices(main_outputs, memory_type=memory_type, gpu_id=gpu_id)
1132 # If we have special outputs, combine them and return tuple
1133 if special_outputs_list:
1134 # Combine special outputs from all slices
1135 combined_special_outputs = []
1136 num_special_outputs = len(special_outputs_list[0])
1138 for i in range(num_special_outputs):
1139 # Collect the i-th special output from all slices
1140 special_output_values = [slice_outputs[i] for slice_outputs in special_outputs_list]
1141 combined_special_outputs.append(special_output_values)
1143 # Return tuple: (stacked_main_output, combined_special_output1, combined_special_output2, ...)
1144 result = (result, *combined_special_outputs)
1145 else:
1146 # Process normally
1147 result = original_func(image, *args, **kwargs)
1149 # Apply dtype conversion if result is a tensor and we have dtype conversion info
1150 if (hasattr(result, 'dtype') and hasattr(result, 'shape') and
1151 original_dtype is not None and dtype_conversion is not None):
1153 if dtype_conversion == DtypeConversion.PRESERVE_INPUT:
1154 # Preserve input dtype
1155 if result.dtype != original_dtype:
1156 result = result.to(original_dtype)
1157 elif dtype_conversion == DtypeConversion.NATIVE_OUTPUT:
1158 # Return PyTorch's native output dtype
1159 pass # No conversion needed
1160 else:
1161 # Force specific dtype
1162 target_dtype = dtype_conversion.numpy_dtype
1163 if target_dtype is not None:
1164 # Map numpy dtypes to torch dtypes
1165 import numpy as np
1166 numpy_to_torch = {
1167 np.uint8: torch.uint8,
1168 np.uint16: torch.int32, # PyTorch doesn't have uint16, use int32
1169 np.int16: torch.int16,
1170 np.int32: torch.int32,
1171 np.float32: torch.float32,
1172 np.float64: torch.float64,
1173 }
1174 torch_dtype = numpy_to_torch.get(target_dtype)
1175 if torch_dtype is not None:
1176 result = result.to(torch_dtype)
1178 return result
1180 except Exception as e:
1181 logger.error(f"Error in PyTorch dtype/slice preserving wrapper for {func_name}: {e}")
1182 # Return original result on error
1183 return original_func(image, *args, **kwargs)
1185 # Update function signature to include new parameters
1186 try:
1187 original_sig = inspect.signature(original_func)
1188 new_params = list(original_sig.parameters.values())
1190 # Add dtype_conversion parameter first (before slice_by_slice)
1191 param_names = [p.name for p in new_params]
1192 if 'dtype_conversion' not in param_names: 1192 ↛ 1202line 1192 didn't jump to line 1202 because the condition on line 1192 was always true
1193 dtype_param = inspect.Parameter(
1194 'dtype_conversion',
1195 inspect.Parameter.KEYWORD_ONLY,
1196 default=DtypeConversion.PRESERVE_INPUT,
1197 annotation=DtypeConversion
1198 )
1199 new_params.append(dtype_param)
1201 # Add slice_by_slice parameter after dtype_conversion
1202 if 'slice_by_slice' not in param_names:
1203 slice_param = inspect.Parameter(
1204 'slice_by_slice',
1205 inspect.Parameter.KEYWORD_ONLY,
1206 default=False,
1207 annotation=bool
1208 )
1209 new_params.append(slice_param)
1211 new_sig = original_sig.replace(parameters=new_params)
1212 torch_dtype_and_slice_preserving_wrapper.__signature__ = new_sig
1214 # Set type annotations manually for get_type_hints() compatibility
1215 torch_dtype_and_slice_preserving_wrapper.__annotations__ = getattr(original_func, '__annotations__', {}).copy()
1216 if DtypeConversion is not None: 1216 ↛ 1218line 1216 didn't jump to line 1218 because the condition on line 1216 was always true
1217 torch_dtype_and_slice_preserving_wrapper.__annotations__['dtype_conversion'] = DtypeConversion
1218 torch_dtype_and_slice_preserving_wrapper.__annotations__['slice_by_slice'] = bool
1220 except Exception:
1221 # If signature modification fails, continue without it
1222 pass
1224 # Update docstring to mention new parameters
1225 original_doc = torch_dtype_and_slice_preserving_wrapper.__doc__ or ""
1226 additional_doc = """
1228 Additional OpenHCS Parameters
1229 -----------------------------
1230 slice_by_slice : bool, optional (default: False)
1231 If True, process 3D arrays slice-by-slice to avoid cross-slice contamination.
1232 If False, use original 3D behavior. Recommended for edge detection functions
1233 on stitched microscopy data to prevent artifacts at field boundaries.
1235 dtype_conversion : DtypeConversion, optional (default: PRESERVE_INPUT)
1236 Controls output data type conversion:
1238 - PRESERVE_INPUT: Keep input dtype (uint16 → uint16)
1239 - NATIVE_OUTPUT: Use PyTorch's native output dtype
1240 - UINT8: Force 8-bit unsigned integer (0-255 range)
1241 - UINT16: Force 16-bit unsigned integer (mapped to int32 in PyTorch)
1242 - INT16: Force 16-bit signed integer
1243 - INT32: Force 32-bit signed integer
1244 - FLOAT32: Force 32-bit float (GPU performance)
1245 - FLOAT64: Force 64-bit float (maximum precision)
1246 """
1247 torch_dtype_and_slice_preserving_wrapper.__doc__ = original_doc + additional_doc
1249 return torch_dtype_and_slice_preserving_wrapper
1252def _create_tensorflow_dtype_preserving_wrapper(original_func, func_name):
1253 """
1254 Create a wrapper that preserves input data type and adds slice_by_slice parameter for TensorFlow functions.
1256 This follows the same pattern as existing dtype preservation wrappers for consistency.
1257 TensorFlow functions generally preserve dtypes well, but this wrapper ensures consistent behavior
1258 and adds slice_by_slice parameter to avoid cross-slice contamination in 3D arrays.
1259 """
1260 import inspect
1261 from functools import wraps
1263 @wraps(original_func)
1264 def tensorflow_dtype_and_slice_preserving_wrapper(image, *args, dtype_conversion=None, slice_by_slice: bool = False, **kwargs):
1265 # Set default dtype_conversion if not provided
1266 if dtype_conversion is None:
1267 dtype_conversion = DtypeConversion.PRESERVE_INPUT
1269 try:
1270 tf = optional_import("tensorflow")
1271 if tf is None:
1272 return original_func(image, *args, **kwargs)
1274 # Store original dtype
1275 original_dtype = image.dtype if hasattr(image, 'dtype') else None
1277 # Handle slice_by_slice processing for 3D arrays
1278 if slice_by_slice and hasattr(image, 'ndim') and image.ndim == 3:
1279 from openhcs.core.memory.stack_utils import unstack_slices, stack_slices, _detect_memory_type
1281 # Detect memory type and use proper OpenHCS utilities
1282 memory_type = _detect_memory_type(image)
1283 gpu_id = 0 # TensorFlow manages GPU placement internally
1285 # Unstack 3D array into 2D slices
1286 slices_2d = unstack_slices(image, memory_type=memory_type, gpu_id=gpu_id)
1288 # Process each slice and handle special outputs
1289 main_outputs = []
1290 special_outputs_list = []
1292 for slice_2d in slices_2d:
1293 slice_result = original_func(slice_2d, *args, **kwargs)
1295 # Check if result is a tuple (indicating special outputs)
1296 if isinstance(slice_result, tuple):
1297 main_outputs.append(slice_result[0]) # First element is main output
1298 special_outputs_list.append(slice_result[1:]) # Rest are special outputs
1299 else:
1300 main_outputs.append(slice_result) # Single output
1302 # Stack main outputs back into 3D array
1303 result = stack_slices(main_outputs, memory_type=memory_type, gpu_id=gpu_id)
1305 # If we have special outputs, combine them and return tuple
1306 if special_outputs_list:
1307 # Combine special outputs from all slices
1308 combined_special_outputs = []
1309 num_special_outputs = len(special_outputs_list[0])
1311 for i in range(num_special_outputs):
1312 # Collect the i-th special output from all slices
1313 special_output_values = [slice_outputs[i] for slice_outputs in special_outputs_list]
1314 combined_special_outputs.append(special_output_values)
1316 # Return tuple: (stacked_main_output, combined_special_output1, combined_special_output2, ...)
1317 result = (result, *combined_special_outputs)
1318 else:
1319 # Process normally
1320 result = original_func(image, *args, **kwargs)
1322 # Apply dtype conversion if result is a tensor and we have dtype conversion info
1323 if (hasattr(result, 'dtype') and hasattr(result, 'shape') and
1324 original_dtype is not None and dtype_conversion is not None):
1326 if dtype_conversion == DtypeConversion.PRESERVE_INPUT:
1327 # Preserve input dtype
1328 if result.dtype != original_dtype:
1329 result = tf.cast(result, original_dtype)
1330 elif dtype_conversion == DtypeConversion.NATIVE_OUTPUT:
1331 # Return TensorFlow's native output dtype
1332 pass # No conversion needed
1333 else:
1334 # Force specific dtype
1335 target_dtype = dtype_conversion.numpy_dtype
1336 if target_dtype is not None:
1337 # Convert numpy dtype to tensorflow dtype
1338 import numpy as np
1339 numpy_to_tf = {
1340 np.uint8: tf.uint8,
1341 np.uint16: tf.uint16,
1342 np.int16: tf.int16,
1343 np.int32: tf.int32,
1344 np.float32: tf.float32,
1345 np.float64: tf.float64,
1346 }
1347 tf_dtype = numpy_to_tf.get(target_dtype)
1348 if tf_dtype is not None:
1349 result = tf.cast(result, tf_dtype)
1351 return result
1353 except Exception as e:
1354 logger.error(f"Error in TensorFlow dtype/slice preserving wrapper for {func_name}: {e}")
1355 # Return original result on error
1356 return original_func(image, *args, **kwargs)
1358 # Update function signature to include new parameters
1359 try:
1360 original_sig = inspect.signature(original_func)
1361 new_params = list(original_sig.parameters.values())
1363 # Add slice_by_slice parameter if not already present
1364 param_names = [p.name for p in new_params]
1365 if 'slice_by_slice' not in param_names: 1365 ↛ 1375line 1365 didn't jump to line 1375 because the condition on line 1365 was always true
1366 slice_param = inspect.Parameter(
1367 'slice_by_slice',
1368 inspect.Parameter.KEYWORD_ONLY,
1369 default=False,
1370 annotation=bool
1371 )
1372 new_params.append(slice_param)
1374 # Add dtype_conversion parameter if DtypeConversion is available
1375 if DtypeConversion is not None and 'dtype_conversion' not in param_names: 1375 ↛ 1384line 1375 didn't jump to line 1384 because the condition on line 1375 was always true
1376 dtype_param = inspect.Parameter(
1377 'dtype_conversion',
1378 inspect.Parameter.KEYWORD_ONLY,
1379 default=DtypeConversion.PRESERVE_INPUT,
1380 annotation=DtypeConversion
1381 )
1382 new_params.append(dtype_param)
1384 new_sig = original_sig.replace(parameters=new_params)
1385 tensorflow_dtype_and_slice_preserving_wrapper.__signature__ = new_sig
1387 # Set type annotations manually for get_type_hints() compatibility
1388 tensorflow_dtype_and_slice_preserving_wrapper.__annotations__ = getattr(original_func, '__annotations__', {}).copy()
1389 tensorflow_dtype_and_slice_preserving_wrapper.__annotations__['slice_by_slice'] = bool
1390 if DtypeConversion is not None: 1390 ↛ 1398line 1390 didn't jump to line 1398 because the condition on line 1390 was always true
1391 tensorflow_dtype_and_slice_preserving_wrapper.__annotations__['dtype_conversion'] = DtypeConversion
1393 except Exception:
1394 # If signature modification fails, continue without it
1395 pass
1397 # Update docstring to mention new parameters
1398 original_doc = tensorflow_dtype_and_slice_preserving_wrapper.__doc__ or ""
1399 additional_doc = """
1401 Additional OpenHCS Parameters
1402 -----------------------------
1403 slice_by_slice : bool, optional (default: False)
1404 If True, process 3D arrays slice-by-slice to avoid cross-slice contamination.
1405 If False, use original 3D behavior. Recommended for edge detection functions
1406 on stitched microscopy data to prevent artifacts at field boundaries.
1408 dtype_conversion : DtypeConversion, optional (default: PRESERVE_INPUT)
1409 Controls output data type conversion:
1411 - PRESERVE_INPUT: Keep input dtype (uint16 → uint16)
1412 - NATIVE_OUTPUT: Use TensorFlow's native output dtype
1413 - UINT8: Force 8-bit unsigned integer (0-255 range)
1414 - UINT16: Force 16-bit unsigned integer (microscopy standard)
1415 - INT16: Force 16-bit signed integer
1416 - INT32: Force 32-bit signed integer
1417 - FLOAT32: Force 32-bit float (GPU performance)
1418 - FLOAT64: Force 64-bit float (maximum precision)
1419 """
1420 tensorflow_dtype_and_slice_preserving_wrapper.__doc__ = original_doc + additional_doc
1422 return tensorflow_dtype_and_slice_preserving_wrapper
1425def _create_jax_dtype_preserving_wrapper(original_func, func_name):
1426 """
1427 Create a wrapper that preserves input data type and adds slice_by_slice parameter for JAX functions.
1429 This follows the same pattern as existing dtype preservation wrappers for consistency.
1430 JAX functions generally preserve dtypes well, but this wrapper ensures consistent behavior
1431 and adds slice_by_slice parameter to avoid cross-slice contamination in 3D arrays.
1432 """
1433 import inspect
1434 from functools import wraps
1436 @wraps(original_func)
1437 def jax_dtype_and_slice_preserving_wrapper(image, *args, dtype_conversion=None, slice_by_slice: bool = False, **kwargs):
1438 # Set default dtype_conversion if not provided
1439 if dtype_conversion is None:
1440 dtype_conversion = DtypeConversion.PRESERVE_INPUT
1442 try:
1443 jax = optional_import("jax")
1444 jnp = optional_import("jax.numpy") if jax is not None else None
1445 if jax is None or jnp is None:
1446 return original_func(image, *args, **kwargs)
1448 # Store original dtype
1449 original_dtype = image.dtype if hasattr(image, 'dtype') else None
1451 # Handle slice_by_slice processing for 3D arrays
1452 if slice_by_slice and hasattr(image, 'ndim') and image.ndim == 3:
1453 from openhcs.core.memory.stack_utils import unstack_slices, stack_slices, _detect_memory_type
1455 # Detect memory type and use proper OpenHCS utilities
1456 memory_type = _detect_memory_type(image)
1457 gpu_id = 0 # JAX manages GPU placement internally
1459 # Unstack 3D array into 2D slices
1460 slices_2d = unstack_slices(image, memory_type=memory_type, gpu_id=gpu_id)
1462 # Process each slice and handle special outputs
1463 main_outputs = []
1464 special_outputs_list = []
1466 for slice_2d in slices_2d:
1467 slice_result = original_func(slice_2d, *args, **kwargs)
1469 # Check if result is a tuple (indicating special outputs)
1470 if isinstance(slice_result, tuple):
1471 main_outputs.append(slice_result[0]) # First element is main output
1472 special_outputs_list.append(slice_result[1:]) # Rest are special outputs
1473 else:
1474 main_outputs.append(slice_result) # Single output
1476 # Stack main outputs back into 3D array
1477 result = stack_slices(main_outputs, memory_type=memory_type, gpu_id=gpu_id)
1479 # If we have special outputs, combine them and return tuple
1480 if special_outputs_list:
1481 # Combine special outputs from all slices
1482 combined_special_outputs = []
1483 num_special_outputs = len(special_outputs_list[0])
1485 for i in range(num_special_outputs):
1486 # Collect the i-th special output from all slices
1487 special_output_values = [slice_outputs[i] for slice_outputs in special_outputs_list]
1488 combined_special_outputs.append(special_output_values)
1490 # Return tuple: (stacked_main_output, combined_special_output1, combined_special_output2, ...)
1491 result = (result, *combined_special_outputs)
1492 else:
1493 # Process normally
1494 result = original_func(image, *args, **kwargs)
1496 # Apply dtype conversion if result is an array and we have dtype conversion info
1497 if (hasattr(result, 'dtype') and hasattr(result, 'shape') and
1498 original_dtype is not None and dtype_conversion is not None):
1500 if dtype_conversion == DtypeConversion.PRESERVE_INPUT:
1501 # Preserve input dtype
1502 if result.dtype != original_dtype:
1503 result = result.astype(original_dtype)
1504 elif dtype_conversion == DtypeConversion.NATIVE_OUTPUT:
1505 # Return JAX's native output dtype
1506 pass # No conversion needed
1507 else:
1508 # Force specific dtype
1509 target_dtype = dtype_conversion.numpy_dtype
1510 if target_dtype is not None:
1511 # JAX uses numpy-compatible dtypes
1512 result = result.astype(target_dtype)
1514 return result
1516 except Exception as e:
1517 logger.error(f"Error in JAX dtype/slice preserving wrapper for {func_name}: {e}")
1518 # Return original result on error
1519 return original_func(image, *args, **kwargs)
1521 # Update function signature to include new parameters
1522 try:
1523 original_sig = inspect.signature(original_func)
1524 new_params = list(original_sig.parameters.values())
1526 # Add slice_by_slice parameter if not already present
1527 param_names = [p.name for p in new_params]
1528 if 'slice_by_slice' not in param_names:
1529 slice_param = inspect.Parameter(
1530 'slice_by_slice',
1531 inspect.Parameter.KEYWORD_ONLY,
1532 default=False,
1533 annotation=bool
1534 )
1535 new_params.append(slice_param)
1537 # Add dtype_conversion parameter if DtypeConversion is available
1538 if DtypeConversion is not None and 'dtype_conversion' not in param_names: 1538 ↛ 1547line 1538 didn't jump to line 1547 because the condition on line 1538 was always true
1539 dtype_param = inspect.Parameter(
1540 'dtype_conversion',
1541 inspect.Parameter.KEYWORD_ONLY,
1542 default=DtypeConversion.PRESERVE_INPUT,
1543 annotation=DtypeConversion
1544 )
1545 new_params.append(dtype_param)
1547 new_sig = original_sig.replace(parameters=new_params)
1548 jax_dtype_and_slice_preserving_wrapper.__signature__ = new_sig
1550 # Set type annotations manually for get_type_hints() compatibility
1551 jax_dtype_and_slice_preserving_wrapper.__annotations__ = getattr(original_func, '__annotations__', {}).copy()
1552 jax_dtype_and_slice_preserving_wrapper.__annotations__['slice_by_slice'] = bool
1553 if DtypeConversion is not None: 1553 ↛ 1561line 1553 didn't jump to line 1561 because the condition on line 1553 was always true
1554 jax_dtype_and_slice_preserving_wrapper.__annotations__['dtype_conversion'] = DtypeConversion
1556 except Exception:
1557 # If signature modification fails, continue without it
1558 pass
1560 # Update docstring to mention new parameters
1561 original_doc = jax_dtype_and_slice_preserving_wrapper.__doc__ or ""
1562 additional_doc = """
1564 Additional OpenHCS Parameters
1565 -----------------------------
1566 slice_by_slice : bool, optional (default: False)
1567 If True, process 3D arrays slice-by-slice to avoid cross-slice contamination.
1568 If False, use original 3D behavior. Recommended for edge detection functions
1569 on stitched microscopy data to prevent artifacts at field boundaries.
1571 dtype_conversion : DtypeConversion, optional (default: PRESERVE_INPUT)
1572 Controls output data type conversion:
1574 - PRESERVE_INPUT: Keep input dtype (uint16 → uint16)
1575 - NATIVE_OUTPUT: Use JAX's native output dtype
1576 - UINT8: Force 8-bit unsigned integer (0-255 range)
1577 - UINT16: Force 16-bit unsigned integer (microscopy standard)
1578 - INT16: Force 16-bit signed integer
1579 - INT32: Force 32-bit signed integer
1580 - FLOAT32: Force 32-bit float (GPU performance)
1581 - FLOAT64: Force 64-bit float (maximum precision)
1582 """
1583 jax_dtype_and_slice_preserving_wrapper.__doc__ = original_doc + additional_doc
1585 return jax_dtype_and_slice_preserving_wrapper
1588def _create_pyclesperanto_dtype_preserving_wrapper(original_func, func_name):
1589 """
1590 Create a wrapper that ensures array-in/array-out compliance and dtype preservation for pyclesperanto functions.
1592 All OpenHCS functions must:
1593 1. Take 3D pyclesperanto array as first argument
1594 2. Return 3D pyclesperanto array as first output
1595 3. Additional outputs (values, coordinates) as 2nd, 3rd, etc. returns
1596 4. Preserve input dtype when appropriate
1597 """
1598 import inspect
1599 from functools import wraps
1601 @wraps(original_func)
1602 def pyclesperanto_dtype_and_slice_preserving_wrapper(image_3d, *args, dtype_conversion=None, slice_by_slice: bool = False, **kwargs):
1603 # Set default dtype_conversion if not provided
1604 if dtype_conversion is None:
1605 dtype_conversion = DtypeConversion.PRESERVE_INPUT
1607 try:
1608 cle = optional_import("pyclesperanto")
1609 if cle is None:
1610 return original_func(image_3d, *args, **kwargs)
1612 # Store original dtype for preservation
1613 original_dtype = image_3d.dtype
1615 # Handle slice_by_slice processing for 3D arrays using OpenHCS stack utilities
1616 if slice_by_slice and hasattr(image_3d, 'ndim') and image_3d.ndim == 3:
1617 from openhcs.core.memory.stack_utils import unstack_slices, stack_slices, _detect_memory_type
1619 # Detect memory type and use proper OpenHCS utilities
1620 memory_type = _detect_memory_type(image_3d)
1621 gpu_id = 0 # pyclesperanto manages GPU internally
1623 # Process each slice and handle special outputs
1624 slices = unstack_slices(image_3d, memory_type, gpu_id)
1625 main_outputs = []
1626 special_outputs_list = []
1628 for slice_2d in slices:
1629 # Apply function to 2D slice
1630 result_slice = original_func(slice_2d, *args, **kwargs)
1632 # Check if result is a tuple (indicating special outputs)
1633 if isinstance(result_slice, tuple):
1634 main_outputs.append(result_slice[0]) # First element is main output
1635 special_outputs_list.append(result_slice[1:]) # Rest are special outputs
1636 else:
1637 main_outputs.append(result_slice) # Single output
1639 # Stack main outputs back into 3D array
1640 result = stack_slices(main_outputs, memory_type, gpu_id)
1642 # If we have special outputs, combine them and return tuple
1643 if special_outputs_list:
1644 # Combine special outputs from all slices
1645 combined_special_outputs = []
1646 num_special_outputs = len(special_outputs_list[0])
1648 for i in range(num_special_outputs):
1649 # Collect the i-th special output from all slices
1650 special_output_values = [slice_outputs[i] for slice_outputs in special_outputs_list]
1651 combined_special_outputs.append(special_output_values)
1653 # Return tuple: (stacked_main_output, combined_special_output1, combined_special_output2, ...)
1654 result = (result, *combined_special_outputs)
1655 else:
1656 # Normal 3D processing
1657 result = original_func(image_3d, *args, **kwargs)
1659 # Check if result is 2D and needs expansion to 3D
1660 if hasattr(result, 'ndim') and result.ndim == 2:
1661 # Expand 2D result to 3D single slice
1662 try:
1663 # Concatenate with itself to create 3D, then take first slice
1664 temp_3d = cle.concatenate_along_z(result, result) # Creates (2, Y, X)
1665 result = temp_3d[0:1, :, :] # Take first slice to get (1, Y, X)
1666 except Exception:
1667 # If expansion fails, return original 2D result
1668 # This maintains backward compatibility
1669 pass
1671 # Apply dtype conversion based on enum value
1672 if hasattr(result, 'dtype') and hasattr(result, 'shape') and dtype_conversion is not None:
1673 if dtype_conversion == DtypeConversion.PRESERVE_INPUT:
1674 # Preserve input dtype
1675 if result.dtype != original_dtype:
1676 return _scale_and_convert_pyclesperanto(result, original_dtype)
1677 return result
1679 elif dtype_conversion == DtypeConversion.NATIVE_OUTPUT:
1680 # Return pyclesperanto's native output dtype
1681 return result
1683 else:
1684 # Force specific dtype
1685 target_dtype = dtype_conversion.numpy_dtype
1686 if target_dtype is not None and result.dtype != target_dtype:
1687 return _scale_and_convert_pyclesperanto(result, target_dtype)
1688 return result
1689 else:
1690 # Non-array result, return as-is
1691 return result
1693 except Exception as e:
1694 logger.error(f"Error in pyclesperanto dtype/slice preserving wrapper for {func_name}: {e}")
1695 # If anything goes wrong, fall back to original function
1696 return original_func(image_3d, *args, **kwargs)
1698 # Update function signature to include new parameters
1699 try:
1700 original_sig = inspect.signature(original_func)
1701 new_params = list(original_sig.parameters.values())
1703 # Add slice_by_slice parameter if not already present
1704 param_names = [p.name for p in new_params]
1705 if 'slice_by_slice' not in param_names: 1705 ↛ 1715line 1705 didn't jump to line 1715 because the condition on line 1705 was always true
1706 slice_param = inspect.Parameter(
1707 'slice_by_slice',
1708 inspect.Parameter.KEYWORD_ONLY,
1709 default=False,
1710 annotation=bool
1711 )
1712 new_params.append(slice_param)
1714 # Add dtype_conversion parameter if DtypeConversion is available
1715 if DtypeConversion is not None and 'dtype_conversion' not in param_names: 1715 ↛ 1724line 1715 didn't jump to line 1724 because the condition on line 1715 was always true
1716 dtype_param = inspect.Parameter(
1717 'dtype_conversion',
1718 inspect.Parameter.KEYWORD_ONLY,
1719 default=DtypeConversion.PRESERVE_INPUT,
1720 annotation=DtypeConversion
1721 )
1722 new_params.append(dtype_param)
1724 new_sig = original_sig.replace(parameters=new_params)
1725 pyclesperanto_dtype_and_slice_preserving_wrapper.__signature__ = new_sig
1727 # Set type annotations manually for get_type_hints() compatibility
1728 pyclesperanto_dtype_and_slice_preserving_wrapper.__annotations__ = getattr(original_func, '__annotations__', {}).copy()
1729 pyclesperanto_dtype_and_slice_preserving_wrapper.__annotations__['slice_by_slice'] = bool
1730 if DtypeConversion is not None: 1730 ↛ 1738line 1730 didn't jump to line 1738 because the condition on line 1730 was always true
1731 pyclesperanto_dtype_and_slice_preserving_wrapper.__annotations__['dtype_conversion'] = DtypeConversion
1733 except Exception:
1734 # If signature modification fails, continue without it
1735 pass
1737 # Update docstring to mention additional parameters
1738 original_doc = pyclesperanto_dtype_and_slice_preserving_wrapper.__doc__ or ""
1739 additional_doc = """
1741 Additional OpenHCS Parameters
1742 -----------------------------
1743 slice_by_slice : bool, optional (default: False)
1744 If True, process 3D arrays slice-by-slice to avoid cross-slice contamination.
1745 If False, use original 3D behavior. Recommended for edge detection functions
1746 on stitched microscopy data to prevent artifacts at field boundaries.
1748 dtype_conversion : DtypeConversion, optional (default: PRESERVE_INPUT)
1749 Controls output data type conversion:
1751 - PRESERVE_INPUT: Keep input dtype (uint16 → uint16)
1752 - NATIVE_OUTPUT: Use pyclesperanto's native output (often float32)
1753 - UINT8: Force 8-bit unsigned integer (0-255 range)
1754 - UINT16: Force 16-bit unsigned integer (microscopy standard)
1755 - INT16: Force 16-bit signed integer
1756 - INT32: Force 32-bit signed integer
1757 - FLOAT32: Force 32-bit float (GPU performance)
1758 - FLOAT64: Force 64-bit float (maximum precision)
1759 """
1760 pyclesperanto_dtype_and_slice_preserving_wrapper.__doc__ = original_doc + additional_doc
1762 return pyclesperanto_dtype_and_slice_preserving_wrapper
1765def _scale_and_convert_pyclesperanto(result, target_dtype):
1766 """
1767 Scale and convert pyclesperanto array to target dtype.
1768 This is a simplified version of the helper function from pyclesperanto_registry.py
1769 """
1770 try:
1771 cle = optional_import("pyclesperanto")
1772 if cle is None:
1773 return result
1775 import numpy as np
1777 # If result is floating point and target is integer, scale appropriately
1778 if np.issubdtype(result.dtype, np.floating) and not np.issubdtype(target_dtype, np.floating):
1779 # Convert to numpy for scaling, then back to pyclesperanto
1780 result_np = cle.pull(result)
1782 # Clip to [0, 1] range and scale to integer range
1783 clipped = np.clip(result_np, 0, 1)
1784 if target_dtype == np.uint8:
1785 scaled = (clipped * 255).astype(target_dtype)
1786 elif target_dtype == np.uint16:
1787 scaled = (clipped * 65535).astype(target_dtype)
1788 elif target_dtype == np.uint32:
1789 scaled = (clipped * 4294967295).astype(target_dtype)
1790 else:
1791 # For other integer types, just convert without scaling
1792 scaled = clipped.astype(target_dtype)
1794 # Push back to GPU
1795 return cle.push(scaled)
1796 else:
1797 # Direct conversion for same numeric type families
1798 result_np = cle.pull(result)
1799 converted = result_np.astype(target_dtype)
1800 return cle.push(converted)
1802 except Exception:
1803 # If conversion fails, return original result
1804 return result