Coverage for openhcs/core/memory/decorators.py: 28.9%

763 statements  

« prev     ^ index     » next       coverage.py v7.10.7, created at 2025-10-01 18:33 +0000

1""" 

2Memory type declaration decorators for OpenHCS. 

3 

4This module provides decorators for explicitly declaring the memory interface 

5of pure functions, enforcing Clause 106-A (Declared Memory Types) and supporting 

6memory-type-aware dispatching and orchestration. 

7 

8These decorators annotate functions with input_memory_type and output_memory_type 

9attributes and provide automatic thread-local CUDA stream management for GPU 

10frameworks to enable true parallelization across multiple threads. 

11""" 

12 

13import functools 

14import logging 

15import threading 

16from typing import Any, Callable, Optional, TypeVar 

17 

18from openhcs.constants.constants import VALID_MEMORY_TYPES 

19from openhcs.core.utils import optional_import 

20from openhcs.core.memory.oom_recovery import _execute_with_oom_recovery 

21 

22# Direct import for default contract (inlined single-use method per RST principle) 

23 

24logger = logging.getLogger(__name__) 

25 

26F = TypeVar('F', bound=Callable[..., Any]) 

27 

28# Dtype conversion enum and utilities for consistent dtype handling across all frameworks 

29from enum import Enum 

30import numpy as np 

31 

32class DtypeConversion(Enum): 

33 """Data type conversion modes for all memory type functions.""" 

34 

35 PRESERVE_INPUT = "preserve" # Keep input dtype (default) 

36 NATIVE_OUTPUT = "native" # Use framework's native output 

37 UINT8 = "uint8" # Force uint8 (0-255 range) 

38 UINT16 = "uint16" # Force uint16 (microscopy standard) 

39 INT16 = "int16" # Force int16 (signed microscopy data) 

40 INT32 = "int32" # Force int32 (large integer values) 

41 FLOAT32 = "float32" # Force float32 (GPU performance) 

42 FLOAT64 = "float64" # Force float64 (maximum precision) 

43 

44 @property 

45 def numpy_dtype(self): 

46 """Get the corresponding numpy dtype.""" 

47 dtype_map = { 

48 self.UINT8: np.uint8, 

49 self.UINT16: np.uint16, 

50 self.INT16: np.int16, 

51 self.INT32: np.int32, 

52 self.FLOAT32: np.float32, 

53 self.FLOAT64: np.float64, 

54 } 

55 return dtype_map.get(self, None) 

56 

57 

58def _scale_and_convert_numpy(result, target_dtype): 

59 """Scale numpy results to target integer range and convert dtype.""" 

60 if not hasattr(result, 'dtype'): 

61 return result 

62 

63 # Check if result is floating point and target is integer 

64 result_is_float = np.issubdtype(result.dtype, np.floating) 

65 target_is_int = target_dtype in [np.uint8, np.uint16, np.uint32, np.int8, np.int16, np.int32] 

66 

67 if result_is_float and target_is_int: 

68 # Scale floating point results to integer range 

69 result_min = result.min() 

70 result_max = result.max() 

71 

72 if result_max > result_min: # Avoid division by zero 

73 # Normalize to [0, 1] range 

74 normalized = (result - result_min) / (result_max - result_min) 

75 

76 # Scale to target dtype range 

77 if target_dtype == np.uint8: 

78 scaled = normalized * 255.0 

79 elif target_dtype == np.uint16: 

80 scaled = normalized * 65535.0 

81 elif target_dtype == np.uint32: 

82 scaled = normalized * 4294967295.0 

83 elif target_dtype == np.int16: 

84 scaled = normalized * 65535.0 - 32768.0 

85 elif target_dtype == np.int32: 

86 scaled = normalized * 4294967295.0 - 2147483648.0 

87 else: 

88 scaled = normalized 

89 

90 return scaled.astype(target_dtype) 

91 else: 

92 # Constant image, just convert dtype 

93 return result.astype(target_dtype) 

94 else: 

95 # Direct conversion for compatible types 

96 return result.astype(target_dtype) 

97 

98 

99def _scale_and_convert_pyclesperanto(result, target_dtype): 

100 """Scale pyclesperanto results to target integer range and convert dtype.""" 

101 try: 

102 import pyclesperanto as cle 

103 except ImportError: 

104 return result 

105 

106 if not hasattr(result, 'dtype'): 

107 return result 

108 

109 # Check if result is floating point and target is integer 

110 result_is_float = np.issubdtype(result.dtype, np.floating) 

111 target_is_int = target_dtype in [np.uint8, np.uint16, np.uint32, np.int8, np.int16, np.int32] 

112 

113 if result_is_float and target_is_int: 

114 # Get min/max of result for proper scaling 

115 result_min = float(cle.minimum_of_all_pixels(result)) 

116 result_max = float(cle.maximum_of_all_pixels(result)) 

117 

118 if result_max > result_min: # Avoid division by zero 

119 # Normalize to [0, 1] range 

120 normalized = cle.subtract_image_from_scalar(result, scalar=result_min) 

121 range_val = result_max - result_min 

122 normalized = cle.multiply_image_and_scalar(normalized, scalar=1.0/range_val) 

123 

124 # Scale to target dtype range 

125 if target_dtype == np.uint8: 

126 scaled = cle.multiply_image_and_scalar(normalized, scalar=255.0) 

127 elif target_dtype == np.uint16: 

128 scaled = cle.multiply_image_and_scalar(normalized, scalar=65535.0) 

129 elif target_dtype == np.uint32: 

130 scaled = cle.multiply_image_and_scalar(normalized, scalar=4294967295.0) 

131 elif target_dtype == np.int16: 

132 scaled = cle.multiply_image_and_scalar(normalized, scalar=65535.0) 

133 scaled = cle.subtract_image_from_scalar(scaled, scalar=32768.0) 

134 elif target_dtype == np.int32: 

135 scaled = cle.multiply_image_and_scalar(normalized, scalar=4294967295.0) 

136 scaled = cle.subtract_image_from_scalar(scaled, scalar=2147483648.0) 

137 else: 

138 scaled = normalized 

139 

140 # Convert to target dtype using push/pull method 

141 scaled_cpu = cle.pull(scaled).astype(target_dtype) 

142 return cle.push(scaled_cpu) 

143 else: 

144 # Constant image, just convert dtype 

145 result_cpu = cle.pull(result).astype(target_dtype) 

146 return cle.push(result_cpu) 

147 else: 

148 # Direct conversion for compatible types 

149 result_cpu = cle.pull(result).astype(target_dtype) 

150 return cle.push(result_cpu) 

151 

152 

153def _scale_and_convert_cupy(result, target_dtype): 

154 """Scale CuPy results to target integer range and convert dtype.""" 

155 try: 

156 import cupy as cp 

157 except ImportError: 

158 return result 

159 

160 if not hasattr(result, 'dtype'): 

161 return result 

162 

163 # If result is floating point and target is integer, scale appropriately 

164 if cp.issubdtype(result.dtype, cp.floating) and not cp.issubdtype(target_dtype, cp.floating): 

165 # Clip to [0, 1] range and scale to integer range 

166 clipped = cp.clip(result, 0, 1) 

167 if target_dtype == cp.uint8: 

168 return (clipped * 255).astype(target_dtype) 

169 elif target_dtype == cp.uint16: 

170 return (clipped * 65535).astype(target_dtype) 

171 elif target_dtype == cp.uint32: 

172 return (clipped * 4294967295).astype(target_dtype) 

173 else: 

174 # For other integer types, just convert without scaling 

175 return result.astype(target_dtype) 

176 

177 # Direct conversion for same numeric type families 

178 return result.astype(target_dtype) 

179 

180 

181# GPU frameworks imported lazily to prevent thread explosion 

182# These will be imported only when actually needed by functions 

183_gpu_frameworks_cache = {} 

184 

185def _get_cupy(): 

186 """Lazy import CuPy only when needed.""" 

187 if 'cupy' not in _gpu_frameworks_cache: 

188 _gpu_frameworks_cache['cupy'] = optional_import("cupy") 

189 if _gpu_frameworks_cache['cupy'] is not None: 

190 logger.debug(f"🔧 Lazy imported CuPy in thread {threading.current_thread().name}") 

191 return _gpu_frameworks_cache['cupy'] 

192 

193def _get_torch(): 

194 """Lazy import PyTorch only when needed.""" 

195 if 'torch' not in _gpu_frameworks_cache: 

196 _gpu_frameworks_cache['torch'] = optional_import("torch") 

197 if _gpu_frameworks_cache['torch'] is not None: 

198 logger.debug(f"🔧 Lazy imported PyTorch in thread {threading.current_thread().name}") 

199 return _gpu_frameworks_cache['torch'] 

200 

201def _get_tensorflow(): 

202 """Lazy import TensorFlow only when needed.""" 

203 if 'tensorflow' not in _gpu_frameworks_cache: 

204 _gpu_frameworks_cache['tensorflow'] = optional_import("tensorflow") 

205 if _gpu_frameworks_cache['tensorflow'] is not None: 

206 logger.debug(f"🔧 Lazy imported TensorFlow in thread {threading.current_thread().name}") 

207 return _gpu_frameworks_cache['tensorflow'] 

208 

209def _get_jax(): 

210 """Lazy import JAX only when needed.""" 

211 if 'jax' not in _gpu_frameworks_cache: 

212 _gpu_frameworks_cache['jax'] = optional_import("jax") 

213 if _gpu_frameworks_cache['jax'] is not None: 

214 logger.debug(f"🔧 Lazy imported JAX in thread {threading.current_thread().name}") 

215 return _gpu_frameworks_cache['jax'] 

216 

217# Thread-local storage for GPU streams and contexts 

218_thread_gpu_contexts = threading.local() 

219 

220class ThreadGPUContext: 

221 """Unified thread-local GPU context manager to prevent stream leaks.""" 

222 

223 def __init__(self): 

224 self._cupy_stream = None 

225 self._torch_stream = None 

226 self._thread_name = threading.current_thread().name 

227 

228 def get_cupy_stream(self): 

229 """Get or create the single CuPy stream for this thread.""" 

230 if self._cupy_stream is None: 

231 cp = _get_cupy() 

232 if cp is not None and hasattr(cp, 'cuda'): 

