Coverage for openhcs/core/memory/decorators.py: 30.2%
759 statements
« prev ^ index » next coverage.py v7.10.3, created at 2025-08-14 05:57 +0000
« prev ^ index » next coverage.py v7.10.3, created at 2025-08-14 05:57 +0000
1"""
2Memory type declaration decorators for OpenHCS.
4This module provides decorators for explicitly declaring the memory interface
5of pure functions, enforcing Clause 106-A (Declared Memory Types) and supporting
6memory-type-aware dispatching and orchestration.
8These decorators annotate functions with input_memory_type and output_memory_type
9attributes and provide automatic thread-local CUDA stream management for GPU
10frameworks to enable true parallelization across multiple threads.
11"""
13import functools
14import logging
15import threading
16from typing import Any, Callable, Optional, TypeVar
18from openhcs.constants.constants import VALID_MEMORY_TYPES
19from openhcs.core.utils import optional_import
20from openhcs.core.memory.oom_recovery import _execute_with_oom_recovery
22logger = logging.getLogger(__name__)
24F = TypeVar('F', bound=Callable[..., Any])
26# Dtype conversion enum and utilities for consistent dtype handling across all frameworks
27from enum import Enum
28import numpy as np
30class DtypeConversion(Enum):
31 """Data type conversion modes for all memory type functions."""
33 PRESERVE_INPUT = "preserve" # Keep input dtype (default)
34 NATIVE_OUTPUT = "native" # Use framework's native output
35 UINT8 = "uint8" # Force uint8 (0-255 range)
36 UINT16 = "uint16" # Force uint16 (microscopy standard)
37 INT16 = "int16" # Force int16 (signed microscopy data)
38 INT32 = "int32" # Force int32 (large integer values)
39 FLOAT32 = "float32" # Force float32 (GPU performance)
40 FLOAT64 = "float64" # Force float64 (maximum precision)
42 @property
43 def numpy_dtype(self):
44 """Get the corresponding numpy dtype."""
45 dtype_map = {
46 self.UINT8: np.uint8,
47 self.UINT16: np.uint16,
48 self.INT16: np.int16,
49 self.INT32: np.int32,
50 self.FLOAT32: np.float32,
51 self.FLOAT64: np.float64,
52 }
53 return dtype_map.get(self, None)
56def _scale_and_convert_numpy(result, target_dtype):
57 """Scale numpy results to target integer range and convert dtype."""
58 if not hasattr(result, 'dtype'):
59 return result
61 # Check if result is floating point and target is integer
62 result_is_float = np.issubdtype(result.dtype, np.floating)
63 target_is_int = target_dtype in [np.uint8, np.uint16, np.uint32, np.int8, np.int16, np.int32]
65 if result_is_float and target_is_int:
66 # Scale floating point results to integer range
67 result_min = result.min()
68 result_max = result.max()
70 if result_max > result_min: # Avoid division by zero
71 # Normalize to [0, 1] range
72 normalized = (result - result_min) / (result_max - result_min)
74 # Scale to target dtype range
75 if target_dtype == np.uint8:
76 scaled = normalized * 255.0
77 elif target_dtype == np.uint16:
78 scaled = normalized * 65535.0
79 elif target_dtype == np.uint32:
80 scaled = normalized * 4294967295.0
81 elif target_dtype == np.int16:
82 scaled = normalized * 65535.0 - 32768.0
83 elif target_dtype == np.int32:
84 scaled = normalized * 4294967295.0 - 2147483648.0
85 else:
86 scaled = normalized
88 return scaled.astype(target_dtype)
89 else:
90 # Constant image, just convert dtype
91 return result.astype(target_dtype)
92 else:
93 # Direct conversion for compatible types
94 return result.astype(target_dtype)
97def _scale_and_convert_pyclesperanto(result, target_dtype):
98 """Scale pyclesperanto results to target integer range and convert dtype."""
99 try:
100 import pyclesperanto as cle
101 except ImportError:
102 return result
104 if not hasattr(result, 'dtype'):
105 return result
107 # Check if result is floating point and target is integer
108 result_is_float = np.issubdtype(result.dtype, np.floating)
109 target_is_int = target_dtype in [np.uint8, np.uint16, np.uint32, np.int8, np.int16, np.int32]
111 if result_is_float and target_is_int:
112 # Get min/max of result for proper scaling
113 result_min = float(cle.minimum_of_all_pixels(result))
114 result_max = float(cle.maximum_of_all_pixels(result))
116 if result_max > result_min: # Avoid division by zero
117 # Normalize to [0, 1] range
118 normalized = cle.subtract_image_from_scalar(result, scalar=result_min)
119 range_val = result_max - result_min
120 normalized = cle.multiply_image_and_scalar(normalized, scalar=1.0/range_val)
122 # Scale to target dtype range
123 if target_dtype == np.uint8:
124 scaled = cle.multiply_image_and_scalar(normalized, scalar=255.0)
125 elif target_dtype == np.uint16:
126 scaled = cle.multiply_image_and_scalar(normalized, scalar=65535.0)
127 elif target_dtype == np.uint32:
128 scaled = cle.multiply_image_and_scalar(normalized, scalar=4294967295.0)
129 elif target_dtype == np.int16:
130 scaled = cle.multiply_image_and_scalar(normalized, scalar=65535.0)
131 scaled = cle.subtract_image_from_scalar(scaled, scalar=32768.0)
132 elif target_dtype == np.int32:
133 scaled = cle.multiply_image_and_scalar(normalized, scalar=4294967295.0)
134 scaled = cle.subtract_image_from_scalar(scaled, scalar=2147483648.0)
135 else:
136 scaled = normalized
138 # Convert to target dtype using push/pull method
139 scaled_cpu = cle.pull(scaled).astype(target_dtype)
140 return cle.push(scaled_cpu)
141 else:
142 # Constant image, just convert dtype
143 result_cpu = cle.pull(result).astype(target_dtype)
144 return cle.push(result_cpu)
145 else:
146 # Direct conversion for compatible types
147 result_cpu = cle.pull(result).astype(target_dtype)
148 return cle.push(result_cpu)
151def _scale_and_convert_cupy(result, target_dtype):
152 """Scale CuPy results to target integer range and convert dtype."""
153 try:
154 import cupy as cp
155 except ImportError:
156 return result
158 if not hasattr(result, 'dtype'):
159 return result
161 # If result is floating point and target is integer, scale appropriately
162 if cp.issubdtype(result.dtype, cp.floating) and not cp.issubdtype(target_dtype, cp.floating):
163 # Clip to [0, 1] range and scale to integer range
164 clipped = cp.clip(result, 0, 1)
165 if target_dtype == cp.uint8:
166 return (clipped * 255).astype(target_dtype)
167 elif target_dtype == cp.uint16:
168 return (clipped * 65535).astype(target_dtype)
169 elif target_dtype == cp.uint32:
170 return (clipped * 4294967295).astype(target_dtype)
171 else:
172 # For other integer types, just convert without scaling
173 return result.astype(target_dtype)
175 # Direct conversion for same numeric type families
176 return result.astype(target_dtype)
179# GPU frameworks imported lazily to prevent thread explosion
180# These will be imported only when actually needed by functions
181_gpu_frameworks_cache = {}
183def _get_cupy():
184 """Lazy import CuPy only when needed."""
185 if 'cupy' not in _gpu_frameworks_cache:
186 _gpu_frameworks_cache['cupy'] = optional_import("cupy")
187 if _gpu_frameworks_cache['cupy'] is not None:
188 logger.debug(f"🔧 Lazy imported CuPy in thread {threading.current_thread().name}")
189 return _gpu_frameworks_cache['cupy']
191def _get_torch():
192 """Lazy import PyTorch only when needed."""
193 if 'torch' not in _gpu_frameworks_cache:
194 _gpu_frameworks_cache['torch'] = optional_import("torch")
195 if _gpu_frameworks_cache['torch'] is not None:
196 logger.debug(f"🔧 Lazy imported PyTorch in thread {threading.current_thread().name}")
197 return _gpu_frameworks_cache['torch']
199def _get_tensorflow():
200 """Lazy import TensorFlow only when needed."""
201 if 'tensorflow' not in _gpu_frameworks_cache:
202 _gpu_frameworks_cache['tensorflow'] = optional_import("tensorflow")
203 if _gpu_frameworks_cache['tensorflow'] is not None:
204 logger.debug(f"🔧 Lazy imported TensorFlow in thread {threading.current_thread().name}")
205 return _gpu_frameworks_cache['tensorflow']
207def _get_jax():
208 """Lazy import JAX only when needed."""
209 if 'jax' not in _gpu_frameworks_cache:
210 _gpu_frameworks_cache['jax'] = optional_import("jax")
211 if _gpu_frameworks_cache['jax'] is not None:
212 logger.debug(f"🔧 Lazy imported JAX in thread {threading.current_thread().name}")
213 return _gpu_frameworks_cache['jax']
215# Thread-local storage for GPU streams and contexts
216_thread_gpu_contexts = threading.local()
218class ThreadGPUContext:
219 """Unified thread-local GPU context manager to prevent stream leaks."""
221 def __init__(self):
222 self._cupy_stream = None
223 self._torch_stream = None
224 self._thread_name = threading.current_thread().name
226 def get_cupy_stream(self):
227 """Get or create the single CuPy stream for this thread."""
228 if self._cupy_stream is None:
229 cp = _get_cupy()
230 if cp is not None and hasattr(cp, 'cuda'):
231 self._cupy_stream = cp.cuda.Stream()
232 logger.debug(f"🔧 Created CuPy stream for thread {self._thread_name}")
233 return self._cupy_stream
235 def get_torch_stream(self):
236 """Get or create the single PyTorch stream for this thread."""
237 if self._torch_stream is None:
238 torch = _get_torch()
239 if torch is not None and hasattr(torch, 'cuda') and torch.cuda.is_available():
240 self._torch_stream = torch.cuda.Stream()
241 logger.debug(f"🔧 Created PyTorch stream for thread {self._thread_name}")
242 return self._torch_stream
244 def cleanup(self):
245 """Clean up streams when thread exits."""
246 if self._cupy_stream is not None:
247 logger.debug(f"🔧 Cleaning up CuPy stream for thread {self._thread_name}")
248 self._cupy_stream = None
250 if self._torch_stream is not None:
251 logger.debug(f"🔧 Cleaning up PyTorch stream for thread {self._thread_name}")
252 self._torch_stream = None
254def get_thread_gpu_context() -> ThreadGPUContext:
255 """Get the unified GPU context for the current thread."""
256 if not hasattr(_thread_gpu_contexts, 'gpu_context'):
257 _thread_gpu_contexts.gpu_context = ThreadGPUContext()
259 # Register cleanup for when thread exits
260 import weakref
261 def cleanup_on_thread_exit():
262 if hasattr(_thread_gpu_contexts, 'gpu_context'):
263 _thread_gpu_contexts.gpu_context.cleanup()
265 # Use weakref to avoid circular references
266 current_thread = threading.current_thread()
267 if hasattr(current_thread, '_cleanup_funcs'):
268 current_thread._cleanup_funcs.append(cleanup_on_thread_exit)
269 else:
270 current_thread._cleanup_funcs = [cleanup_on_thread_exit]
272 return _thread_gpu_contexts.gpu_context
275def memory_types(*, input_type: str, output_type: str) -> Callable[[F], F]:
276 """
277 Decorator that explicitly declares the memory types for a function's input and output.
279 This decorator enforces Clause 106-A (Declared Memory Types) by requiring explicit
280 memory type declarations for both input and output.
282 Args:
283 input_type: The memory type for the function's input (e.g., "numpy", "cupy")
284 output_type: The memory type for the function's output (e.g., "numpy", "cupy")
286 Returns:
287 A decorator function that sets the memory type attributes
289 Raises:
290 ValueError: If input_type or output_type is not a supported memory type
291 """
292 # 🔒 Clause 88 — No Inferred Capabilities
293 # Validate memory types at decoration time, not runtime
294 if not input_type: 294 ↛ 295line 294 didn't jump to line 295 because the condition on line 294 was never true
295 raise ValueError(
296 "Clause 106-A Violation: input_type must be explicitly declared. "
297 "No default or inferred memory types are allowed."
298 )
300 if not output_type: 300 ↛ 301line 300 didn't jump to line 301 because the condition on line 300 was never true
301 raise ValueError(
302 "Clause 106-A Violation: output_type must be explicitly declared. "
303 "No default or inferred memory types are allowed."
304 )
306 # Validate that memory types are supported
307 if input_type not in VALID_MEMORY_TYPES: 307 ↛ 308line 307 didn't jump to line 308 because the condition on line 307 was never true
308 raise ValueError(
309 f"Clause 106-A Violation: input_type '{input_type}' is not supported. "
310 f"Supported types are: {', '.join(sorted(VALID_MEMORY_TYPES))}"
311 )
313 if output_type not in VALID_MEMORY_TYPES: 313 ↛ 314line 313 didn't jump to line 314 because the condition on line 313 was never true
314 raise ValueError(
315 f"Clause 106-A Violation: output_type '{output_type}' is not supported. "
316 f"Supported types are: {', '.join(sorted(VALID_MEMORY_TYPES))}"
317 )
319 def decorator(func: F) -> F:
320 """
321 Decorator function that sets memory type attributes on the function.
323 Args:
324 func: The function to decorate
326 Returns:
327 The decorated function with memory type attributes set
329 Raises:
330 ValueError: If the function already has different memory type attributes
331 """
332 # 🔒 Clause 66 — Immutability
333 # Check if memory type attributes already exist
334 if hasattr(func, 'input_memory_type') and func.input_memory_type != input_type: 334 ↛ 335line 334 didn't jump to line 335 because the condition on line 334 was never true
335 raise ValueError(
336 f"Clause 66 Violation: Function '{func.__name__}' already has input_memory_type "
337 f"'{func.input_memory_type}', cannot change to '{input_type}'."
338 )
340 if hasattr(func, 'output_memory_type') and func.output_memory_type != output_type: 340 ↛ 341line 340 didn't jump to line 341 because the condition on line 340 was never true
341 raise ValueError(
342 f"Clause 66 Violation: Function '{func.__name__}' already has output_memory_type "
343 f"'{func.output_memory_type}', cannot change to '{output_type}'."
344 )
346 # Set memory type attributes using canonical names
347 # 🔒 Clause 106-A.2 — Canonical Memory Type Attributes
348 func.input_memory_type = input_type
349 func.output_memory_type = output_type
351 # Return the function unchanged (no wrapper)
352 return func
354 return decorator
357def numpy(
358 func: Optional[F] = None,
359 *,
360 input_type: str = "numpy",
361 output_type: str = "numpy"
362) -> Any:
363 """
364 Decorator that declares a function as operating on numpy arrays.
366 This is a convenience wrapper around memory_types with numpy defaults.
368 Args:
369 func: The function to decorate (optional)
370 input_type: The memory type for the function's input (default: "numpy")
371 output_type: The memory type for the function's output (default: "numpy")
373 Returns:
374 The decorated function with memory type attributes set
376 Raises:
377 ValueError: If input_type or output_type is not a supported memory type
378 """
379 def decorator_with_dtype_preservation(func: F) -> F:
380 # Set memory type attributes
381 memory_decorator = memory_types(input_type=input_type, output_type=output_type)
382 func = memory_decorator(func)
384 # Apply dtype preservation wrapper
385 func = _create_numpy_dtype_preserving_wrapper(func, func.__name__)
387 return func
389 # Handle both @numpy and @numpy(input_type=..., output_type=...) forms
390 if func is None: 390 ↛ 391line 390 didn't jump to line 391 because the condition on line 390 was never true
391 return decorator_with_dtype_preservation
393 return decorator_with_dtype_preservation(func)
396def cupy(func: Optional[F] = None, *, input_type: str = "cupy", output_type: str = "cupy", oom_recovery: bool = True) -> Any:
397 """
398 Decorator that declares a function as operating on cupy arrays.
400 This decorator provides automatic thread-local CUDA stream management for
401 true parallelization across multiple threads. Each thread gets its own
402 persistent CUDA stream that is reused for all CuPy operations.
404 Args:
405 func: The function to decorate (optional)
406 input_type: The memory type for the function's input (default: "cupy")
407 output_type: The memory type for the function's output (default: "cupy")
408 oom_recovery: Enable automatic OOM recovery (default: True)
410 Returns:
411 The decorated function with memory type attributes and stream management
413 Raises:
414 ValueError: If input_type or output_type is not a supported memory type
415 """
416 def decorator(func: F) -> F:
417 # Set memory type attributes
418 memory_decorator = memory_types(input_type=input_type, output_type=output_type)
419 func = memory_decorator(func)
421 # Apply dtype preservation wrapper
422 func = _create_cupy_dtype_preserving_wrapper(func, func.__name__)
424 # Add CUDA stream wrapper if CuPy is available (lazy import)
425 @functools.wraps(func)
426 def wrapper(*args, **kwargs):
427 cp = _get_cupy()
428 if cp is not None and hasattr(cp, 'cuda'):
429 # Get unified thread context and CuPy stream
430 gpu_context = get_thread_gpu_context()
431 cupy_stream = gpu_context.get_cupy_stream()
433 def execute_with_stream():
434 if cupy_stream is not None:
435 # Execute function in stream context
436 with cupy_stream:
437 return func(*args, **kwargs)
438 else:
439 # No CUDA available, execute without stream
440 return func(*args, **kwargs)
442 # Execute with OOM recovery if enabled
443 if oom_recovery:
444 return _execute_with_oom_recovery(execute_with_stream, input_type)
445 else:
446 return execute_with_stream()
447 else:
448 # CuPy not available, execute without stream
449 return func(*args, **kwargs)
451 # Preserve memory type attributes
452 wrapper.input_memory_type = func.input_memory_type
453 wrapper.output_memory_type = func.output_memory_type
455 return wrapper
457 # Handle both @cupy and @cupy(input_type=..., output_type=...) forms
458 if func is None: 458 ↛ 459line 458 didn't jump to line 459 because the condition on line 458 was never true
459 return decorator
461 return decorator(func)
464def torch(
465 func: Optional[F] = None,
466 *,
467 input_type: str = "torch",
468 output_type: str = "torch",
469 oom_recovery: bool = True
470) -> Any:
471 """
472 Decorator that declares a function as operating on torch tensors.
474 This decorator provides automatic thread-local CUDA stream management for
475 true parallelization across multiple threads. Each thread gets its own
476 persistent CUDA stream that is reused for all PyTorch operations.
478 Args:
479 func: The function to decorate (optional)
480 input_type: The memory type for the function's input (default: "torch")
481 output_type: The memory type for the function's output (default: "torch")
482 oom_recovery: Enable automatic OOM recovery (default: True)
484 Returns:
485 The decorated function with memory type attributes and stream management
487 Raises:
488 ValueError: If input_type or output_type is not a supported memory type
489 """
490 def decorator(func: F) -> F:
491 # Set memory type attributes
492 memory_decorator = memory_types(input_type=input_type, output_type=output_type)
493 func = memory_decorator(func)
495 # Apply dtype preservation wrapper
496 func = _create_torch_dtype_preserving_wrapper(func, func.__name__)
498 # Add CUDA stream wrapper if PyTorch is available and CUDA is available (lazy import)
499 @functools.wraps(func)
500 def wrapper(*args, **kwargs):
501 torch = _get_torch()
502 if torch is not None and hasattr(torch, 'cuda') and torch.cuda.is_available():
503 # Get unified thread context and PyTorch stream
504 gpu_context = get_thread_gpu_context()
505 torch_stream = gpu_context.get_torch_stream()
507 def execute_with_stream():
508 if torch_stream is not None:
509 # Execute function in stream context
510 with torch.cuda.stream(torch_stream):
511 return func(*args, **kwargs)
512 else:
513 # No CUDA available, execute without stream
514 return func(*args, **kwargs)
516 # Execute with OOM recovery if enabled
517 if oom_recovery:
518 return _execute_with_oom_recovery(execute_with_stream, input_type)
519 else:
520 return execute_with_stream()
521 else:
522 # PyTorch not available or CUDA not available, execute without stream
523 return func(*args, **kwargs)
525 # Preserve memory type attributes
526 wrapper.input_memory_type = func.input_memory_type
527 wrapper.output_memory_type = func.output_memory_type
529 return wrapper
531 # Handle both @torch and @torch(input_type=..., output_type=...) forms
532 if func is None: 532 ↛ 533line 532 didn't jump to line 533 because the condition on line 532 was never true
533 return decorator
535 return decorator(func)
538def tensorflow(
539 func: Optional[F] = None,
540 *,
541 input_type: str = "tensorflow",
542 output_type: str = "tensorflow",
543 oom_recovery: bool = True
544) -> Any:
545 """
546 Decorator that declares a function as operating on tensorflow tensors.
548 This decorator provides automatic thread-local GPU device context management
549 for parallelization across multiple threads. TensorFlow manages CUDA streams
550 internally, so we use device contexts for thread isolation.
552 Args:
553 func: The function to decorate (optional)
554 input_type: The memory type for the function's input (default: "tensorflow")
555 output_type: The memory type for the function's output (default: "tensorflow")
556 oom_recovery: Enable automatic OOM recovery (default: True)
558 Returns:
559 The decorated function with memory type attributes and device management
561 Raises:
562 ValueError: If input_type or output_type is not a supported memory type
563 """
564 def decorator(func: F) -> F:
565 # Set memory type attributes
566 memory_decorator = memory_types(input_type=input_type, output_type=output_type)
567 func = memory_decorator(func)
569 # Apply dtype preservation wrapper
570 func = _create_tensorflow_dtype_preserving_wrapper(func, func.__name__)
572 # Add device context wrapper if TensorFlow is available and GPU is available (lazy import)
573 @functools.wraps(func)
574 def wrapper(*args, **kwargs):
575 tf = _get_tensorflow()
576 if tf is not None and tf.config.list_physical_devices('GPU'):
577 def execute_with_device():
578 # Use GPU device context for thread isolation
579 # TensorFlow manages internal CUDA streams automatically
580 with tf.device('/GPU:0'):
581 return func(*args, **kwargs)
583 # Execute with OOM recovery if enabled
584 if oom_recovery:
585 return _execute_with_oom_recovery(execute_with_device, input_type)
586 else:
587 return execute_with_device()
588 else:
589 # TensorFlow not available or GPU not available, execute without device context
590 return func(*args, **kwargs)
592 # Preserve memory type attributes
593 wrapper.input_memory_type = func.input_memory_type
594 wrapper.output_memory_type = func.output_memory_type
596 return wrapper
598 # Handle both @tensorflow and @tensorflow(input_type=..., output_type=...) forms
599 if func is None: 599 ↛ 600line 599 didn't jump to line 600 because the condition on line 599 was never true
600 return decorator
602 return decorator(func)
605def jax(
606 func: Optional[F] = None,
607 *,
608 input_type: str = "jax",
609 output_type: str = "jax",
610 oom_recovery: bool = True
611) -> Any:
612 """
613 Decorator that declares a function as operating on JAX arrays.
615 This decorator provides automatic thread-local GPU device placement for
616 parallelization across multiple threads. JAX/XLA manages CUDA streams
617 internally, so we use device placement for thread isolation.
619 Args:
620 func: The function to decorate (optional)
621 input_type: The memory type for the function's input (default: "jax")
622 output_type: The memory type for the function's output (default: "jax")
623 oom_recovery: Enable automatic OOM recovery (default: True)
625 Returns:
626 The decorated function with memory type attributes and device management
628 Raises:
629 ValueError: If input_type or output_type is not a supported memory type
630 """
631 def decorator(func: F) -> F:
632 # Set memory type attributes
633 memory_decorator = memory_types(input_type=input_type, output_type=output_type)
634 func = memory_decorator(func)
636 # Apply dtype preservation wrapper
637 func = _create_jax_dtype_preserving_wrapper(func, func.__name__)
639 # Add device placement wrapper if JAX is available and GPU is available (lazy import)
640 @functools.wraps(func)
641 def wrapper(*args, **kwargs):
642 jax_module = _get_jax()
643 if jax_module is not None:
644 devices = jax_module.devices()
645 gpu_devices = [d for d in devices if d.platform == 'gpu']
647 if gpu_devices:
648 def execute_with_device():
649 # Use GPU device placement for thread isolation
650 # JAX/XLA manages internal CUDA streams automatically
651 with jax_module.default_device(gpu_devices[0]):
652 return func(*args, **kwargs)
654 # Execute with OOM recovery if enabled
655 if oom_recovery:
656 return _execute_with_oom_recovery(execute_with_device, input_type)
657 else:
658 return execute_with_device()
659 else:
660 # No GPU devices available, execute without device placement
661 return func(*args, **kwargs)
662 else:
663 # JAX not available, execute without device placement
664 return func(*args, **kwargs)
666 # Preserve memory type attributes
667 wrapper.input_memory_type = func.input_memory_type
668 wrapper.output_memory_type = func.output_memory_type
670 return wrapper
672 # Handle both @jax and @jax(input_type=..., output_type=...) forms
673 if func is None: 673 ↛ 674line 673 didn't jump to line 674 because the condition on line 673 was never true
674 return decorator
676 return decorator(func)
679def pyclesperanto(
680 func: Optional[F] = None,
681 *,
682 input_type: str = "pyclesperanto",
683 output_type: str = "pyclesperanto",
684 oom_recovery: bool = True
685) -> Any:
686 """
687 Decorator that declares a function as operating on pyclesperanto GPU arrays.
689 This decorator provides automatic OOM recovery for pyclesperanto functions.
691 Args:
692 func: The function to decorate (optional)
693 input_type: The memory type for the function's input (default: "pyclesperanto")
694 output_type: The memory type for the function's output (default: "pyclesperanto")
695 oom_recovery: Enable automatic OOM recovery (default: True)
697 Returns:
698 The decorated function with memory type attributes and OOM recovery
700 Raises:
701 ValueError: If input_type or output_type is not a supported memory type
702 """
703 def decorator(func: F) -> F:
704 # Set memory type attributes
705 memory_decorator = memory_types(input_type=input_type, output_type=output_type)
706 func = memory_decorator(func)
708 # Apply dtype preservation wrapper
709 func = _create_pyclesperanto_dtype_preserving_wrapper(func, func.__name__)
711 # Add OOM recovery wrapper
712 @functools.wraps(func)
713 def wrapper(*args, **kwargs):
714 if oom_recovery:
715 return _execute_with_oom_recovery(lambda: func(*args, **kwargs), input_type)
716 else:
717 return func(*args, **kwargs)
719 # Preserve memory type attributes
720 wrapper.input_memory_type = func.input_memory_type
721 wrapper.output_memory_type = func.output_memory_type
723 # Make wrapper pickleable by preserving original function identity
724 wrapper.__module__ = getattr(func, '__module__', wrapper.__module__)
725 wrapper.__qualname__ = getattr(func, '__qualname__', wrapper.__qualname__)
727 # Store reference to original function for pickle support
728 wrapper.__wrapped__ = func
730 return wrapper
732 # Handle both @pyclesperanto and @pyclesperanto(input_type=..., output_type=...) forms
733 if func is None: 733 ↛ 734line 733 didn't jump to line 734 because the condition on line 733 was never true
734 return decorator
736 return decorator(func)
739# ============================================================================
740# Dtype Preservation Wrapper Functions
741# ============================================================================
743def _create_numpy_dtype_preserving_wrapper(original_func, func_name):
744 """
745 Create a wrapper that preserves input data type and adds slice_by_slice parameter for NumPy functions.
747 Many scikit-image functions return float64 regardless of input type.
748 This wrapper ensures the output has the same dtype as the input and adds
749 a slice_by_slice parameter to avoid cross-slice contamination in 3D arrays.
750 """
751 import numpy as np
752 import inspect
753 from functools import wraps
755 @wraps(original_func)
756 def numpy_dtype_and_slice_preserving_wrapper(image, *args, dtype_conversion=None, slice_by_slice: bool = False, **kwargs):
757 # Set default dtype_conversion if not provided and DtypeConversion is available
758 if dtype_conversion is None and DtypeConversion is not None: 758 ↛ 761line 758 didn't jump to line 761 because the condition on line 758 was always true
759 dtype_conversion = DtypeConversion.PRESERVE_INPUT
761 try:
762 # Store original dtype
763 original_dtype = image.dtype
765 # Handle slice_by_slice processing for 3D arrays using OpenHCS stack utilities
766 if slice_by_slice and hasattr(image, 'ndim') and image.ndim == 3: 766 ↛ 767line 766 didn't jump to line 767 because the condition on line 766 was never true
767 from openhcs.core.memory.stack_utils import unstack_slices, stack_slices, _detect_memory_type
769 # Detect memory type and use proper OpenHCS utilities
770 memory_type = _detect_memory_type(image)
771 gpu_id = 0 # Default GPU ID for slice processing
773 # Unstack 3D array into 2D slices
774 slices_2d = unstack_slices(image, memory_type, gpu_id)
776 # Process each slice and handle special outputs
777 main_outputs = []
778 special_outputs_list = []
780 for slice_2d in slices_2d:
781 slice_result = original_func(slice_2d, *args, **kwargs)
783 # Check if result is a tuple (indicating special outputs)
784 if isinstance(slice_result, tuple):
785 main_outputs.append(slice_result[0]) # First element is main output
786 special_outputs_list.append(slice_result[1:]) # Rest are special outputs
787 else:
788 main_outputs.append(slice_result) # Single output
790 # Stack main outputs back into 3D array
791 result = stack_slices(main_outputs, memory_type, gpu_id)
793 # If we have special outputs, combine them and return tuple
794 if special_outputs_list:
795 # Combine special outputs from all slices
796 combined_special_outputs = []
797 num_special_outputs = len(special_outputs_list[0])
799 for i in range(num_special_outputs):
800 # Collect the i-th special output from all slices
801 special_output_values = [slice_outputs[i] for slice_outputs in special_outputs_list]
802 combined_special_outputs.append(special_output_values)
804 # Return tuple: (stacked_main_output, combined_special_output1, combined_special_output2, ...)
805 result = (result, *combined_special_outputs)
806 else:
807 # Call the original function normally
808 result = original_func(image, *args, **kwargs)
810 # Apply dtype conversion based on enum value
811 if hasattr(result, 'dtype') and dtype_conversion is not None:
812 if dtype_conversion == DtypeConversion.PRESERVE_INPUT: 812 ↛ 816line 812 didn't jump to line 816 because the condition on line 812 was always true
813 # Preserve input dtype
814 if result.dtype != original_dtype: 814 ↛ 815line 814 didn't jump to line 815 because the condition on line 814 was never true
815 result = _scale_and_convert_numpy(result, original_dtype)
816 elif dtype_conversion == DtypeConversion.NATIVE_OUTPUT:
817 # Return NumPy's native output dtype
818 pass # No conversion needed
819 else:
820 # Force specific dtype
821 target_dtype = dtype_conversion.numpy_dtype
822 if target_dtype is not None:
823 result = _scale_and_convert_numpy(result, target_dtype)
825 return result
826 except Exception as e:
827 logger.error(f"Error in NumPy dtype/slice preserving wrapper for {func_name}: {e}")
828 # Return original result on error
829 return original_func(image, *args, **kwargs)
831 # Update function signature to include new parameters
832 try:
833 original_sig = inspect.signature(original_func)
834 new_params = list(original_sig.parameters.values())
836 # Check if slice_by_slice parameter already exists
837 param_names = [p.name for p in new_params]
838 # Add dtype_conversion parameter first (before slice_by_slice)
839 if 'dtype_conversion' not in param_names: 839 ↛ 848line 839 didn't jump to line 848 because the condition on line 839 was always true
840 dtype_param = inspect.Parameter(
841 'dtype_conversion',
842 inspect.Parameter.KEYWORD_ONLY,
843 default=DtypeConversion.PRESERVE_INPUT,
844 annotation=DtypeConversion
845 )
846 new_params.append(dtype_param)
848 if 'slice_by_slice' not in param_names: 848 ↛ 859line 848 didn't jump to line 859 because the condition on line 848 was always true
849 # Add slice_by_slice parameter as keyword-only (after dtype_conversion)
850 slice_param = inspect.Parameter(
851 'slice_by_slice',
852 inspect.Parameter.KEYWORD_ONLY,
853 default=False,
854 annotation=bool
855 )
856 new_params.append(slice_param)
858 # Create new signature and override the @wraps signature
859 new_sig = original_sig.replace(parameters=new_params)
860 numpy_dtype_and_slice_preserving_wrapper.__signature__ = new_sig
862 # Set type annotations manually for get_type_hints() compatibility
863 numpy_dtype_and_slice_preserving_wrapper.__annotations__ = getattr(original_func, '__annotations__', {}).copy()
864 numpy_dtype_and_slice_preserving_wrapper.__annotations__['dtype_conversion'] = DtypeConversion
865 numpy_dtype_and_slice_preserving_wrapper.__annotations__['slice_by_slice'] = bool
867 except Exception:
868 # If signature modification fails, continue without it
869 pass
871 # Update docstring to mention slice_by_slice parameter
872 original_doc = numpy_dtype_and_slice_preserving_wrapper.__doc__ or ""
873 additional_doc = """
875 Additional OpenHCS Parameters
876 -----------------------------
877 slice_by_slice : bool, optional (default: False)
878 If True, process 3D arrays slice-by-slice to avoid cross-slice contamination.
879 If False, use original 3D behavior. Recommended for edge detection functions
880 on stitched microscopy data to prevent artifacts at field boundaries.
882 dtype_conversion : DtypeConversion, optional (default: PRESERVE_INPUT)
883 Controls output data type conversion:
885 - PRESERVE_INPUT: Keep input dtype (uint16 → uint16)
886 - NATIVE_OUTPUT: Use NumPy's native output dtype
887 - UINT8: Force 8-bit unsigned integer (0-255 range)
888 - UINT16: Force 16-bit unsigned integer (microscopy standard)
889 - INT16: Force 16-bit signed integer
890 - INT32: Force 32-bit signed integer
891 - FLOAT32: Force 32-bit float (GPU performance)
892 - FLOAT64: Force 64-bit float (maximum precision)
893 """
894 numpy_dtype_and_slice_preserving_wrapper.__doc__ = original_doc + additional_doc
896 return numpy_dtype_and_slice_preserving_wrapper
899def _create_cupy_dtype_preserving_wrapper(original_func, func_name):
900 """
901 Create a wrapper that preserves input data type and adds slice_by_slice parameter for CuPy functions.
903 This uses the SAME pattern as scikit-image for consistency. CuPy functions generally preserve
904 dtypes better than scikit-image, but this wrapper ensures consistent behavior and adds
905 slice_by_slice parameter to avoid cross-slice contamination in 3D arrays.
906 """
907 import inspect
908 from functools import wraps
910 @wraps(original_func)
911 def cupy_dtype_and_slice_preserving_wrapper(image, *args, dtype_conversion=None, slice_by_slice: bool = False, **kwargs):
912 # Set default dtype_conversion if not provided and DtypeConversion is available
913 if dtype_conversion is None and DtypeConversion is not None:
914 dtype_conversion = DtypeConversion.PRESERVE_INPUT
916 try:
917 cupy = optional_import("cupy")
918 if cupy is None:
919 return original_func(image, *args, **kwargs)
921 # Store original dtype
922 original_dtype = image.dtype
924 # Handle slice_by_slice processing for 3D arrays using OpenHCS stack utilities
925 if slice_by_slice and hasattr(image, 'ndim') and image.ndim == 3:
926 from openhcs.core.memory.stack_utils import unstack_slices, stack_slices, _detect_memory_type
928 # Detect memory type and use proper OpenHCS utilities
929 memory_type = _detect_memory_type(image)
930 gpu_id = image.device.id if hasattr(image, 'device') else 0
932 # Unstack 3D array into 2D slices
933 slices_2d = unstack_slices(image, memory_type, gpu_id)
935 # Process each slice and handle special outputs
936 main_outputs = []
937 special_outputs_list = []
939 for slice_2d in slices_2d:
940 slice_result = original_func(slice_2d, *args, **kwargs)
942 # Check if result is a tuple (indicating special outputs)
943 if isinstance(slice_result, tuple):
944 main_outputs.append(slice_result[0]) # First element is main output
945 special_outputs_list.append(slice_result[1:]) # Rest are special outputs
946 else:
947 main_outputs.append(slice_result) # Single output
949 # Stack main outputs back into 3D array
950 result = stack_slices(main_outputs, memory_type, gpu_id)
952 # If we have special outputs, combine them and return tuple
953 if special_outputs_list:
954 # Combine special outputs from all slices
955 combined_special_outputs = []
956 num_special_outputs = len(special_outputs_list[0])
958 for i in range(num_special_outputs):
959 # Collect the i-th special output from all slices
960 special_output_values = [slice_outputs[i] for slice_outputs in special_outputs_list]
961 combined_special_outputs.append(special_output_values)
963 # Return tuple: (stacked_main_output, combined_special_output1, combined_special_output2, ...)
964 result = (result, *combined_special_outputs)
965 else:
966 # Call the original function normally
967 result = original_func(image, *args, **kwargs)
969 # Apply dtype conversion based on enum value
970 if hasattr(result, 'dtype') and dtype_conversion is not None:
971 if dtype_conversion == DtypeConversion.PRESERVE_INPUT:
972 # Preserve input dtype
973 if result.dtype != original_dtype:
974 result = _scale_and_convert_cupy(result, original_dtype)
975 elif dtype_conversion == DtypeConversion.NATIVE_OUTPUT:
976 # Return CuPy's native output dtype
977 pass # No conversion needed
978 else:
979 # Force specific dtype
980 target_dtype = dtype_conversion.numpy_dtype
981 if target_dtype is not None:
982 result = _scale_and_convert_cupy(result, target_dtype)
984 return result
985 except Exception as e:
986 logger.error(f"Error in CuPy dtype/slice preserving wrapper for {func_name}: {e}")
987 # Return original result on error
988 return original_func(image, *args, **kwargs)
990 # Update function signature to include new parameters
991 try:
992 original_sig = inspect.signature(original_func)
993 new_params = list(original_sig.parameters.values())
995 # Check if slice_by_slice parameter already exists
996 param_names = [p.name for p in new_params]
997 # Add dtype_conversion parameter first (before slice_by_slice)
998 if 'dtype_conversion' not in param_names: 998 ↛ 1007line 998 didn't jump to line 1007 because the condition on line 998 was always true
999 dtype_param = inspect.Parameter(
1000 'dtype_conversion',
1001 inspect.Parameter.KEYWORD_ONLY,
1002 default=DtypeConversion.PRESERVE_INPUT,
1003 annotation=DtypeConversion
1004 )
1005 new_params.append(dtype_param)
1007 if 'slice_by_slice' not in param_names: 1007 ↛ 1018line 1007 didn't jump to line 1018 because the condition on line 1007 was always true
1008 # Add slice_by_slice parameter as keyword-only (after dtype_conversion)
1009 slice_param = inspect.Parameter(
1010 'slice_by_slice',
1011 inspect.Parameter.KEYWORD_ONLY,
1012 default=False,
1013 annotation=bool
1014 )
1015 new_params.append(slice_param)
1017 # Create new signature and override the @wraps signature
1018 new_sig = original_sig.replace(parameters=new_params)
1019 cupy_dtype_and_slice_preserving_wrapper.__signature__ = new_sig
1021 # Set type annotations manually for get_type_hints() compatibility
1022 cupy_dtype_and_slice_preserving_wrapper.__annotations__ = getattr(original_func, '__annotations__', {}).copy()
1023 cupy_dtype_and_slice_preserving_wrapper.__annotations__['dtype_conversion'] = DtypeConversion
1024 cupy_dtype_and_slice_preserving_wrapper.__annotations__['slice_by_slice'] = bool
1026 except Exception:
1027 # If signature modification fails, continue without it
1028 pass
1030 # Update docstring to mention slice_by_slice parameter
1031 original_doc = cupy_dtype_and_slice_preserving_wrapper.__doc__ or ""
1032 additional_doc = """
1034 Additional OpenHCS Parameters
1035 -----------------------------
1036 slice_by_slice : bool, optional (default: False)
1037 If True, process 3D arrays slice-by-slice to avoid cross-slice contamination.
1038 If False, use original 3D behavior. Recommended for edge detection functions
1039 on stitched microscopy data to prevent artifacts at field boundaries.
1041 dtype_conversion : DtypeConversion, optional (default: PRESERVE_INPUT)
1042 Controls output data type conversion:
1044 - PRESERVE_INPUT: Keep input dtype (uint16 → uint16)
1045 - NATIVE_OUTPUT: Use CuPy's native output dtype
1046 - UINT8: Force 8-bit unsigned integer (0-255 range)
1047 - UINT16: Force 16-bit unsigned integer (microscopy standard)
1048 - INT16: Force 16-bit signed integer
1049 - INT32: Force 32-bit signed integer
1050 - FLOAT32: Force 32-bit float (GPU performance)
1051 - FLOAT64: Force 64-bit float (maximum precision)
1052 """
1053 cupy_dtype_and_slice_preserving_wrapper.__doc__ = original_doc + additional_doc
1055 return cupy_dtype_and_slice_preserving_wrapper
1058def _create_torch_dtype_preserving_wrapper(original_func, func_name):
1059 """
1060 Create a wrapper that preserves input data type and adds slice_by_slice parameter for PyTorch functions.
1062 This follows the same pattern as existing dtype preservation wrappers for consistency.
1063 PyTorch functions generally preserve dtypes well, but this wrapper ensures consistent behavior
1064 and adds slice_by_slice parameter to avoid cross-slice contamination in 3D arrays.
1065 """
1066 import inspect
1067 from functools import wraps
1069 @wraps(original_func)
1070 def torch_dtype_and_slice_preserving_wrapper(image, *args, dtype_conversion=None, slice_by_slice: bool = False, **kwargs):
1071 # Set default dtype_conversion if not provided
1072 if dtype_conversion is None:
1073 dtype_conversion = DtypeConversion.PRESERVE_INPUT
1075 try:
1076 torch = optional_import("torch")
1077 if torch is None:
1078 return original_func(image, *args, **kwargs)
1080 # Store original dtype
1081 original_dtype = image.dtype if hasattr(image, 'dtype') else None
1083 # Handle slice_by_slice processing for 3D arrays
1084 if slice_by_slice and hasattr(image, 'ndim') and image.ndim == 3:
1085 from openhcs.core.memory.stack_utils import unstack_slices, stack_slices, _detect_memory_type
1087 # Detect memory type and use proper OpenHCS utilities
1088 memory_type = _detect_memory_type(image)
1089 gpu_id = image.device.index if hasattr(image, 'device') and image.device.type == 'cuda' else 0
1091 # Unstack 3D array into 2D slices
1092 slices_2d = unstack_slices(image, memory_type=memory_type, gpu_id=gpu_id)
1094 # Process each slice and handle special outputs
1095 main_outputs = []
1096 special_outputs_list = []
1098 for slice_2d in slices_2d:
1099 slice_result = original_func(slice_2d, *args, **kwargs)
1101 # Check if result is a tuple (indicating special outputs)
1102 if isinstance(slice_result, tuple):
1103 main_outputs.append(slice_result[0]) # First element is main output
1104 special_outputs_list.append(slice_result[1:]) # Rest are special outputs
1105 else:
1106 main_outputs.append(slice_result) # Single output
1108 # Stack main outputs back into 3D array
1109 result = stack_slices(main_outputs, memory_type=memory_type, gpu_id=gpu_id)
1111 # If we have special outputs, combine them and return tuple
1112 if special_outputs_list:
1113 # Combine special outputs from all slices
1114 combined_special_outputs = []
1115 num_special_outputs = len(special_outputs_list[0])
1117 for i in range(num_special_outputs):
1118 # Collect the i-th special output from all slices
1119 special_output_values = [slice_outputs[i] for slice_outputs in special_outputs_list]
1120 combined_special_outputs.append(special_output_values)
1122 # Return tuple: (stacked_main_output, combined_special_output1, combined_special_output2, ...)
1123 result = (result, *combined_special_outputs)
1124 else:
1125 # Process normally
1126 result = original_func(image, *args, **kwargs)
1128 # Apply dtype conversion if result is a tensor and we have dtype conversion info
1129 if (hasattr(result, 'dtype') and hasattr(result, 'shape') and
1130 original_dtype is not None and dtype_conversion is not None):
1132 if dtype_conversion == DtypeConversion.PRESERVE_INPUT:
1133 # Preserve input dtype
1134 if result.dtype != original_dtype:
1135 result = result.to(original_dtype)
1136 elif dtype_conversion == DtypeConversion.NATIVE_OUTPUT:
1137 # Return PyTorch's native output dtype
1138 pass # No conversion needed
1139 else:
1140 # Force specific dtype
1141 target_dtype = dtype_conversion.numpy_dtype
1142 if target_dtype is not None:
1143 # Map numpy dtypes to torch dtypes
1144 import numpy as np
1145 numpy_to_torch = {
1146 np.uint8: torch.uint8,
1147 np.uint16: torch.int32, # PyTorch doesn't have uint16, use int32
1148 np.int16: torch.int16,
1149 np.int32: torch.int32,
1150 np.float32: torch.float32,
1151 np.float64: torch.float64,
1152 }
1153 torch_dtype = numpy_to_torch.get(target_dtype)
1154 if torch_dtype is not None:
1155 result = result.to(torch_dtype)
1157 return result
1159 except Exception as e:
1160 logger.error(f"Error in PyTorch dtype/slice preserving wrapper for {func_name}: {e}")
1161 # Return original result on error
1162 return original_func(image, *args, **kwargs)
1164 # Update function signature to include new parameters
1165 try:
1166 original_sig = inspect.signature(original_func)
1167 new_params = list(original_sig.parameters.values())
1169 # Add dtype_conversion parameter first (before slice_by_slice)
1170 param_names = [p.name for p in new_params]
1171 if 'dtype_conversion' not in param_names: 1171 ↛ 1181line 1171 didn't jump to line 1181 because the condition on line 1171 was always true
1172 dtype_param = inspect.Parameter(
1173 'dtype_conversion',
1174 inspect.Parameter.KEYWORD_ONLY,
1175 default=DtypeConversion.PRESERVE_INPUT,
1176 annotation=DtypeConversion
1177 )
1178 new_params.append(dtype_param)
1180 # Add slice_by_slice parameter after dtype_conversion
1181 if 'slice_by_slice' not in param_names:
1182 slice_param = inspect.Parameter(
1183 'slice_by_slice',
1184 inspect.Parameter.KEYWORD_ONLY,
1185 default=False,
1186 annotation=bool
1187 )
1188 new_params.append(slice_param)
1190 new_sig = original_sig.replace(parameters=new_params)
1191 torch_dtype_and_slice_preserving_wrapper.__signature__ = new_sig
1193 # Set type annotations manually for get_type_hints() compatibility
1194 torch_dtype_and_slice_preserving_wrapper.__annotations__ = getattr(original_func, '__annotations__', {}).copy()
1195 if DtypeConversion is not None: 1195 ↛ 1197line 1195 didn't jump to line 1197 because the condition on line 1195 was always true
1196 torch_dtype_and_slice_preserving_wrapper.__annotations__['dtype_conversion'] = DtypeConversion
1197 torch_dtype_and_slice_preserving_wrapper.__annotations__['slice_by_slice'] = bool
1199 except Exception:
1200 # If signature modification fails, continue without it
1201 pass
1203 # Update docstring to mention new parameters
1204 original_doc = torch_dtype_and_slice_preserving_wrapper.__doc__ or ""
1205 additional_doc = """
1207 Additional OpenHCS Parameters
1208 -----------------------------
1209 slice_by_slice : bool, optional (default: False)
1210 If True, process 3D arrays slice-by-slice to avoid cross-slice contamination.
1211 If False, use original 3D behavior. Recommended for edge detection functions
1212 on stitched microscopy data to prevent artifacts at field boundaries.
1214 dtype_conversion : DtypeConversion, optional (default: PRESERVE_INPUT)
1215 Controls output data type conversion:
1217 - PRESERVE_INPUT: Keep input dtype (uint16 → uint16)
1218 - NATIVE_OUTPUT: Use PyTorch's native output dtype
1219 - UINT8: Force 8-bit unsigned integer (0-255 range)
1220 - UINT16: Force 16-bit unsigned integer (mapped to int32 in PyTorch)
1221 - INT16: Force 16-bit signed integer
1222 - INT32: Force 32-bit signed integer
1223 - FLOAT32: Force 32-bit float (GPU performance)
1224 - FLOAT64: Force 64-bit float (maximum precision)
1225 """
1226 torch_dtype_and_slice_preserving_wrapper.__doc__ = original_doc + additional_doc
1228 return torch_dtype_and_slice_preserving_wrapper
1231def _create_tensorflow_dtype_preserving_wrapper(original_func, func_name):
1232 """
1233 Create a wrapper that preserves input data type and adds slice_by_slice parameter for TensorFlow functions.
1235 This follows the same pattern as existing dtype preservation wrappers for consistency.
1236 TensorFlow functions generally preserve dtypes well, but this wrapper ensures consistent behavior
1237 and adds slice_by_slice parameter to avoid cross-slice contamination in 3D arrays.
1238 """
1239 import inspect
1240 from functools import wraps
1242 @wraps(original_func)
1243 def tensorflow_dtype_and_slice_preserving_wrapper(image, *args, dtype_conversion=None, slice_by_slice: bool = False, **kwargs):
1244 # Set default dtype_conversion if not provided
1245 if dtype_conversion is None:
1246 dtype_conversion = DtypeConversion.PRESERVE_INPUT
1248 try:
1249 tf = optional_import("tensorflow")
1250 if tf is None:
1251 return original_func(image, *args, **kwargs)
1253 # Store original dtype
1254 original_dtype = image.dtype if hasattr(image, 'dtype') else None
1256 # Handle slice_by_slice processing for 3D arrays
1257 if slice_by_slice and hasattr(image, 'ndim') and image.ndim == 3:
1258 from openhcs.core.memory.stack_utils import unstack_slices, stack_slices, _detect_memory_type
1260 # Detect memory type and use proper OpenHCS utilities
1261 memory_type = _detect_memory_type(image)
1262 gpu_id = 0 # TensorFlow manages GPU placement internally
1264 # Unstack 3D array into 2D slices
1265 slices_2d = unstack_slices(image, memory_type=memory_type, gpu_id=gpu_id)
1267 # Process each slice and handle special outputs
1268 main_outputs = []
1269 special_outputs_list = []
1271 for slice_2d in slices_2d:
1272 slice_result = original_func(slice_2d, *args, **kwargs)
1274 # Check if result is a tuple (indicating special outputs)
1275 if isinstance(slice_result, tuple):
1276 main_outputs.append(slice_result[0]) # First element is main output
1277 special_outputs_list.append(slice_result[1:]) # Rest are special outputs
1278 else:
1279 main_outputs.append(slice_result) # Single output
1281 # Stack main outputs back into 3D array
1282 result = stack_slices(main_outputs, memory_type=memory_type, gpu_id=gpu_id)
1284 # If we have special outputs, combine them and return tuple
1285 if special_outputs_list:
1286 # Combine special outputs from all slices
1287 combined_special_outputs = []
1288 num_special_outputs = len(special_outputs_list[0])
1290 for i in range(num_special_outputs):
1291 # Collect the i-th special output from all slices
1292 special_output_values = [slice_outputs[i] for slice_outputs in special_outputs_list]
1293 combined_special_outputs.append(special_output_values)
1295 # Return tuple: (stacked_main_output, combined_special_output1, combined_special_output2, ...)
1296 result = (result, *combined_special_outputs)
1297 else:
1298 # Process normally
1299 result = original_func(image, *args, **kwargs)
1301 # Apply dtype conversion if result is a tensor and we have dtype conversion info
1302 if (hasattr(result, 'dtype') and hasattr(result, 'shape') and
1303 original_dtype is not None and dtype_conversion is not None):
1305 if dtype_conversion == DtypeConversion.PRESERVE_INPUT:
1306 # Preserve input dtype
1307 if result.dtype != original_dtype:
1308 result = tf.cast(result, original_dtype)
1309 elif dtype_conversion == DtypeConversion.NATIVE_OUTPUT:
1310 # Return TensorFlow's native output dtype
1311 pass # No conversion needed
1312 else:
1313 # Force specific dtype
1314 target_dtype = dtype_conversion.numpy_dtype
1315 if target_dtype is not None:
1316 # Convert numpy dtype to tensorflow dtype
1317 import numpy as np
1318 numpy_to_tf = {
1319 np.uint8: tf.uint8,
1320 np.uint16: tf.uint16,
1321 np.int16: tf.int16,
1322 np.int32: tf.int32,
1323 np.float32: tf.float32,
1324 np.float64: tf.float64,
1325 }
1326 tf_dtype = numpy_to_tf.get(target_dtype)
1327 if tf_dtype is not None:
1328 result = tf.cast(result, tf_dtype)
1330 return result
1332 except Exception as e:
1333 logger.error(f"Error in TensorFlow dtype/slice preserving wrapper for {func_name}: {e}")
1334 # Return original result on error
1335 return original_func(image, *args, **kwargs)
1337 # Update function signature to include new parameters
1338 try:
1339 original_sig = inspect.signature(original_func)
1340 new_params = list(original_sig.parameters.values())
1342 # Add slice_by_slice parameter if not already present
1343 param_names = [p.name for p in new_params]
1344 if 'slice_by_slice' not in param_names: 1344 ↛ 1354line 1344 didn't jump to line 1354 because the condition on line 1344 was always true
1345 slice_param = inspect.Parameter(
1346 'slice_by_slice',
1347 inspect.Parameter.KEYWORD_ONLY,
1348 default=False,
1349 annotation=bool
1350 )
1351 new_params.append(slice_param)
1353 # Add dtype_conversion parameter if DtypeConversion is available
1354 if DtypeConversion is not None and 'dtype_conversion' not in param_names: 1354 ↛ 1363line 1354 didn't jump to line 1363 because the condition on line 1354 was always true
1355 dtype_param = inspect.Parameter(
1356 'dtype_conversion',
1357 inspect.Parameter.KEYWORD_ONLY,
1358 default=DtypeConversion.PRESERVE_INPUT,
1359 annotation=DtypeConversion
1360 )
1361 new_params.append(dtype_param)
1363 new_sig = original_sig.replace(parameters=new_params)
1364 tensorflow_dtype_and_slice_preserving_wrapper.__signature__ = new_sig
1366 # Set type annotations manually for get_type_hints() compatibility
1367 tensorflow_dtype_and_slice_preserving_wrapper.__annotations__ = getattr(original_func, '__annotations__', {}).copy()
1368 tensorflow_dtype_and_slice_preserving_wrapper.__annotations__['slice_by_slice'] = bool
1369 if DtypeConversion is not None: 1369 ↛ 1377line 1369 didn't jump to line 1377 because the condition on line 1369 was always true
1370 tensorflow_dtype_and_slice_preserving_wrapper.__annotations__['dtype_conversion'] = DtypeConversion
1372 except Exception:
1373 # If signature modification fails, continue without it
1374 pass
1376 # Update docstring to mention new parameters
1377 original_doc = tensorflow_dtype_and_slice_preserving_wrapper.__doc__ or ""
1378 additional_doc = """
1380 Additional OpenHCS Parameters
1381 -----------------------------
1382 slice_by_slice : bool, optional (default: False)
1383 If True, process 3D arrays slice-by-slice to avoid cross-slice contamination.
1384 If False, use original 3D behavior. Recommended for edge detection functions
1385 on stitched microscopy data to prevent artifacts at field boundaries.
1387 dtype_conversion : DtypeConversion, optional (default: PRESERVE_INPUT)
1388 Controls output data type conversion:
1390 - PRESERVE_INPUT: Keep input dtype (uint16 → uint16)
1391 - NATIVE_OUTPUT: Use TensorFlow's native output dtype
1392 - UINT8: Force 8-bit unsigned integer (0-255 range)
1393 - UINT16: Force 16-bit unsigned integer (microscopy standard)
1394 - INT16: Force 16-bit signed integer
1395 - INT32: Force 32-bit signed integer
1396 - FLOAT32: Force 32-bit float (GPU performance)
1397 - FLOAT64: Force 64-bit float (maximum precision)
1398 """
1399 tensorflow_dtype_and_slice_preserving_wrapper.__doc__ = original_doc + additional_doc
1401 return tensorflow_dtype_and_slice_preserving_wrapper
1404def _create_jax_dtype_preserving_wrapper(original_func, func_name):
1405 """
1406 Create a wrapper that preserves input data type and adds slice_by_slice parameter for JAX functions.
1408 This follows the same pattern as existing dtype preservation wrappers for consistency.
1409 JAX functions generally preserve dtypes well, but this wrapper ensures consistent behavior
1410 and adds slice_by_slice parameter to avoid cross-slice contamination in 3D arrays.
1411 """
1412 import inspect
1413 from functools import wraps
1415 @wraps(original_func)
1416 def jax_dtype_and_slice_preserving_wrapper(image, *args, dtype_conversion=None, slice_by_slice: bool = False, **kwargs):
1417 # Set default dtype_conversion if not provided
1418 if dtype_conversion is None:
1419 dtype_conversion = DtypeConversion.PRESERVE_INPUT
1421 try:
1422 jax = optional_import("jax")
1423 jnp = optional_import("jax.numpy") if jax is not None else None
1424 if jax is None or jnp is None:
1425 return original_func(image, *args, **kwargs)
1427 # Store original dtype
1428 original_dtype = image.dtype if hasattr(image, 'dtype') else None
1430 # Handle slice_by_slice processing for 3D arrays
1431 if slice_by_slice and hasattr(image, 'ndim') and image.ndim == 3:
1432 from openhcs.core.memory.stack_utils import unstack_slices, stack_slices, _detect_memory_type
1434 # Detect memory type and use proper OpenHCS utilities
1435 memory_type = _detect_memory_type(image)
1436 gpu_id = 0 # JAX manages GPU placement internally
1438 # Unstack 3D array into 2D slices
1439 slices_2d = unstack_slices(image, memory_type=memory_type, gpu_id=gpu_id)
1441 # Process each slice and handle special outputs
1442 main_outputs = []
1443 special_outputs_list = []
1445 for slice_2d in slices_2d:
1446 slice_result = original_func(slice_2d, *args, **kwargs)
1448 # Check if result is a tuple (indicating special outputs)
1449 if isinstance(slice_result, tuple):
1450 main_outputs.append(slice_result[0]) # First element is main output
1451 special_outputs_list.append(slice_result[1:]) # Rest are special outputs
1452 else:
1453 main_outputs.append(slice_result) # Single output
1455 # Stack main outputs back into 3D array
1456 result = stack_slices(main_outputs, memory_type=memory_type, gpu_id=gpu_id)
1458 # If we have special outputs, combine them and return tuple
1459 if special_outputs_list:
1460 # Combine special outputs from all slices
1461 combined_special_outputs = []
1462 num_special_outputs = len(special_outputs_list[0])
1464 for i in range(num_special_outputs):
1465 # Collect the i-th special output from all slices
1466 special_output_values = [slice_outputs[i] for slice_outputs in special_outputs_list]
1467 combined_special_outputs.append(special_output_values)
1469 # Return tuple: (stacked_main_output, combined_special_output1, combined_special_output2, ...)
1470 result = (result, *combined_special_outputs)
1471 else:
1472 # Process normally
1473 result = original_func(image, *args, **kwargs)
1475 # Apply dtype conversion if result is an array and we have dtype conversion info
1476 if (hasattr(result, 'dtype') and hasattr(result, 'shape') and
1477 original_dtype is not None and dtype_conversion is not None):
1479 if dtype_conversion == DtypeConversion.PRESERVE_INPUT:
1480 # Preserve input dtype
1481 if result.dtype != original_dtype:
1482 result = result.astype(original_dtype)
1483 elif dtype_conversion == DtypeConversion.NATIVE_OUTPUT:
1484 # Return JAX's native output dtype
1485 pass # No conversion needed
1486 else:
1487 # Force specific dtype
1488 target_dtype = dtype_conversion.numpy_dtype
1489 if target_dtype is not None:
1490 # JAX uses numpy-compatible dtypes
1491 result = result.astype(target_dtype)
1493 return result
1495 except Exception as e:
1496 logger.error(f"Error in JAX dtype/slice preserving wrapper for {func_name}: {e}")
1497 # Return original result on error
1498 return original_func(image, *args, **kwargs)
1500 # Update function signature to include new parameters
1501 try:
1502 original_sig = inspect.signature(original_func)
1503 new_params = list(original_sig.parameters.values())
1505 # Add slice_by_slice parameter if not already present
1506 param_names = [p.name for p in new_params]
1507 if 'slice_by_slice' not in param_names:
1508 slice_param = inspect.Parameter(
1509 'slice_by_slice',
1510 inspect.Parameter.KEYWORD_ONLY,
1511 default=False,
1512 annotation=bool
1513 )
1514 new_params.append(slice_param)
1516 # Add dtype_conversion parameter if DtypeConversion is available
1517 if DtypeConversion is not None and 'dtype_conversion' not in param_names: 1517 ↛ 1526line 1517 didn't jump to line 1526 because the condition on line 1517 was always true
1518 dtype_param = inspect.Parameter(
1519 'dtype_conversion',
1520 inspect.Parameter.KEYWORD_ONLY,
1521 default=DtypeConversion.PRESERVE_INPUT,
1522 annotation=DtypeConversion
1523 )
1524 new_params.append(dtype_param)
1526 new_sig = original_sig.replace(parameters=new_params)
1527 jax_dtype_and_slice_preserving_wrapper.__signature__ = new_sig
1529 # Set type annotations manually for get_type_hints() compatibility
1530 jax_dtype_and_slice_preserving_wrapper.__annotations__ = getattr(original_func, '__annotations__', {}).copy()
1531 jax_dtype_and_slice_preserving_wrapper.__annotations__['slice_by_slice'] = bool
1532 if DtypeConversion is not None: 1532 ↛ 1540line 1532 didn't jump to line 1540 because the condition on line 1532 was always true
1533 jax_dtype_and_slice_preserving_wrapper.__annotations__['dtype_conversion'] = DtypeConversion
1535 except Exception:
1536 # If signature modification fails, continue without it
1537 pass
1539 # Update docstring to mention new parameters
1540 original_doc = jax_dtype_and_slice_preserving_wrapper.__doc__ or ""
1541 additional_doc = """
1543 Additional OpenHCS Parameters
1544 -----------------------------
1545 slice_by_slice : bool, optional (default: False)
1546 If True, process 3D arrays slice-by-slice to avoid cross-slice contamination.
1547 If False, use original 3D behavior. Recommended for edge detection functions
1548 on stitched microscopy data to prevent artifacts at field boundaries.
1550 dtype_conversion : DtypeConversion, optional (default: PRESERVE_INPUT)
1551 Controls output data type conversion:
1553 - PRESERVE_INPUT: Keep input dtype (uint16 → uint16)
1554 - NATIVE_OUTPUT: Use JAX's native output dtype
1555 - UINT8: Force 8-bit unsigned integer (0-255 range)
1556 - UINT16: Force 16-bit unsigned integer (microscopy standard)
1557 - INT16: Force 16-bit signed integer
1558 - INT32: Force 32-bit signed integer
1559 - FLOAT32: Force 32-bit float (GPU performance)
1560 - FLOAT64: Force 64-bit float (maximum precision)
1561 """
1562 jax_dtype_and_slice_preserving_wrapper.__doc__ = original_doc + additional_doc
1564 return jax_dtype_and_slice_preserving_wrapper
1567def _create_pyclesperanto_dtype_preserving_wrapper(original_func, func_name):
1568 """
1569 Create a wrapper that ensures array-in/array-out compliance and dtype preservation for pyclesperanto functions.
1571 All OpenHCS functions must:
1572 1. Take 3D pyclesperanto array as first argument
1573 2. Return 3D pyclesperanto array as first output
1574 3. Additional outputs (values, coordinates) as 2nd, 3rd, etc. returns
1575 4. Preserve input dtype when appropriate
1576 """
1577 import inspect
1578 from functools import wraps
1580 @wraps(original_func)
1581 def pyclesperanto_dtype_and_slice_preserving_wrapper(image_3d, *args, dtype_conversion=None, slice_by_slice: bool = False, **kwargs):
1582 # Set default dtype_conversion if not provided
1583 if dtype_conversion is None:
1584 dtype_conversion = DtypeConversion.PRESERVE_INPUT
1586 try:
1587 cle = optional_import("pyclesperanto")
1588 if cle is None:
1589 return original_func(image_3d, *args, **kwargs)
1591 # Store original dtype for preservation
1592 original_dtype = image_3d.dtype
1594 # Handle slice_by_slice processing for 3D arrays using OpenHCS stack utilities
1595 if slice_by_slice and hasattr(image_3d, 'ndim') and image_3d.ndim == 3:
1596 from openhcs.core.memory.stack_utils import unstack_slices, stack_slices, _detect_memory_type
1598 # Detect memory type and use proper OpenHCS utilities
1599 memory_type = _detect_memory_type(image_3d)
1600 gpu_id = 0 # pyclesperanto manages GPU internally
1602 # Process each slice and handle special outputs
1603 slices = unstack_slices(image_3d, memory_type, gpu_id)
1604 main_outputs = []
1605 special_outputs_list = []
1607 for slice_2d in slices:
1608 # Apply function to 2D slice
1609 result_slice = original_func(slice_2d, *args, **kwargs)
1611 # Check if result is a tuple (indicating special outputs)
1612 if isinstance(result_slice, tuple):
1613 main_outputs.append(result_slice[0]) # First element is main output
1614 special_outputs_list.append(result_slice[1:]) # Rest are special outputs
1615 else:
1616 main_outputs.append(result_slice) # Single output
1618 # Stack main outputs back into 3D array
1619 result = stack_slices(main_outputs, memory_type, gpu_id)
1621 # If we have special outputs, combine them and return tuple
1622 if special_outputs_list:
1623 # Combine special outputs from all slices
1624 combined_special_outputs = []
1625 num_special_outputs = len(special_outputs_list[0])
1627 for i in range(num_special_outputs):
1628 # Collect the i-th special output from all slices
1629 special_output_values = [slice_outputs[i] for slice_outputs in special_outputs_list]
1630 combined_special_outputs.append(special_output_values)
1632 # Return tuple: (stacked_main_output, combined_special_output1, combined_special_output2, ...)
1633 result = (result, *combined_special_outputs)
1634 else:
1635 # Normal 3D processing
1636 result = original_func(image_3d, *args, **kwargs)
1638 # Check if result is 2D and needs expansion to 3D
1639 if hasattr(result, 'ndim') and result.ndim == 2:
1640 # Expand 2D result to 3D single slice
1641 try:
1642 # Concatenate with itself to create 3D, then take first slice
1643 temp_3d = cle.concatenate_along_z(result, result) # Creates (2, Y, X)
1644 result = temp_3d[0:1, :, :] # Take first slice to get (1, Y, X)
1645 except Exception:
1646 # If expansion fails, return original 2D result
1647 # This maintains backward compatibility
1648 pass
1650 # Apply dtype conversion based on enum value
1651 if hasattr(result, 'dtype') and hasattr(result, 'shape') and dtype_conversion is not None:
1652 if dtype_conversion == DtypeConversion.PRESERVE_INPUT:
1653 # Preserve input dtype
1654 if result.dtype != original_dtype:
1655 return _scale_and_convert_pyclesperanto(result, original_dtype)
1656 return result
1658 elif dtype_conversion == DtypeConversion.NATIVE_OUTPUT:
1659 # Return pyclesperanto's native output dtype
1660 return result
1662 else:
1663 # Force specific dtype
1664 target_dtype = dtype_conversion.numpy_dtype
1665 if target_dtype is not None and result.dtype != target_dtype:
1666 return _scale_and_convert_pyclesperanto(result, target_dtype)
1667 return result
1668 else:
1669 # Non-array result, return as-is
1670 return result
1672 except Exception as e:
1673 logger.error(f"Error in pyclesperanto dtype/slice preserving wrapper for {func_name}: {e}")
1674 # If anything goes wrong, fall back to original function
1675 return original_func(image_3d, *args, **kwargs)
1677 # Update function signature to include new parameters
1678 try:
1679 original_sig = inspect.signature(original_func)
1680 new_params = list(original_sig.parameters.values())
1682 # Add slice_by_slice parameter if not already present
1683 param_names = [p.name for p in new_params]
1684 if 'slice_by_slice' not in param_names: 1684 ↛ 1694line 1684 didn't jump to line 1694 because the condition on line 1684 was always true
1685 slice_param = inspect.Parameter(
1686 'slice_by_slice',
1687 inspect.Parameter.KEYWORD_ONLY,
1688 default=False,
1689 annotation=bool
1690 )
1691 new_params.append(slice_param)
1693 # Add dtype_conversion parameter if DtypeConversion is available
1694 if DtypeConversion is not None and 'dtype_conversion' not in param_names: 1694 ↛ 1703line 1694 didn't jump to line 1703 because the condition on line 1694 was always true
1695 dtype_param = inspect.Parameter(
1696 'dtype_conversion',
1697 inspect.Parameter.KEYWORD_ONLY,
1698 default=DtypeConversion.PRESERVE_INPUT,
1699 annotation=DtypeConversion
1700 )
1701 new_params.append(dtype_param)
1703 new_sig = original_sig.replace(parameters=new_params)
1704 pyclesperanto_dtype_and_slice_preserving_wrapper.__signature__ = new_sig
1706 # Set type annotations manually for get_type_hints() compatibility
1707 pyclesperanto_dtype_and_slice_preserving_wrapper.__annotations__ = getattr(original_func, '__annotations__', {}).copy()
1708 pyclesperanto_dtype_and_slice_preserving_wrapper.__annotations__['slice_by_slice'] = bool
1709 if DtypeConversion is not None: 1709 ↛ 1717line 1709 didn't jump to line 1717 because the condition on line 1709 was always true
1710 pyclesperanto_dtype_and_slice_preserving_wrapper.__annotations__['dtype_conversion'] = DtypeConversion
1712 except Exception:
1713 # If signature modification fails, continue without it
1714 pass
1716 # Update docstring to mention additional parameters
1717 original_doc = pyclesperanto_dtype_and_slice_preserving_wrapper.__doc__ or ""
1718 additional_doc = """
1720 Additional OpenHCS Parameters
1721 -----------------------------
1722 slice_by_slice : bool, optional (default: False)
1723 If True, process 3D arrays slice-by-slice to avoid cross-slice contamination.
1724 If False, use original 3D behavior. Recommended for edge detection functions
1725 on stitched microscopy data to prevent artifacts at field boundaries.
1727 dtype_conversion : DtypeConversion, optional (default: PRESERVE_INPUT)
1728 Controls output data type conversion:
1730 - PRESERVE_INPUT: Keep input dtype (uint16 → uint16)
1731 - NATIVE_OUTPUT: Use pyclesperanto's native output (often float32)
1732 - UINT8: Force 8-bit unsigned integer (0-255 range)
1733 - UINT16: Force 16-bit unsigned integer (microscopy standard)
1734 - INT16: Force 16-bit signed integer
1735 - INT32: Force 32-bit signed integer
1736 - FLOAT32: Force 32-bit float (GPU performance)
1737 - FLOAT64: Force 64-bit float (maximum precision)
1738 """
1739 pyclesperanto_dtype_and_slice_preserving_wrapper.__doc__ = original_doc + additional_doc
1741 return pyclesperanto_dtype_and_slice_preserving_wrapper
1744def _scale_and_convert_pyclesperanto(result, target_dtype):
1745 """
1746 Scale and convert pyclesperanto array to target dtype.
1747 This is a simplified version of the helper function from pyclesperanto_registry.py
1748 """
1749 try:
1750 cle = optional_import("pyclesperanto")
1751 if cle is None:
1752 return result
1754 import numpy as np
1756 # If result is floating point and target is integer, scale appropriately
1757 if np.issubdtype(result.dtype, np.floating) and not np.issubdtype(target_dtype, np.floating):
1758 # Convert to numpy for scaling, then back to pyclesperanto
1759 result_np = cle.pull(result)
1761 # Clip to [0, 1] range and scale to integer range
1762 clipped = np.clip(result_np, 0, 1)
1763 if target_dtype == np.uint8:
1764 scaled = (clipped * 255).astype(target_dtype)
1765 elif target_dtype == np.uint16:
1766 scaled = (clipped * 65535).astype(target_dtype)
1767 elif target_dtype == np.uint32:
1768 scaled = (clipped * 4294967295).astype(target_dtype)
1769 else:
1770 # For other integer types, just convert without scaling
1771 scaled = clipped.astype(target_dtype)
1773 # Push back to GPU
1774 return cle.push(scaled)
1775 else:
1776 # Direct conversion for same numeric type families
1777 result_np = cle.pull(result)
1778 converted = result_np.astype(target_dtype)
1779 return cle.push(converted)
1781 except Exception:
1782 # If conversion fails, return original result
1783 return result