233 self._cupy_stream = cp.cuda.Stream() 

234 logger.debug(f"🔧 Created CuPy stream for thread {self._thread_name}") 

235 return self._cupy_stream 

236 

237 def get_torch_stream(self): 

238 """Get or create the single PyTorch stream for this thread.""" 

239 if self._torch_stream is None: 

240 torch = _get_torch() 

241 if torch is not None and hasattr(torch, 'cuda') and torch.cuda.is_available(): 

242 self._torch_stream = torch.cuda.Stream() 

243 logger.debug(f"🔧 Created PyTorch stream for thread {self._thread_name}") 

244 return self._torch_stream 

245 

246 def cleanup(self): 

247 """Clean up streams when thread exits.""" 

248 if self._cupy_stream is not None: 

249 logger.debug(f"🔧 Cleaning up CuPy stream for thread {self._thread_name}") 

250 self._cupy_stream = None 

251 

252 if self._torch_stream is not None: 

253 logger.debug(f"🔧 Cleaning up PyTorch stream for thread {self._thread_name}") 

254 self._torch_stream = None 

255 

256def get_thread_gpu_context() -> ThreadGPUContext: 

257 """Get the unified GPU context for the current thread.""" 

258 if not hasattr(_thread_gpu_contexts, 'gpu_context'): 

259 _thread_gpu_contexts.gpu_context = ThreadGPUContext() 

260 

261 # Register cleanup for when thread exits 

262 import weakref 

263 def cleanup_on_thread_exit(): 

264 if hasattr(_thread_gpu_contexts, 'gpu_context'): 

265 _thread_gpu_contexts.gpu_context.cleanup() 

266 

267 # Use weakref to avoid circular references 

268 current_thread = threading.current_thread() 

269 if hasattr(current_thread, '_cleanup_funcs'): 

270 current_thread._cleanup_funcs.append(cleanup_on_thread_exit) 

271 else: 

272 current_thread._cleanup_funcs = [cleanup_on_thread_exit] 

273 

274 return _thread_gpu_contexts.gpu_context 

275 

276 

277def memory_types(*, input_type: str, output_type: str, contract: Optional['ProcessingContract'] = None) -> Callable[[F], F]: 

278 """ 

279 Decorator that explicitly declares the memory types for a function's input and output. 

280 

281 This decorator enforces Clause 106-A (Declared Memory Types) by requiring explicit 

282 memory type declarations for both input and output. 

283 

284 Args: 

285 input_type: The memory type for the function's input (e.g., "numpy", "cupy") 

286 output_type: The memory type for the function's output (e.g., "numpy", "cupy") 

287 contract: Optional processing contract declaration (defaults to PURE_3D) 

288 

289 Returns: 

290 A decorator function that sets the memory type attributes 

291 

292 Raises: 

293 ValueError: If input_type or output_type is not a supported memory type 

294 """ 

295 # 🔒 Clause 88 — No Inferred Capabilities 

296 # Validate memory types at decoration time, not runtime 

297 if not input_type: 297 ↛ 298line 297 didn't jump to line 298 because the condition on line 297 was never true

298 raise ValueError( 

299 "Clause 106-A Violation: input_type must be explicitly declared. " 

300 "No default or inferred memory types are allowed." 

301 ) 

302 

303 if not output_type: 303 ↛ 304line 303 didn't jump to line 304 because the condition on line 303 was never true

304 raise ValueError( 

305 "Clause 106-A Violation: output_type must be explicitly declared. " 

306 "No default or inferred memory types are allowed." 

307 ) 

308 

309 # Validate that memory types are supported 

310 if input_type not in VALID_MEMORY_TYPES: 310 ↛ 311line 310 didn't jump to line 311 because the condition on line 310 was never true

311 raise ValueError( 

312 f"Clause 106-A Violation: input_type '{input_type}' is not supported. " 

313 f"Supported types are: {', '.join(sorted(VALID_MEMORY_TYPES))}" 

314 ) 

315 

316 if output_type not in VALID_MEMORY_TYPES: 316 ↛ 317line 316 didn't jump to line 317 because the condition on line 316 was never true

317 raise ValueError( 

318 f"Clause 106-A Violation: output_type '{output_type}' is not supported. " 

319 f"Supported types are: {', '.join(sorted(VALID_MEMORY_TYPES))}" 

320 ) 

321 

322 def decorator(func: F) -> F: 

323 """ 

324 Decorator function that sets memory type attributes on the function. 

325 

326 Args: 

327 func: The function to decorate 

328 

329 Returns: 

330 The decorated function with memory type attributes set 

331 

332 Raises: 

333 ValueError: If the function already has different memory type attributes 

334 """ 

335 # 🔒 Clause 66 — Immutability 

336 # Check if memory type attributes already exist 

337 if hasattr(func, 'input_memory_type') and func.input_memory_type != input_type: 337 ↛ 338line 337 didn't jump to line 338 because the condition on line 337 was never true

338 raise ValueError( 

339 f"Clause 66 Violation: Function '{func.__name__}' already has input_memory_type " 

340 f"'{func.input_memory_type}', cannot change to '{input_type}'." 

341 ) 

342 

343 if hasattr(func, 'output_memory_type') and func.output_memory_type != output_type: 343 ↛ 344line 343 didn't jump to line 344 because the condition on line 343 was never true

344 raise ValueError( 

345 f"Clause 66 Violation: Function '{func.__name__}' already has output_memory_type " 

346 f"'{func.output_memory_type}', cannot change to '{output_type}'." 

347 ) 

348 

349 # Set memory type attributes using canonical names 

350 # 🔒 Clause 106-A.2 — Canonical Memory Type Attributes 

351 func.input_memory_type = input_type 

352 func.output_memory_type = output_type 

353 

354 # Set processing contract with fail-loud behavior (inlined per RST principle) 

355 if contract is None: 355 ↛ 359line 355 didn't jump to line 359 because the condition on line 355 was always true

356 from openhcs.processing.backends.lib_registry.unified_registry import ProcessingContract 

357 func.__processing_contract__ = ProcessingContract.PURE_3D 

358 else: 

359 func.__processing_contract__ = contract 

360 

361 # Return the function unchanged (no wrapper) 

362 return func 

363 

364 return decorator 

365 

366 

367def numpy( 

368 func: Optional[F] = None, 

369 *, 

370 input_type: str = "numpy", 

371 output_type: str = "numpy", 

372 contract: Optional['ProcessingContract'] = None 

373) -> Any: 

374 """ 

375 Decorator that declares a function as operating on numpy arrays. 

376 

377 This is a convenience wrapper around memory_types with numpy defaults. 

378 

379 Args: 

380 func: The function to decorate (optional) 

381 input_type: The memory type for the function's input (default: "numpy") 

382 output_type: The memory type for the function's output (default: "numpy") 

383 contract: Optional processing contract declaration (defaults to PURE_3D) 

384 

385 Returns: 

386 The decorated function with memory type attributes set 

387 

388 Raises: 

389 ValueError: If input_type or output_type is not a supported memory type 

390 """ 

391 def decorator_with_dtype_preservation(func: F) -> F: 

392 # Set memory type attributes and contract 

393 memory_decorator = memory_types(input_type=input_type, output_type=output_type, contract=contract) 

394 func = memory_decorator(func) 

395 

396 # Apply dtype preservation wrapper 

397 func = _create_numpy_dtype_preserving_wrapper(func, func.__name__) 

398 

399 return func 

400 

401 # Handle both @numpy and @numpy(input_type=..., output_type=...) forms 

402 if func is None: 402 ↛ 403line 402 didn't jump to line 403 because the condition on line 402 was never true

403 return decorator_with_dtype_preservation 

404 

405 return decorator_with_dtype_preservation(func) 

406 

407 

408def cupy(func: Optional[F] = None, *, input_type: str = "cupy", output_type: str = "cupy", oom_recovery: bool = True, contract: Optional['ProcessingContract'] = None) -> Any: 

409 """ 

410 Decorator that declares a function as operating on cupy arrays. 

411 

412 This decorator provides automatic thread-local CUDA stream management for 

413 true parallelization across multiple threads. Each thread gets its own 

414 persistent CUDA stream that is reused for all CuPy operations. 

415 

416 Args: 

417 func: The function to decorate (optional) 

418 input_type: The memory type for the function's input (default: "cupy") 

419 output_type: The memory type for the function's output (default: "cupy") 

420 oom_recovery: Enable automatic OOM recovery (default: True) 

421 contract: Optional processing contract declaration (defaults to PURE_3D) 

422 

423 Returns: 

424 The decorated function with memory type attributes and stream management 

425 

426 Raises: 

427 ValueError: If input_type or output_type is not a supported memory type 

428 """ 

429 def decorator(func: F) -> F: 

430 # Set memory type attributes and contract 

431 memory_decorator = memory_types(input_type=input_type, output_type=output_type, contract=contract) 

432 func = memory_decorator(func) 

433 

434 # Apply dtype preservation wrapper 

435 func = _create_cupy_dtype_preserving_wrapper(func, func.__name__) 

436 

437 # Add CUDA stream wrapper if CuPy is available (lazy import) 

438 @functools.wraps(func) 

439 def wrapper(*args, **kwargs): 

440 cp = _get_cupy() 

441 if cp is not None and hasattr(cp, 'cuda'): 

442 # Get unified thread context and CuPy stream 

443 gpu_context = get_thread_gpu_context() 

444 cupy_stream = gpu_context.get_cupy_stream() 

445 

446 def execute_with_stream(): 

447 if cupy_stream is not None: 

448 # Execute function in stream context 

449 with cupy_stream: 

450 return func(*args, **kwargs) 

451 else: 

452 # No CUDA available, execute without stream 

453 return func(*args, **kwargs) 

454 

455 # Execute with OOM recovery if enabled 

456 if oom_recovery: 

457 return _execute_with_oom_recovery(execute_with_stream, input_type) 

458 else: 

459 return execute_with_stream() 

460 else: 

461 # CuPy not available, execute without stream 

462 return func(*args, **kwargs) 

463 

464 # Preserve memory type attributes 

465 wrapper.input_memory_type = func.input_memory_type 

466 wrapper.output_memory_type = func.output_memory_type 

467 

468 return wrapper 

469 

470 # Handle both @cupy and @cupy(input_type=..., output_type=...) forms 

471 if func is None: 471 ↛ 472line 471 didn't jump to line 472 because the condition on line 471 was never true

472 return decorator 

473 

474 return decorator(func) 

475 

476 

477def torch( 

478 func: Optional[F] = None, 

479 *, 

480 input_type: str = "torch", 

481 output_type: str = "torch", 

482 oom_recovery: bool = True, 

483 contract: Optional['ProcessingContract'] = None 

484) -> Any: 

485 """ 

486 Decorator that declares a function as operating on torch tensors. 

487 

488 This decorator provides automatic thread-local CUDA stream management for 

489 true parallelization across multiple threads. Each thread gets its own 

490 persistent CUDA stream that is reused for all PyTorch operations. 

491 

492 Args: 

493 func: The function to decorate (optional) 

494 input_type: The memory type for the function's input (default: "torch") 

495 output_type: The memory type for the function's output (default: "torch") 

496 oom_recovery: Enable automatic OOM recovery (default: True) 

497 contract: Optional processing contract declaration (defaults to PURE_3D) 

498 

499 Returns: 

500 The decorated function with memory type attributes and stream management 

501 

502 Raises: 

503 ValueError: If input_type or output_type is not a supported memory type 

504 """ 

505 def decorator(func: F) -> F: 

506 # Set memory type attributes and contract 

507 memory_decorator = memory_types(input_type=input_type, output_type=output_type, contract=contract) 

508 func = memory_decorator(func) 

509 

510 # Apply dtype preservation wrapper 

511 func = _create_torch_dtype_preserving_wrapper(func, func.__name__) 

512 

513 # Add CUDA stream wrapper if PyTorch is available and CUDA is available (lazy import) 

514 @functools.wraps(func) 

515 def wrapper(*args, **kwargs): 

516 torch = _get_torch() 

517 if torch is not None and hasattr(torch, 'cuda') and torch.cuda.is_available(): 

518 # Get unified thread context and PyTorch stream 

519 gpu_context = get_thread_gpu_context() 

520 torch_stream = gpu_context.get_torch_stream() 

521 

522 def execute_with_stream(): 

523 if torch_stream is not None: 

524 # Execute function in stream context 

525 with torch.cuda.stream(torch_stream): 

526 return func(*args, **kwargs) 

527 else: 

528 # No CUDA available, execute without stream 

529 return func(*args, **kwargs) 

530 

531 # Execute with OOM recovery if enabled 

532 if oom_recovery: 

533 return _execute_with_oom_recovery(execute_with_stream, input_type) 

534 else: 

535 return execute_with_stream() 

536 else: 

537 # PyTorch not available or CUDA not available, execute without stream 

538 return func(*args, **kwargs) 

539 

540 # Preserve memory type attributes 

541 wrapper.input_memory_type = func.input_memory_type 

542 wrapper.output_memory_type = func.output_memory_type 

543 

544 return wrapper 

545 

546 # Handle both @torch and @torch(input_type=..., output_type=...) forms 

547 if func is None: 547 ↛ 548line 547 didn't jump to line 548 because the condition on line 547 was never true

548 return decorator 

549 

550 return decorator(func) 

551 

552 

553def tensorflow( 

554 func: Optional[F] = None, 

555 *, 

556 input_type: str = "tensorflow", 

557 output_type: str = "tensorflow", 

558 oom_recovery: bool = True, 

559 contract: Optional['ProcessingContract'] = None 

560) -> Any: 

561 """ 

562 Decorator that declares a function as operating on tensorflow tensors. 

563 

564 This decorator provides automatic thread-local GPU device context management 

565 for parallelization across multiple threads. TensorFlow manages CUDA streams 

566 internally, so we use device contexts for thread isolation. 

567 

568 Args: 

569 func: The function to decorate (optional) 

570 input_type: The memory type for the function's input (default: "tensorflow") 

571 output_type: The memory type for the function's output (default: "tensorflow") 

572 oom_recovery: Enable automatic OOM recovery (default: True) 

573 contract: Optional processing contract declaration (defaults to PURE_3D) 

574 

575 Returns: 

576 The decorated function with memory type attributes and device management 

577 

578 Raises: 

579 ValueError: If input_type or output_type is not a supported memory type 

580 """ 

581 def decorator(func: F) -> F: 

582 # Set memory type attributes and contract 

583 memory_decorator = memory_types(input_type=input_type, output_type=output_type, contract=contract) 

584 func = memory_decorator(func) 

585 

586 # Apply dtype preservation wrapper 

587 func = _create_tensorflow_dtype_preserving_wrapper(func, func.__name__) 

588 

589 # Add device context wrapper if TensorFlow is available and GPU is available (lazy import) 

590 @functools.wraps(func) 

591 def wrapper(*args, **kwargs): 

592 tf = _get_tensorflow() 

593 if tf is not None and tf.config.list_physical_devices('GPU'): 

594 def execute_with_device(): 

595 # Use GPU device context for thread isolation 

596 # TensorFlow manages internal CUDA streams automatically 

597 with tf.device('/GPU:0'): 

598 return func(*args, **kwargs) 

599 

600 # Execute with OOM recovery if enabled 

601 if oom_recovery: 

602 return _execute_with_oom_recovery(execute_with_device, input_type) 

603 else: 

604 return execute_with_device() 

605 else: 

606 # TensorFlow not available or GPU not available, execute without device context 

607 return func(*args, **kwargs) 

608 

609 # Preserve memory type attributes 

610 wrapper.input_memory_type = func.input_memory_type 

611 wrapper.output_memory_type = func.output_memory_type 

612 

613 return wrapper 

614 

615 # Handle both @tensorflow and @tensorflow(input_type=..., output_type=...) forms 

616 if func is None: 616 ↛ 617line 616 didn't jump to line 617 because the condition on line 616 was never true

617 return decorator 

618 

619 return decorator(func) 

620 

621 

622def jax( 

623 func: Optional[F] = None, 

624 *, 

625 input_type: str = "jax", 

626 output_type: str = "jax", 

627 oom_recovery: bool = True, 

628 contract: Optional['ProcessingContract'] = None 

629) -> Any: 

630 """ 

631 Decorator that declares a function as operating on JAX arrays. 

632 

633 This decorator provides automatic thread-local GPU device placement for 

634 parallelization across multiple threads. JAX/XLA manages CUDA streams 

635 internally, so we use device placement for thread isolation. 

636 

637 Args: 

638 func: The function to decorate (optional) 

639 input_type: The memory type for the function's input (default: "jax") 

640 output_type: The memory type for the function's output (default: "jax") 

641 oom_recovery: Enable automatic OOM recovery (default: True) 

642 contract: Optional processing contract declaration (defaults to PURE_3D) 

643 

644 Returns: 

645 The decorated function with memory type attributes and device management 

646 

647 Raises: 

648 ValueError: If input_type or output_type is not a supported memory type 

649 """ 

650 def decorator(func: F) -> F: 

651 # Set memory type attributes and contract 

652 memory_decorator = memory_types(input_type=input_type, output_type=output_type, contract=contract) 

653 func = memory_decorator(func) 

654 

655 # Apply dtype preservation wrapper 

656 func = _create_jax_dtype_preserving_wrapper(func, func.__name__) 

657 

658 # Add device placement wrapper if JAX is available and GPU is available (lazy import) 

659 @functools.wraps(func) 

660 def wrapper(*args, **kwargs): 

661 jax_module = _get_jax() 

662 if jax_module is not None: 

663 devices = jax_module.devices() 

664 gpu_devices = [d for d in devices if d.platform == 'gpu'] 

665 

666 if gpu_devices: 

667 def execute_with_device(): 

668 # Use GPU device placement for thread isolation 

669 # JAX/XLA manages internal CUDA streams automatically 

670 with jax_module.default_device(gpu_devices[0]): 

671 return func(*args, **kwargs) 

672 

673 # Execute with OOM recovery if enabled 

674 if oom_recovery: 

675 return _execute_with_oom_recovery(execute_with_device, input_type) 

676 else: 

677 return execute_with_device() 

678 else: 

679 # No GPU devices available, execute without device placement 

680 return func(*args, **kwargs) 

681 else: 

682 # JAX not available, execute without device placement 

683 return func(*args, **kwargs) 

684 

685 # Preserve memory type attributes 

686 wrapper.input_memory_type = func.input_memory_type 

687 wrapper.output_memory_type = func.output_memory_type 

688 

689 return wrapper 

690 

691 # Handle both @jax and @jax(input_type=..., output_type=...) forms 

692 if func is None: 692 ↛ 693line 692 didn't jump to line 693 because the condition on line 692 was never true

693 return decorator 

694 

695 return decorator(func) 

696 

697 

698def pyclesperanto( 

699 func: Optional[F] = None, 

700 *, 

701 input_type: str = "pyclesperanto", 

702 output_type: str = "pyclesperanto", 

703 oom_recovery: bool = True, 

704 contract: Optional['ProcessingContract'] = None 

705) -> Any: 

706 """ 

707 Decorator that declares a function as operating on pyclesperanto GPU arrays. 

708 

709 This decorator provides automatic OOM recovery for pyclesperanto functions. 

710 

711 Args: 

712 func: The function to decorate (optional) 

713 input_type: The memory type for the function's input (default: "pyclesperanto") 

714 output_type: The memory type for the function's output (default: "pyclesperanto") 

715 oom_recovery: Enable automatic OOM recovery (default: True) 

716 contract: Optional processing contract declaration (defaults to PURE_3D) 

717 

718 Returns: 

719 The decorated function with memory type attributes and OOM recovery 

720 

721 Raises: 

722 ValueError: If input_type or output_type is not a supported memory type 

723 """ 

724 def decorator(func: F) -> F: 

725 # Set memory type attributes and contract 

726 memory_decorator = memory_types(input_type=input_type, output_type=output_type, contract=contract) 

727 func = memory_decorator(func) 

728 

729 # Apply dtype preservation wrapper 

730 func = _create_pyclesperanto_dtype_preserving_wrapper(func, func.__name__) 

731 

732 # Add OOM recovery wrapper 

733 @functools.wraps(func) 

734 def wrapper(*args, **kwargs): 

735 if oom_recovery: 

736 return _execute_with_oom_recovery(lambda: func(*args, **kwargs), input_type) 

737 else: 

738 return func(*args, **kwargs) 

739 

740 # Preserve memory type attributes 

741 wrapper.input_memory_type = func.input_memory_type 

742 wrapper.output_memory_type = func.output_memory_type 

743 

744 # Make wrapper pickleable by preserving original function identity 

745 wrapper.__module__ = getattr(func, '__module__', wrapper.__module__) 

746 wrapper.__qualname__ = getattr(func, '__qualname__', wrapper.__qualname__) 

747 

748 # Store reference to original function for pickle support 

749 wrapper.__wrapped__ = func 

750 

751 return wrapper 

752 

753 # Handle both @pyclesperanto and @pyclesperanto(input_type=..., output_type=...) forms 

754 if func is None: 754 ↛ 755line 754 didn't jump to line 755 because the condition on line 754 was never true

755 return decorator 

756 

757 return decorator(func) 

758 

759 

760# ============================================================================ 

761# Dtype Preservation Wrapper Functions 

762# ============================================================================ 

763 

764def _create_numpy_dtype_preserving_wrapper(original_func, func_name): 

765 """ 

766 Create a wrapper that preserves input data type and adds slice_by_slice parameter for NumPy functions. 

767 

768 Many scikit-image functions return float64 regardless of input type. 

769 This wrapper ensures the output has the same dtype as the input and adds 

770 a slice_by_slice parameter to avoid cross-slice contamination in 3D arrays. 

771 """ 

772 import numpy as np 

773 import inspect 

774 from functools import wraps 

775 

776 @wraps(original_func) 

777 def numpy_dtype_and_slice_preserving_wrapper(image, *args, dtype_conversion=None, slice_by_slice: bool = False, **kwargs): 

778 # Set default dtype_conversion if not provided and DtypeConversion is available 

779 if dtype_conversion is None and DtypeConversion is not None: 

780 dtype_conversion = DtypeConversion.PRESERVE_INPUT 

781 

782 try: 

783 # Store original dtype 

784 original_dtype = image.dtype 

785 

786 # Handle slice_by_slice processing for 3D arrays using OpenHCS stack utilities 

787 if slice_by_slice and hasattr(image, 'ndim') and image.ndim == 3: 

788 from openhcs.core.memory.stack_utils import unstack_slices, stack_slices, _detect_memory_type 

789 

790 # Detect memory type and use proper OpenHCS utilities 

791 memory_type = _detect_memory_type(image) 

792 gpu_id = 0 # Default GPU ID for slice processing 

793 

794 # Unstack 3D array into 2D slices 

795 slices_2d = unstack_slices(image, memory_type, gpu_id) 

796 

797 # Process each slice and handle special outputs 

798 main_outputs = [] 

799 special_outputs_list = [] 

800 

801 for slice_2d in slices_2d: 

802 slice_result = original_func(slice_2d, *args, **kwargs) 

803 

804 # Check if result is a tuple (indicating special outputs) 

805 if isinstance(slice_result, tuple): 

806 main_outputs.append(slice_result[0]) # First element is main output 

807 special_outputs_list.append(slice_result[1:]) # Rest are special outputs 

808 else: 

809 main_outputs.append(slice_result) # Single output 

810 

811 # Stack main outputs back into 3D array 

812 result = stack_slices(main_outputs, memory_type, gpu_id) 

813 

814 # If we have special outputs, combine them and return tuple 

815 if special_outputs_list: 

816 # Combine special outputs from all slices 

817 combined_special_outputs = [] 

818 num_special_outputs = len(special_outputs_list[0]) 

819 

820 for i in range(num_special_outputs): 

821 # Collect the i-th special output from all slices 

822 special_output_values = [slice_outputs[i] for slice_outputs in special_outputs_list] 

823 combined_special_outputs.append(special_output_values) 

824 

825 # Return tuple: (stacked_main_output, combined_special_output1, combined_special_output2, ...) 

826 result = (result, *combined_special_outputs) 

827 else: 

828 # Call the original function normally 

829 result = original_func(image, *args, **kwargs) 

830 

831 # Apply dtype conversion based on enum value 

832 if hasattr(result, 'dtype') and dtype_conversion is not None: 

833 if dtype_conversion == DtypeConversion.PRESERVE_INPUT: 

834 # Preserve input dtype 

835 if result.dtype != original_dtype: 

836 result = _scale_and_convert_numpy(result, original_dtype) 

837 elif dtype_conversion == DtypeConversion.NATIVE_OUTPUT: 

838 # Return NumPy's native output dtype 

839 pass # No conversion needed 

840 else: 

841 # Force specific dtype 

842 target_dtype = dtype_conversion.numpy_dtype 

843 if target_dtype is not None: 

844 result = _scale_and_convert_numpy(result, target_dtype) 

845 

846 return result 

847 except Exception as e: 

848 logger.error(f"Error in NumPy dtype/slice preserving wrapper for {func_name}: {e}") 

849 # Return original result on error 

850 return original_func(image, *args, **kwargs) 

851 

852 # Update function signature to include new parameters 

853 try: 

854 original_sig = inspect.signature(original_func) 

855 new_params = list(original_sig.parameters.values()) 

856 

857 # Check if slice_by_slice parameter already exists 

858 param_names = [p.name for p in new_params] 

859 # Add dtype_conversion parameter first (before slice_by_slice) 

860 if 'dtype_conversion' not in param_names: 860 ↛ 869line 860 didn't jump to line 869 because the condition on line 860 was always true

861 dtype_param = inspect.Parameter( 

862 'dtype_conversion', 

863 inspect.Parameter.KEYWORD_ONLY, 

864 default=DtypeConversion.PRESERVE_INPUT, 

865 annotation=DtypeConversion 

866 ) 

867 new_params.append(dtype_param) 

868 

869 if 'slice_by_slice' not in param_names: 869 ↛ 880line 869 didn't jump to line 880 because the condition on line 869 was always true

870 # Add slice_by_slice parameter as keyword-only (after dtype_conversion) 

871 slice_param = inspect.Parameter( 

872 'slice_by_slice', 

873 inspect.Parameter.KEYWORD_ONLY, 

874 default=False, 

875 annotation=bool 

876 ) 

877 new_params.append(slice_param) 

878 

879 # Create new signature and override the @wraps signature 

880 new_sig = original_sig.replace(parameters=new_params) 

881 numpy_dtype_and_slice_preserving_wrapper.__signature__ = new_sig 

882 

883 # Set type annotations manually for get_type_hints() compatibility 

884 numpy_dtype_and_slice_preserving_wrapper.__annotations__ = getattr(original_func, '__annotations__', {}).copy() 

885 numpy_dtype_and_slice_preserving_wrapper.__annotations__['dtype_conversion'] = DtypeConversion 

886 numpy_dtype_and_slice_preserving_wrapper.__annotations__['slice_by_slice'] = bool 

887 

888 except Exception: 

889 # If signature modification fails, continue without it 

890 pass 

891 

892 # Update docstring to mention slice_by_slice parameter 

893 original_doc = numpy_dtype_and_slice_preserving_wrapper.__doc__ or "" 

894 additional_doc = """ 

895 

896 Additional OpenHCS Parameters 

897 ----------------------------- 

898 slice_by_slice : bool, optional (default: False) 

899 If True, process 3D arrays slice-by-slice to avoid cross-slice contamination. 

900 If False, use original 3D behavior. Recommended for edge detection functions 

901 on stitched microscopy data to prevent artifacts at field boundaries. 

902 

903 dtype_conversion : DtypeConversion, optional (default: PRESERVE_INPUT) 

904 Controls output data type conversion: 

905 

906 - PRESERVE_INPUT: Keep input dtype (uint16 → uint16) 

907 - NATIVE_OUTPUT: Use NumPy's native output dtype 

908 - UINT8: Force 8-bit unsigned integer (0-255 range) 

909 - UINT16: Force 16-bit unsigned integer (microscopy standard) 

910 - INT16: Force 16-bit signed integer 

911 - INT32: Force 32-bit signed integer 

912 - FLOAT32: Force 32-bit float (GPU performance) 

913 - FLOAT64: Force 64-bit float (maximum precision) 

914 """ 

915 numpy_dtype_and_slice_preserving_wrapper.__doc__ = original_doc + additional_doc 

916 

917 return numpy_dtype_and_slice_preserving_wrapper 

918 

919 

920def _create_cupy_dtype_preserving_wrapper(original_func, func_name): 

921 """ 

922 Create a wrapper that preserves input data type and adds slice_by_slice parameter for CuPy functions. 

923 

924 This uses the SAME pattern as scikit-image for consistency. CuPy functions generally preserve 

925 dtypes better than scikit-image, but this wrapper ensures consistent behavior and adds 

926 slice_by_slice parameter to avoid cross-slice contamination in 3D arrays. 

927 """ 

928 import inspect 

929 from functools import wraps 

930 

931 @wraps(original_func) 

932 def cupy_dtype_and_slice_preserving_wrapper(image, *args, dtype_conversion=None, slice_by_slice: bool = False, **kwargs): 

933 # Set default dtype_conversion if not provided and DtypeConversion is available 

934 if dtype_conversion is None and DtypeConversion is not None: 

935 dtype_conversion = DtypeConversion.PRESERVE_INPUT 

936 

937 try: 

938 cupy = optional_import("cupy") 

939 if cupy is None: 

940 return original_func(image, *args, **kwargs) 

941 

942 # Store original dtype 

943 original_dtype = image.dtype 

944 

945 # Handle slice_by_slice processing for 3D arrays using OpenHCS stack utilities 

946 if slice_by_slice and hasattr(image, 'ndim') and image.ndim == 3: 

947 from openhcs.core.memory.stack_utils import unstack_slices, stack_slices, _detect_memory_type 

948 

949 # Detect memory type and use proper OpenHCS utilities 

950 memory_type = _detect_memory_type(image) 

951 gpu_id = image.device.id if hasattr(image, 'device') else 0 

952 

953 # Unstack 3D array into 2D slices 

954 slices_2d = unstack_slices(image, memory_type, gpu_id) 

955 

956 # Process each slice and handle special outputs 

957 main_outputs = [] 

958 special_outputs_list = [] 

959 

960 for slice_2d in slices_2d: 

961 slice_result = original_func(slice_2d, *args, **kwargs) 

962 

963 # Check if result is a tuple (indicating special outputs) 

964 if isinstance(slice_result, tuple): 

965 main_outputs.append(slice_result[0]) # First element is main output 

966 special_outputs_list.append(slice_result[1:]) # Rest are special outputs 

967 else: 

968 main_outputs.append(slice_result) # Single output 

969 

970 # Stack main outputs back into 3D array 

971 result = stack_slices(main_outputs, memory_type, gpu_id) 

972 

973 # If we have special outputs, combine them and return tuple 

974 if special_outputs_list: 

975 # Combine special outputs from all slices 

976 combined_special_outputs = [] 

977 num_special_outputs = len(special_outputs_list[0]) 

978 

979 for i in range(num_special_outputs): 

980 # Collect the i-th special output from all slices 

981 special_output_values = [slice_outputs[i] for slice_outputs in special_outputs_list] 

982 combined_special_outputs.append(special_output_values) 

983 

984 # Return tuple: (stacked_main_output, combined_special_output1, combined_special_output2, ...) 

985 result = (result, *combined_special_outputs) 

986 else: 

987 # Call the original function normally 

988 result = original_func(image, *args, **kwargs) 

989 

990 # Apply dtype conversion based on enum value 

991 if hasattr(result, 'dtype') and dtype_conversion is not None: 

992 if dtype_conversion == DtypeConversion.PRESERVE_INPUT: 

993 # Preserve input dtype 

994 if result.dtype != original_dtype: 

995 result = _scale_and_convert_cupy(result, original_dtype) 

996 elif dtype_conversion == DtypeConversion.NATIVE_OUTPUT: 

997 # Return CuPy's native output dtype 

998 pass # No conversion needed 

999 else: 

1000 # Force specific dtype 

1001 target_dtype = dtype_conversion.numpy_dtype 

1002 if target_dtype is not None: 

1003 result = _scale_and_convert_cupy(result, target_dtype) 

1004 

1005 return result 

1006 except Exception as e: 

1007 logger.error(f"Error in CuPy dtype/slice preserving wrapper for {func_name}: {e}") 

1008 # Return original result on error 

1009 return original_func(image, *args, **kwargs) 

1010 

1011 # Update function signature to include new parameters 

1012 try: 

1013 original_sig = inspect.signature(original_func) 

1014 new_params = list(original_sig.parameters.values()) 

1015 

1016 # Check if slice_by_slice parameter already exists 

1017 param_names = [p.name for p in new_params] 

1018 # Add dtype_conversion parameter first (before slice_by_slice) 

1019 if 'dtype_conversion' not in param_names: 1019 ↛ 1028line 1019 didn't jump to line 1028 because the condition on line 1019 was always true

1020 dtype_param = inspect.Parameter( 

1021 'dtype_conversion', 

1022 inspect.Parameter.KEYWORD_ONLY, 

1023 default=DtypeConversion.PRESERVE_INPUT, 

1024 annotation=DtypeConversion 

1025 ) 

1026 new_params.append(dtype_param) 

1027 

1028 if 'slice_by_slice' not in param_names: 1028 ↛ 1039line 1028 didn't jump to line 1039 because the condition on line 1028 was always true

1029 # Add slice_by_slice parameter as keyword-only (after dtype_conversion) 

1030 slice_param = inspect.Parameter( 

1031 'slice_by_slice', 

1032 inspect.Parameter.KEYWORD_ONLY, 

1033 default=False, 

1034 annotation=bool 

1035 ) 

1036 new_params.append(slice_param) 

1037 

1038 # Create new signature and override the @wraps signature 

1039 new_sig = original_sig.replace(parameters=new_params) 

1040 cupy_dtype_and_slice_preserving_wrapper.__signature__ = new_sig 

1041 

1042 # Set type annotations manually for get_type_hints() compatibility 

1043 cupy_dtype_and_slice_preserving_wrapper.__annotations__ = getattr(original_func, '__annotations__', {}).copy() 

1044 cupy_dtype_and_slice_preserving_wrapper.__annotations__['dtype_conversion'] = DtypeConversion 

1045 cupy_dtype_and_slice_preserving_wrapper.__annotations__['slice_by_slice'] = bool 

1046 

1047 except Exception: 

1048 # If signature modification fails, continue without it 

1049 pass 

1050 

1051 # Update docstring to mention slice_by_slice parameter 

1052 original_doc = cupy_dtype_and_slice_preserving_wrapper.__doc__ or "" 

1053 additional_doc = """ 

1054 

1055 Additional OpenHCS Parameters 

1056 ----------------------------- 

1057 slice_by_slice : bool, optional (default: False) 

1058 If True, process 3D arrays slice-by-slice to avoid cross-slice contamination. 

1059 If False, use original 3D behavior. Recommended for edge detection functions 

1060 on stitched microscopy data to prevent artifacts at field boundaries. 

1061 

1062 dtype_conversion : DtypeConversion, optional (default: PRESERVE_INPUT) 

1063 Controls output data type conversion: 

1064 

1065 - PRESERVE_INPUT: Keep input dtype (uint16 → uint16) 

1066 - NATIVE_OUTPUT: Use CuPy's native output dtype 

1067 - UINT8: Force 8-bit unsigned integer (0-255 range) 

1068 - UINT16: Force 16-bit unsigned integer (microscopy standard) 

1069 - INT16: Force 16-bit signed integer 

1070 - INT32: Force 32-bit signed integer 

1071 - FLOAT32: Force 32-bit float (GPU performance) 

1072 - FLOAT64: Force 64-bit float (maximum precision) 

1073 """ 

1074 cupy_dtype_and_slice_preserving_wrapper.__doc__ = original_doc + additional_doc 

1075 

1076 return cupy_dtype_and_slice_preserving_wrapper 

1077 

1078 

1079def _create_torch_dtype_preserving_wrapper(original_func, func_name): 

1080 """ 

1081 Create a wrapper that preserves input data type and adds slice_by_slice parameter for PyTorch functions. 

1082 

1083 This follows the same pattern as existing dtype preservation wrappers for consistency. 

1084 PyTorch functions generally preserve dtypes well, but this wrapper ensures consistent behavior 

1085 and adds slice_by_slice parameter to avoid cross-slice contamination in 3D arrays. 

1086 """ 

1087 import inspect 

1088 from functools import wraps 

1089 

1090 @wraps(original_func) 

1091 def torch_dtype_and_slice_preserving_wrapper(image, *args, dtype_conversion=None, slice_by_slice: bool = False, **kwargs): 

1092 # Set default dtype_conversion if not provided 

1093 if dtype_conversion is None: 

1094 dtype_conversion = DtypeConversion.PRESERVE_INPUT 

1095 

1096 try: 

1097 torch = optional_import("torch") 

1098 if torch is None: 

1099 return original_func(image, *args, **kwargs) 

1100 

1101 # Store original dtype 

1102 original_dtype = image.dtype if hasattr(image, 'dtype') else None 

1103 

1104 # Handle slice_by_slice processing for 3D arrays 

1105 if slice_by_slice and hasattr(image, 'ndim') and image.ndim == 3: 

1106 from openhcs.core.memory.stack_utils import unstack_slices, stack_slices, _detect_memory_type 

1107 

1108 # Detect memory type and use proper OpenHCS utilities 

1109 memory_type = _detect_memory_type(image) 

1110 gpu_id = image.device.index if hasattr(image, 'device') and image.device.type == 'cuda' else 0 

1111 

1112 # Unstack 3D array into 2D slices 

1113 slices_2d = unstack_slices(image, memory_type=memory_type, gpu_id=gpu_id) 

1114 

1115 # Process each slice and handle special outputs 

1116 main_outputs = [] 

1117 special_outputs_list = [] 

1118 

1119 for slice_2d in slices_2d: 

1120 slice_result = original_func(slice_2d, *args, **kwargs) 

1121 

1122 # Check if result is a tuple (indicating special outputs) 

1123 if isinstance(slice_result, tuple): 

1124 main_outputs.append(slice_result[0]) # First element is main output 

1125 special_outputs_list.append(slice_result[1:]) # Rest are special outputs 

1126 else: 

1127 main_outputs.append(slice_result) # Single output 

1128 

1129 # Stack main outputs back into 3D array 

1130 result = stack_slices(main_outputs, memory_type=memory_type, gpu_id=gpu_id) 

1131 

1132 # If we have special outputs, combine them and return tuple 

1133 if special_outputs_list: 

1134 # Combine special outputs from all slices 

1135 combined_special_outputs = [] 

1136 num_special_outputs = len(special_outputs_list[0]) 

1137 

1138 for i in range(num_special_outputs): 

1139 # Collect the i-th special output from all slices 

1140 special_output_values = [slice_outputs[i] for slice_outputs in special_outputs_list] 

1141 combined_special_outputs.append(special_output_values) 

1142 

1143 # Return tuple: (stacked_main_output, combined_special_output1, combined_special_output2, ...) 

1144 result = (result, *combined_special_outputs) 

1145 else: 

1146 # Process normally 

1147 result = original_func(image, *args, **kwargs) 

1148 

1149 # Apply dtype conversion if result is a tensor and we have dtype conversion info 

1150 if (hasattr(result, 'dtype') and hasattr(result, 'shape') and 

1151 original_dtype is not None and dtype_conversion is not None): 

1152 

1153 if dtype_conversion == DtypeConversion.PRESERVE_INPUT: 

1154 # Preserve input dtype 

1155 if result.dtype != original_dtype: 

1156 result = result.to(original_dtype) 

1157 elif dtype_conversion == DtypeConversion.NATIVE_OUTPUT: 

1158 # Return PyTorch's native output dtype 

1159 pass # No conversion needed 

1160 else: 

1161 # Force specific dtype 

1162 target_dtype = dtype_conversion.numpy_dtype 

1163 if target_dtype is not None: 

1164 # Map numpy dtypes to torch dtypes 

1165 import numpy as np 

1166 numpy_to_torch = { 

1167 np.uint8: torch.uint8, 

1168 np.uint16: torch.int32, # PyTorch doesn't have uint16, use int32 

1169 np.int16: torch.int16, 

1170 np.int32: torch.int32, 

1171 np.float32: torch.float32, 

1172 np.float64: torch.float64, 

1173 } 

1174 torch_dtype = numpy_to_torch.get(target_dtype) 

1175 if torch_dtype is not None: 

1176 result = result.to(torch_dtype) 

1177 

1178 return result 

1179 

1180 except Exception as e: 

1181 logger.error(f"Error in PyTorch dtype/slice preserving wrapper for {func_name}: {e}") 

1182 # Return original result on error 

1183 return original_func(image, *args, **kwargs) 

1184 

1185 # Update function signature to include new parameters 

1186 try: 

1187 original_sig = inspect.signature(original_func) 

1188 new_params = list(original_sig.parameters.values()) 

1189 

1190 # Add dtype_conversion parameter first (before slice_by_slice) 

1191 param_names = [p.name for p in new_params] 

1192 if 'dtype_conversion' not in param_names: 1192 ↛ 1202line 1192 didn't jump to line 1202 because the condition on line 1192 was always true

1193 dtype_param = inspect.Parameter( 

1194 'dtype_conversion', 

1195 inspect.Parameter.KEYWORD_ONLY, 

1196 default=DtypeConversion.PRESERVE_INPUT, 

1197 annotation=DtypeConversion 

1198 ) 

1199 new_params.append(dtype_param) 

1200 

1201 # Add slice_by_slice parameter after dtype_conversion 

1202 if 'slice_by_slice' not in param_names: 

1203 slice_param = inspect.Parameter( 

1204 'slice_by_slice', 

1205 inspect.Parameter.KEYWORD_ONLY, 

1206 default=False, 

1207 annotation=bool 

1208 ) 

1209 new_params.append(slice_param) 

1210 

1211 new_sig = original_sig.replace(parameters=new_params) 

1212 torch_dtype_and_slice_preserving_wrapper.__signature__ = new_sig 

1213 

1214 # Set type annotations manually for get_type_hints() compatibility 

1215 torch_dtype_and_slice_preserving_wrapper.__annotations__ = getattr(original_func, '__annotations__', {}).copy() 

1216 if DtypeConversion is not None: 1216 ↛ 1218line 1216 didn't jump to line 1218 because the condition on line 1216 was always true

1217 torch_dtype_and_slice_preserving_wrapper.__annotations__['dtype_conversion'] = DtypeConversion 

1218 torch_dtype_and_slice_preserving_wrapper.__annotations__['slice_by_slice'] = bool 

1219 

1220 except Exception: 

1221 # If signature modification fails, continue without it 

1222 pass 

1223 

1224 # Update docstring to mention new parameters 

1225 original_doc = torch_dtype_and_slice_preserving_wrapper.__doc__ or "" 

1226 additional_doc = """ 

1227 

1228 Additional OpenHCS Parameters 

1229 ----------------------------- 

1230 slice_by_slice : bool, optional (default: False) 

1231 If True, process 3D arrays slice-by-slice to avoid cross-slice contamination. 

1232 If False, use original 3D behavior. Recommended for edge detection functions 

1233 on stitched microscopy data to prevent artifacts at field boundaries. 

1234 

1235 dtype_conversion : DtypeConversion, optional (default: PRESERVE_INPUT) 

1236 Controls output data type conversion: 

1237 

1238 - PRESERVE_INPUT: Keep input dtype (uint16 → uint16) 

1239 - NATIVE_OUTPUT: Use PyTorch's native output dtype 

1240 - UINT8: Force 8-bit unsigned integer (0-255 range) 

1241 - UINT16: Force 16-bit unsigned integer (mapped to int32 in PyTorch) 

1242 - INT16: Force 16-bit signed integer 

1243 - INT32: Force 32-bit signed integer 

1244 - FLOAT32: Force 32-bit float (GPU performance) 

1245 - FLOAT64: Force 64-bit float (maximum precision) 

1246 """ 

1247 torch_dtype_and_slice_preserving_wrapper.__doc__ = original_doc + additional_doc 

1248 

1249 return torch_dtype_and_slice_preserving_wrapper 

1250 

1251 

1252def _create_tensorflow_dtype_preserving_wrapper(original_func, func_name): 

1253 """ 

1254 Create a wrapper that preserves input data type and adds slice_by_slice parameter for TensorFlow functions. 

1255 

1256 This follows the same pattern as existing dtype preservation wrappers for consistency. 

1257 TensorFlow functions generally preserve dtypes well, but this wrapper ensures consistent behavior 

1258 and adds slice_by_slice parameter to avoid cross-slice contamination in 3D arrays. 

1259 """ 

1260 import inspect 

1261 from functools import wraps 

1262 

1263 @wraps(original_func) 

1264 def tensorflow_dtype_and_slice_preserving_wrapper(image, *args, dtype_conversion=None, slice_by_slice: bool = False, **kwargs): 

1265 # Set default dtype_conversion if not provided 

1266 if dtype_conversion is None: 

1267 dtype_conversion = DtypeConversion.PRESERVE_INPUT 

1268 

1269 try: 

1270 tf = optional_import("tensorflow") 

1271 if tf is None: 

1272 return original_func(image, *args, **kwargs) 

1273 

1274 # Store original dtype 

1275 original_dtype = image.dtype if hasattr(image, 'dtype') else None 

1276 

1277 # Handle slice_by_slice processing for 3D arrays 

1278 if slice_by_slice and hasattr(image, 'ndim') and image.ndim == 3: 

1279 from openhcs.core.memory.stack_utils import unstack_slices, stack_slices, _detect_memory_type 

1280 

1281 # Detect memory type and use proper OpenHCS utilities 

1282 memory_type = _detect_memory_type(image) 

1283 gpu_id = 0 # TensorFlow manages GPU placement internally 

1284 

1285 # Unstack 3D array into 2D slices 

1286 slices_2d = unstack_slices(image, memory_type=memory_type, gpu_id=gpu_id) 

1287 

1288 # Process each slice and handle special outputs 

1289 main_outputs = [] 

1290 special_outputs_list = [] 

1291 

1292 for slice_2d in slices_2d: 

1293 slice_result = original_func(slice_2d, *args, **kwargs) 

1294 

1295 # Check if result is a tuple (indicating special outputs) 

1296 if isinstance(slice_result, tuple): 

1297 main_outputs.append(slice_result[0]) # First element is main output 

1298 special_outputs_list.append(slice_result[1:]) # Rest are special outputs 

1299 else: 

1300 main_outputs.append(slice_result) # Single output 

1301 

1302 # Stack main outputs back into 3D array 

1303 result = stack_slices(main_outputs, memory_type=memory_type, gpu_id=gpu_id) 

1304 

1305 # If we have special outputs, combine them and return tuple 

1306 if special_outputs_list: 

1307 # Combine special outputs from all slices 

1308 combined_special_outputs = [] 

1309 num_special_outputs = len(special_outputs_list[0]) 

1310 

1311 for i in range(num_special_outputs): 

1312 # Collect the i-th special output from all slices 

1313 special_output_values = [slice_outputs[i] for slice_outputs in special_outputs_list] 

1314 combined_special_outputs.append(special_output_values) 

1315 

1316 # Return tuple: (stacked_main_output, combined_special_output1, combined_special_output2, ...) 

1317 result = (result, *combined_special_outputs) 

1318 else: 

1319 # Process normally 

1320 result = original_func(image, *args, **kwargs) 

1321 

1322 # Apply dtype conversion if result is a tensor and we have dtype conversion info 

1323 if (hasattr(result, 'dtype') and hasattr(result, 'shape') and 

1324 original_dtype is not None and dtype_conversion is not None): 

1325 

1326 if dtype_conversion == DtypeConversion.PRESERVE_INPUT: 

1327 # Preserve input dtype 

1328 if result.dtype != original_dtype: 

1329 result = tf.cast(result, original_dtype) 

1330 elif dtype_conversion == DtypeConversion.NATIVE_OUTPUT: 

1331 # Return TensorFlow's native output dtype 

1332 pass # No conversion needed 

1333 else: 

1334 # Force specific dtype 

1335 target_dtype = dtype_conversion.numpy_dtype 

1336 if target_dtype is not None: 

1337 # Convert numpy dtype to tensorflow dtype 

1338 import numpy as np 

1339 numpy_to_tf = { 

1340 np.uint8: tf.uint8, 

1341 np.uint16: tf.uint16, 

1342 np.int16: tf.int16, 

1343 np.int32: tf.int32, 

1344 np.float32: tf.float32, 

1345 np.float64: tf.float64, 

1346 } 

1347 tf_dtype = numpy_to_tf.get(target_dtype) 

1348 if tf_dtype is not None: 

1349 result = tf.cast(result, tf_dtype) 

1350 

1351 return result 

1352 

1353 except Exception as e: 

1354 logger.error(f"Error in TensorFlow dtype/slice preserving wrapper for {func_name}: {e}") 

1355 # Return original result on error 

1356 return original_func(image, *args, **kwargs) 

1357 

1358 # Update function signature to include new parameters 

1359 try: 

1360 original_sig = inspect.signature(original_func) 

1361 new_params = list(original_sig.parameters.values()) 

1362 

1363 # Add slice_by_slice parameter if not already present 

1364 param_names = [p.name for p in new_params] 

1365 if 'slice_by_slice' not in param_names: 1365 ↛ 1375line 1365 didn't jump to line 1375 because the condition on line 1365 was always true

1366 slice_param = inspect.Parameter( 

1367 'slice_by_slice', 

1368 inspect.Parameter.KEYWORD_ONLY, 

1369 default=False, 

1370 annotation=bool 

1371 ) 

1372 new_params.append(slice_param) 

1373 

1374 # Add dtype_conversion parameter if DtypeConversion is available 

1375 if DtypeConversion is not None and 'dtype_conversion' not in param_names: 1375 ↛ 1384line 1375 didn't jump to line 1384 because the condition on line 1375 was always true

1376 dtype_param = inspect.Parameter( 

1377 'dtype_conversion', 

1378 inspect.Parameter.KEYWORD_ONLY, 

1379 default=DtypeConversion.PRESERVE_INPUT, 

1380 annotation=DtypeConversion 

1381 ) 

1382 new_params.append(dtype_param) 

1383 

1384 new_sig = original_sig.replace(parameters=new_params) 

1385 tensorflow_dtype_and_slice_preserving_wrapper.__signature__ = new_sig 

1386 

1387 # Set type annotations manually for get_type_hints() compatibility 

1388 tensorflow_dtype_and_slice_preserving_wrapper.__annotations__ = getattr(original_func, '__annotations__', {}).copy() 

1389 tensorflow_dtype_and_slice_preserving_wrapper.__annotations__['slice_by_slice'] = bool 

1390 if DtypeConversion is not None: 1390 ↛ 1398line 1390 didn't jump to line 1398 because the condition on line 1390 was always true

1391 tensorflow_dtype_and_slice_preserving_wrapper.__annotations__['dtype_conversion'] = DtypeConversion 

1392 

1393 except Exception: 

1394 # If signature modification fails, continue without it 

1395 pass 

1396 

1397 # Update docstring to mention new parameters 

1398 original_doc = tensorflow_dtype_and_slice_preserving_wrapper.__doc__ or "" 

1399 additional_doc = """ 

1400 

1401 Additional OpenHCS Parameters 

1402 ----------------------------- 

1403 slice_by_slice : bool, optional (default: False) 

1404 If True, process 3D arrays slice-by-slice to avoid cross-slice contamination. 

1405 If False, use original 3D behavior. Recommended for edge detection functions 

1406 on stitched microscopy data to prevent artifacts at field boundaries. 

1407 

1408 dtype_conversion : DtypeConversion, optional (default: PRESERVE_INPUT) 

1409 Controls output data type conversion: 

1410 

1411 - PRESERVE_INPUT: Keep input dtype (uint16 → uint16) 

1412 - NATIVE_OUTPUT: Use TensorFlow's native output dtype 

1413 - UINT8: Force 8-bit unsigned integer (0-255 range) 

1414 - UINT16: Force 16-bit unsigned integer (microscopy standard) 

1415 - INT16: Force 16-bit signed integer 

1416 - INT32: Force 32-bit signed integer 

1417 - FLOAT32: Force 32-bit float (GPU performance) 

1418 - FLOAT64: Force 64-bit float (maximum precision) 

1419 """ 

1420 tensorflow_dtype_and_slice_preserving_wrapper.__doc__ = original_doc + additional_doc 

1421 

1422 return tensorflow_dtype_and_slice_preserving_wrapper 

1423 

1424 

1425def _create_jax_dtype_preserving_wrapper(original_func, func_name): 

1426 """ 

1427 Create a wrapper that preserves input data type and adds slice_by_slice parameter for JAX functions. 

1428 

1429 This follows the same pattern as existing dtype preservation wrappers for consistency. 

1430 JAX functions generally preserve dtypes well, but this wrapper ensures consistent behavior 

1431 and adds slice_by_slice parameter to avoid cross-slice contamination in 3D arrays. 

1432 """ 

1433 import inspect 

1434 from functools import wraps 

1435 

1436 @wraps(original_func) 

1437 def jax_dtype_and_slice_preserving_wrapper(image, *args, dtype_conversion=None, slice_by_slice: bool = False, **kwargs): 

1438 # Set default dtype_conversion if not provided 

1439 if dtype_conversion is None: 

1440 dtype_conversion = DtypeConversion.PRESERVE_INPUT 

1441 

1442 try: 

1443 jax = optional_import("jax") 

1444 jnp = optional_import("jax.numpy") if jax is not None else None 

1445 if jax is None or jnp is None: 

1446 return original_func(image, *args, **kwargs) 

1447 

1448 # Store original dtype 

1449 original_dtype = image.dtype if hasattr(image, 'dtype') else None 

1450 

1451 # Handle slice_by_slice processing for 3D arrays 

1452 if slice_by_slice and hasattr(image, 'ndim') and image.ndim == 3: 

1453 from openhcs.core.memory.stack_utils import unstack_slices, stack_slices, _detect_memory_type 

1454 

1455 # Detect memory type and use proper OpenHCS utilities 

1456 memory_type = _detect_memory_type(image) 

1457 gpu_id = 0 # JAX manages GPU placement internally 

1458 

1459 # Unstack 3D array into 2D slices 

1460 slices_2d = unstack_slices(image, memory_type=memory_type, gpu_id=gpu_id) 

1461 

1462 # Process each slice and handle special outputs 

1463 main_outputs = [] 

1464 special_outputs_list = [] 

1465 

1466 for slice_2d in slices_2d: 

1467 slice_result = original_func(slice_2d, *args, **kwargs) 

1468 

1469 # Check if result is a tuple (indicating special outputs) 

1470 if isinstance(slice_result, tuple): 

1471 main_outputs.append(slice_result[0]) # First element is main output 

1472 special_outputs_list.append(slice_result[1:]) # Rest are special outputs 

1473 else: 

1474 main_outputs.append(slice_result) # Single output 

1475 

1476 # Stack main outputs back into 3D array 

1477 result = stack_slices(main_outputs, memory_type=memory_type, gpu_id=gpu_id) 

1478 

1479 # If we have special outputs, combine them and return tuple 

1480 if special_outputs_list: 

1481 # Combine special outputs from all slices 

1482 combined_special_outputs = [] 

1483 num_special_outputs = len(special_outputs_list[0]) 

1484 

1485 for i in range(num_special_outputs): 

1486 # Collect the i-th special output from all slices 

1487 special_output_values = [slice_outputs[i] for slice_outputs in special_outputs_list] 

1488 combined_special_outputs.append(special_output_values) 

1489 

1490 # Return tuple: (stacked_main_output, combined_special_output1, combined_special_output2, ...) 

1491 result = (result, *combined_special_outputs) 

1492 else: 

1493 # Process normally 

1494 result = original_func(image, *args, **kwargs) 

1495 

1496 # Apply dtype conversion if result is an array and we have dtype conversion info 

1497 if (hasattr(result, 'dtype') and hasattr(result, 'shape') and 

1498 original_dtype is not None and dtype_conversion is not None): 

1499 

1500 if dtype_conversion == DtypeConversion.PRESERVE_INPUT: 

1501 # Preserve input dtype 

1502 if result.dtype != original_dtype: 

1503 result = result.astype(original_dtype) 

1504 elif dtype_conversion == DtypeConversion.NATIVE_OUTPUT: 

1505 # Return JAX's native output dtype 

1506 pass # No conversion needed 

1507 else: 

1508 # Force specific dtype 

1509 target_dtype = dtype_conversion.numpy_dtype 

1510 if target_dtype is not None: 

1511 # JAX uses numpy-compatible dtypes 

1512 result = result.astype(target_dtype) 

1513 

1514 return result 

1515 

1516 except Exception as e: 

1517 logger.error(f"Error in JAX dtype/slice preserving wrapper for {func_name}: {e}") 

1518 # Return original result on error 

1519 return original_func(image, *args, **kwargs) 

1520 

1521 # Update function signature to include new parameters 

1522 try: 

1523 original_sig = inspect.signature(original_func) 

1524 new_params = list(original_sig.parameters.values()) 

1525 

1526 # Add slice_by_slice parameter if not already present 

1527 param_names = [p.name for p in new_params] 

1528 if 'slice_by_slice' not in param_names: 

1529 slice_param = inspect.Parameter( 

1530 'slice_by_slice', 

1531 inspect.Parameter.KEYWORD_ONLY, 

1532 default=False, 

1533 annotation=bool 

1534 ) 

1535 new_params.append(slice_param) 

1536 

1537 # Add dtype_conversion parameter if DtypeConversion is available 

1538 if DtypeConversion is not None and 'dtype_conversion' not in param_names: 1538 ↛ 1547line 1538 didn't jump to line 1547 because the condition on line 1538 was always true

1539 dtype_param = inspect.Parameter( 

1540 'dtype_conversion', 

1541 inspect.Parameter.KEYWORD_ONLY, 

1542 default=DtypeConversion.PRESERVE_INPUT, 

1543 annotation=DtypeConversion 

1544 ) 

1545 new_params.append(dtype_param) 

1546 

1547 new_sig = original_sig.replace(parameters=new_params) 

1548 jax_dtype_and_slice_preserving_wrapper.__signature__ = new_sig 

1549 

1550 # Set type annotations manually for get_type_hints() compatibility 

1551 jax_dtype_and_slice_preserving_wrapper.__annotations__ = getattr(original_func, '__annotations__', {}).copy() 

1552 jax_dtype_and_slice_preserving_wrapper.__annotations__['slice_by_slice'] = bool 

1553 if DtypeConversion is not None: 1553 ↛ 1561line 1553 didn't jump to line 1561 because the condition on line 1553 was always true

1554 jax_dtype_and_slice_preserving_wrapper.__annotations__['dtype_conversion'] = DtypeConversion 

1555 

1556 except Exception: 

1557 # If signature modification fails, continue without it 

1558 pass 

1559 

1560 # Update docstring to mention new parameters 

1561 original_doc = jax_dtype_and_slice_preserving_wrapper.__doc__ or "" 

1562 additional_doc = """ 

1563 

1564 Additional OpenHCS Parameters 

1565 ----------------------------- 

1566 slice_by_slice : bool, optional (default: False) 

1567 If True, process 3D arrays slice-by-slice to avoid cross-slice contamination. 

1568 If False, use original 3D behavior. Recommended for edge detection functions 

1569 on stitched microscopy data to prevent artifacts at field boundaries. 

1570 

1571 dtype_conversion : DtypeConversion, optional (default: PRESERVE_INPUT) 

1572 Controls output data type conversion: 

1573 

1574 - PRESERVE_INPUT: Keep input dtype (uint16 → uint16) 

1575 - NATIVE_OUTPUT: Use JAX's native output dtype 

1576 - UINT8: Force 8-bit unsigned integer (0-255 range) 

1577 - UINT16: Force 16-bit unsigned integer (microscopy standard) 

1578 - INT16: Force 16-bit signed integer 

1579 - INT32: Force 32-bit signed integer 

1580 - FLOAT32: Force 32-bit float (GPU performance) 

1581 - FLOAT64: Force 64-bit float (maximum precision) 

1582 """ 

1583 jax_dtype_and_slice_preserving_wrapper.__doc__ = original_doc + additional_doc 

1584 

1585 return jax_dtype_and_slice_preserving_wrapper 

1586 

1587 

1588def _create_pyclesperanto_dtype_preserving_wrapper(original_func, func_name): 

1589 """ 

1590 Create a wrapper that ensures array-in/array-out compliance and dtype preservation for pyclesperanto functions. 

1591 

1592 All OpenHCS functions must: 

1593 1. Take 3D pyclesperanto array as first argument 

1594 2. Return 3D pyclesperanto array as first output 

1595 3. Additional outputs (values, coordinates) as 2nd, 3rd, etc. returns 

1596 4. Preserve input dtype when appropriate 

1597 """ 

1598 import inspect 

1599 from functools import wraps 

1600 

1601 @wraps(original_func) 

1602 def pyclesperanto_dtype_and_slice_preserving_wrapper(image_3d, *args, dtype_conversion=None, slice_by_slice: bool = False, **kwargs): 

1603 # Set default dtype_conversion if not provided 

1604 if dtype_conversion is None: 

1605 dtype_conversion = DtypeConversion.PRESERVE_INPUT 

1606 

1607 try: 

1608 cle = optional_import("pyclesperanto") 

1609 if cle is None: 

1610 return original_func(image_3d, *args, **kwargs) 

1611 

1612 # Store original dtype for preservation 

1613 original_dtype = image_3d.dtype 

1614 

1615 # Handle slice_by_slice processing for 3D arrays using OpenHCS stack utilities 

1616 if slice_by_slice and hasattr(image_3d, 'ndim') and image_3d.ndim == 3: 

1617 from openhcs.core.memory.stack_utils import unstack_slices, stack_slices, _detect_memory_type 

1618 

1619 # Detect memory type and use proper OpenHCS utilities 

1620 memory_type = _detect_memory_type(image_3d) 

1621 gpu_id = 0 # pyclesperanto manages GPU internally 

1622 

1623 # Process each slice and handle special outputs 

1624 slices = unstack_slices(image_3d, memory_type, gpu_id) 

1625 main_outputs = [] 

1626 special_outputs_list = [] 

1627 

1628 for slice_2d in slices: 

1629 # Apply function to 2D slice 

1630 result_slice = original_func(slice_2d, *args, **kwargs) 

1631 

1632 # Check if result is a tuple (indicating special outputs) 

1633 if isinstance(result_slice, tuple): 

1634 main_outputs.append(result_slice[0]) # First element is main output 

1635 special_outputs_list.append(result_slice[1:]) # Rest are special outputs 

1636 else: 

1637 main_outputs.append(result_slice) # Single output 

1638 

1639 # Stack main outputs back into 3D array 

1640 result = stack_slices(main_outputs, memory_type, gpu_id) 

1641 

1642 # If we have special outputs, combine them and return tuple 

1643 if special_outputs_list: 

1644 # Combine special outputs from all slices 

1645 combined_special_outputs = [] 

1646 num_special_outputs = len(special_outputs_list[0]) 

1647 

1648 for i in range(num_special_outputs): 

1649 # Collect the i-th special output from all slices 

1650 special_output_values = [slice_outputs[i] for slice_outputs in special_outputs_list] 

1651 combined_special_outputs.append(special_output_values) 

1652 

1653 # Return tuple: (stacked_main_output, combined_special_output1, combined_special_output2, ...) 

1654 result = (result, *combined_special_outputs) 

1655 else: 

1656 # Normal 3D processing 

1657 result = original_func(image_3d, *args, **kwargs) 

1658 

1659 # Check if result is 2D and needs expansion to 3D 

1660 if hasattr(result, 'ndim') and result.ndim == 2: 

1661 # Expand 2D result to 3D single slice 

1662 try: 

1663 # Concatenate with itself to create 3D, then take first slice 

1664 temp_3d = cle.concatenate_along_z(result, result) # Creates (2, Y, X) 

1665 result = temp_3d[0:1, :, :] # Take first slice to get (1, Y, X) 

1666 except Exception: 

1667 # If expansion fails, return original 2D result 

1668 # This maintains backward compatibility 

1669 pass 

1670 

1671 # Apply dtype conversion based on enum value 

1672 if hasattr(result, 'dtype') and hasattr(result, 'shape') and dtype_conversion is not None: 

1673 if dtype_conversion == DtypeConversion.PRESERVE_INPUT: 

1674 # Preserve input dtype 

1675 if result.dtype != original_dtype: 

1676 return _scale_and_convert_pyclesperanto(result, original_dtype) 

1677 return result 

1678 

1679 elif dtype_conversion == DtypeConversion.NATIVE_OUTPUT: 

1680 # Return pyclesperanto's native output dtype 

1681 return result 

1682 

1683 else: 

1684 # Force specific dtype 

1685 target_dtype = dtype_conversion.numpy_dtype 

1686 if target_dtype is not None and result.dtype != target_dtype: 

1687 return _scale_and_convert_pyclesperanto(result, target_dtype) 

1688 return result 

1689 else: 

1690 # Non-array result, return as-is 

1691 return result 

1692 

1693 except Exception as e: 

1694 logger.error(f"Error in pyclesperanto dtype/slice preserving wrapper for {func_name}: {e}") 

1695 # If anything goes wrong, fall back to original function 

1696 return original_func(image_3d, *args, **kwargs) 

1697 

1698 # Update function signature to include new parameters 

1699 try: 

1700 original_sig = inspect.signature(original_func) 

1701 new_params = list(original_sig.parameters.values()) 

1702 

1703 # Add slice_by_slice parameter if not already present 

1704 param_names = [p.name for p in new_params] 

1705 if 'slice_by_slice' not in param_names: 1705 ↛ 1715line 1705 didn't jump to line 1715 because the condition on line 1705 was always true

1706 slice_param = inspect.Parameter( 

1707 'slice_by_slice', 

1708 inspect.Parameter.KEYWORD_ONLY, 

1709 default=False, 

1710 annotation=bool 

1711 ) 

1712 new_params.append(slice_param) 

1713 

1714 # Add dtype_conversion parameter if DtypeConversion is available 

1715 if DtypeConversion is not None and 'dtype_conversion' not in param_names: 1715 ↛ 1724line 1715 didn't jump to line 1724 because the condition on line 1715 was always true

1716 dtype_param = inspect.Parameter( 

1717 'dtype_conversion', 

1718 inspect.Parameter.KEYWORD_ONLY, 

1719 default=DtypeConversion.PRESERVE_INPUT, 

1720 annotation=DtypeConversion 

1721 ) 

1722 new_params.append(dtype_param) 

1723 

1724 new_sig = original_sig.replace(parameters=new_params) 

1725 pyclesperanto_dtype_and_slice_preserving_wrapper.__signature__ = new_sig 

1726 

1727 # Set type annotations manually for get_type_hints() compatibility 

1728 pyclesperanto_dtype_and_slice_preserving_wrapper.__annotations__ = getattr(original_func, '__annotations__', {}).copy() 

1729 pyclesperanto_dtype_and_slice_preserving_wrapper.__annotations__['slice_by_slice'] = bool 

1730 if DtypeConversion is not None: 1730 ↛ 1738line 1730 didn't jump to line 1738 because the condition on line 1730 was always true

1731 pyclesperanto_dtype_and_slice_preserving_wrapper.__annotations__['dtype_conversion'] = DtypeConversion 

1732 

1733 except Exception: 

1734 # If signature modification fails, continue without it 

1735 pass 

1736 

1737 # Update docstring to mention additional parameters 

1738 original_doc = pyclesperanto_dtype_and_slice_preserving_wrapper.__doc__ or "" 

1739 additional_doc = """ 

1740 

1741 Additional OpenHCS Parameters 

1742 ----------------------------- 

1743 slice_by_slice : bool, optional (default: False) 

1744 If True, process 3D arrays slice-by-slice to avoid cross-slice contamination. 

1745 If False, use original 3D behavior. Recommended for edge detection functions 

1746 on stitched microscopy data to prevent artifacts at field boundaries. 

1747 

1748 dtype_conversion : DtypeConversion, optional (default: PRESERVE_INPUT) 

1749 Controls output data type conversion: 

1750 

1751 - PRESERVE_INPUT: Keep input dtype (uint16 → uint16) 

1752 - NATIVE_OUTPUT: Use pyclesperanto's native output (often float32) 

1753 - UINT8: Force 8-bit unsigned integer (0-255 range) 

1754 - UINT16: Force 16-bit unsigned integer (microscopy standard) 

1755 - INT16: Force 16-bit signed integer 

1756 - INT32: Force 32-bit signed integer 

1757 - FLOAT32: Force 32-bit float (GPU performance) 

1758 - FLOAT64: Force 64-bit float (maximum precision) 

1759 """ 

1760 pyclesperanto_dtype_and_slice_preserving_wrapper.__doc__ = original_doc + additional_doc 

1761 

1762 return pyclesperanto_dtype_and_slice_preserving_wrapper 

1763 

1764 

1765def _scale_and_convert_pyclesperanto(result, target_dtype): 

1766 """ 

1767 Scale and convert pyclesperanto array to target dtype. 

1768 This is a simplified version of the helper function from pyclesperanto_registry.py 

1769 """ 

1770 try: 

1771 cle = optional_import("pyclesperanto") 

1772 if cle is None: 

1773 return result 

1774 

1775 import numpy as np 

1776 

1777 # If result is floating point and target is integer, scale appropriately 

1778 if np.issubdtype(result.dtype, np.floating) and not np.issubdtype(target_dtype, np.floating): 

1779 # Convert to numpy for scaling, then back to pyclesperanto 

1780 result_np = cle.pull(result) 

1781 

1782 # Clip to [0, 1] range and scale to integer range 

1783 clipped = np.clip(result_np, 0, 1) 

1784 if target_dtype == np.uint8: 

1785 scaled = (clipped * 255).astype(target_dtype) 

1786 elif target_dtype == np.uint16: 

1787 scaled = (clipped * 65535).astype(target_dtype) 

1788 elif target_dtype == np.uint32: 

1789 scaled = (clipped * 4294967295).astype(target_dtype) 

1790 else: 

1791 # For other integer types, just convert without scaling 

1792 scaled = clipped.astype(target_dtype) 

1793 

1794 # Push back to GPU 

1795 return cle.push(scaled) 

1796 else: 

1797 # Direct conversion for same numeric type families 

1798 result_np = cle.pull(result) 

1799 converted = result_np.astype(target_dtype) 

1800 return cle.push(converted) 

1801 

1802 except Exception: 

1803 # If conversion fails, return original result 

1804 return result