Coverage for openhcs/core/memory/decorators.py: 30.2%

759 statements  

« prev     ^ index     » next       coverage.py v7.10.3, created at 2025-08-14 05:57 +0000

1""" 

2Memory type declaration decorators for OpenHCS. 

3 

4This module provides decorators for explicitly declaring the memory interface 

5of pure functions, enforcing Clause 106-A (Declared Memory Types) and supporting 

6memory-type-aware dispatching and orchestration. 

7 

8These decorators annotate functions with input_memory_type and output_memory_type 

9attributes and provide automatic thread-local CUDA stream management for GPU 

10frameworks to enable true parallelization across multiple threads. 

11""" 

12 

13import functools 

14import logging 

15import threading 

16from typing import Any, Callable, Optional, TypeVar 

17 

18from openhcs.constants.constants import VALID_MEMORY_TYPES 

19from openhcs.core.utils import optional_import 

20from openhcs.core.memory.oom_recovery import _execute_with_oom_recovery 

21 

22logger = logging.getLogger(__name__) 

23 

24F = TypeVar('F', bound=Callable[..., Any]) 

25 

26# Dtype conversion enum and utilities for consistent dtype handling across all frameworks 

27from enum import Enum 

28import numpy as np 

29 

30class DtypeConversion(Enum): 

31 """Data type conversion modes for all memory type functions.""" 

32 

33 PRESERVE_INPUT = "preserve" # Keep input dtype (default) 

34 NATIVE_OUTPUT = "native" # Use framework's native output 

35 UINT8 = "uint8" # Force uint8 (0-255 range) 

36 UINT16 = "uint16" # Force uint16 (microscopy standard) 

37 INT16 = "int16" # Force int16 (signed microscopy data) 

38 INT32 = "int32" # Force int32 (large integer values) 

39 FLOAT32 = "float32" # Force float32 (GPU performance) 

40 FLOAT64 = "float64" # Force float64 (maximum precision) 

41 

42 @property 

43 def numpy_dtype(self): 

44 """Get the corresponding numpy dtype.""" 

45 dtype_map = { 

46 self.UINT8: np.uint8, 

47 self.UINT16: np.uint16, 

48 self.INT16: np.int16, 

49 self.INT32: np.int32, 

50 self.FLOAT32: np.float32, 

51 self.FLOAT64: np.float64, 

52 } 

53 return dtype_map.get(self, None) 

54 

55 

56def _scale_and_convert_numpy(result, target_dtype): 

57 """Scale numpy results to target integer range and convert dtype.""" 

58 if not hasattr(result, 'dtype'): 

59 return result 

60 

61 # Check if result is floating point and target is integer 

62 result_is_float = np.issubdtype(result.dtype, np.floating) 

63 target_is_int = target_dtype in [np.uint8, np.uint16, np.uint32, np.int8, np.int16, np.int32] 

64 

65 if result_is_float and target_is_int: 

66 # Scale floating point results to integer range 

67 result_min = result.min() 

68 result_max = result.max() 

69 

70 if result_max > result_min: # Avoid division by zero 

71 # Normalize to [0, 1] range 

72 normalized = (result - result_min) / (result_max - result_min) 

73 

74 # Scale to target dtype range 

75 if target_dtype == np.uint8: 

76 scaled = normalized * 255.0 

77 elif target_dtype == np.uint16: 

78 scaled = normalized * 65535.0 

79 elif target_dtype == np.uint32: 

80 scaled = normalized * 4294967295.0 

81 elif target_dtype == np.int16: 

82 scaled = normalized * 65535.0 - 32768.0 

83 elif target_dtype == np.int32: 

84 scaled = normalized * 4294967295.0 - 2147483648.0 

85 else: 

86 scaled = normalized 

87 

88 return scaled.astype(target_dtype) 

89 else: 

90 # Constant image, just convert dtype 

91 return result.astype(target_dtype) 

92 else: 

93 # Direct conversion for compatible types 

94 return result.astype(target_dtype) 

95 

96 

97def _scale_and_convert_pyclesperanto(result, target_dtype): 

98 """Scale pyclesperanto results to target integer range and convert dtype.""" 

99 try: 

100 import pyclesperanto as cle 

101 except ImportError: 

102 return result 

103 

104 if not hasattr(result, 'dtype'): 

105 return result 

106 

107 # Check if result is floating point and target is integer 

108 result_is_float = np.issubdtype(result.dtype, np.floating) 

109 target_is_int = target_dtype in [np.uint8, np.uint16, np.uint32, np.int8, np.int16, np.int32] 

110 

111 if result_is_float and target_is_int: 

112 # Get min/max of result for proper scaling 

113 result_min = float(cle.minimum_of_all_pixels(result)) 

114 result_max = float(cle.maximum_of_all_pixels(result)) 

115 

116 if result_max > result_min: # Avoid division by zero 

117 # Normalize to [0, 1] range 

118 normalized = cle.subtract_image_from_scalar(result, scalar=result_min) 

119 range_val = result_max - result_min 

120 normalized = cle.multiply_image_and_scalar(normalized, scalar=1.0/range_val) 

121 

122 # Scale to target dtype range 

123 if target_dtype == np.uint8: 

124 scaled = cle.multiply_image_and_scalar(normalized, scalar=255.0) 

125 elif target_dtype == np.uint16: 

126 scaled = cle.multiply_image_and_scalar(normalized, scalar=65535.0) 

127 elif target_dtype == np.uint32: 

128 scaled = cle.multiply_image_and_scalar(normalized, scalar=4294967295.0) 

129 elif target_dtype == np.int16: 

130 scaled = cle.multiply_image_and_scalar(normalized, scalar=65535.0) 

131 scaled = cle.subtract_image_from_scalar(scaled, scalar=32768.0) 

132 elif target_dtype == np.int32: 

133 scaled = cle.multiply_image_and_scalar(normalized, scalar=4294967295.0) 

134 scaled = cle.subtract_image_from_scalar(scaled, scalar=2147483648.0) 

135 else: 

136 scaled = normalized 

137 

138 # Convert to target dtype using push/pull method 

139 scaled_cpu = cle.pull(scaled).astype(target_dtype) 

140 return cle.push(scaled_cpu) 

141 else: 

142 # Constant image, just convert dtype 

143 result_cpu = cle.pull(result).astype(target_dtype) 

144 return cle.push(result_cpu) 

145 else: 

146 # Direct conversion for compatible types 

147 result_cpu = cle.pull(result).astype(target_dtype) 

148 return cle.push(result_cpu) 

149 

150 

151def _scale_and_convert_cupy(result, target_dtype): 

152 """Scale CuPy results to target integer range and convert dtype.""" 

153 try: 

154 import cupy as cp 

155 except ImportError: 

156 return result 

157 

158 if not hasattr(result, 'dtype'): 

159 return result 

160 

161 # If result is floating point and target is integer, scale appropriately 

162 if cp.issubdtype(result.dtype, cp.floating) and not cp.issubdtype(target_dtype, cp.floating): 

163 # Clip to [0, 1] range and scale to integer range 

164 clipped = cp.clip(result, 0, 1) 

165 if target_dtype == cp.uint8: 

166 return (clipped * 255).astype(target_dtype) 

167 elif target_dtype == cp.uint16: 

168 return (clipped * 65535).astype(target_dtype) 

169 elif target_dtype == cp.uint32: 

170 return (clipped * 4294967295).astype(target_dtype) 

171 else: 

172 # For other integer types, just convert without scaling 

173 return result.astype(target_dtype) 

174 

175 # Direct conversion for same numeric type families 

176 return result.astype(target_dtype) 

177 

178 

179# GPU frameworks imported lazily to prevent thread explosion 

180# These will be imported only when actually needed by functions 

181_gpu_frameworks_cache = {} 

182 

183def _get_cupy(): 

184 """Lazy import CuPy only when needed.""" 

185 if 'cupy' not in _gpu_frameworks_cache: 

186 _gpu_frameworks_cache['cupy'] = optional_import("cupy") 

187 if _gpu_frameworks_cache['cupy'] is not None: 

188 logger.debug(f"🔧 Lazy imported CuPy in thread {threading.current_thread().name}") 

189 return _gpu_frameworks_cache['cupy'] 

190 

191def _get_torch(): 

192 """Lazy import PyTorch only when needed.""" 

193 if 'torch' not in _gpu_frameworks_cache: 

194 _gpu_frameworks_cache['torch'] = optional_import("torch") 

195 if _gpu_frameworks_cache['torch'] is not None: 

196 logger.debug(f"🔧 Lazy imported PyTorch in thread {threading.current_thread().name}") 

197 return _gpu_frameworks_cache['torch'] 

198 

199def _get_tensorflow(): 

200 """Lazy import TensorFlow only when needed.""" 

201 if 'tensorflow' not in _gpu_frameworks_cache: 

202 _gpu_frameworks_cache['tensorflow'] = optional_import("tensorflow") 

203 if _gpu_frameworks_cache['tensorflow'] is not None: 

204 logger.debug(f"🔧 Lazy imported TensorFlow in thread {threading.current_thread().name}") 

205 return _gpu_frameworks_cache['tensorflow'] 

206 

207def _get_jax(): 

208 """Lazy import JAX only when needed.""" 

209 if 'jax' not in _gpu_frameworks_cache: 

210 _gpu_frameworks_cache['jax'] = optional_import("jax") 

211 if _gpu_frameworks_cache['jax'] is not None: 

212 logger.debug(f"🔧 Lazy imported JAX in thread {threading.current_thread().name}") 

213 return _gpu_frameworks_cache['jax'] 

214 

215# Thread-local storage for GPU streams and contexts 

216_thread_gpu_contexts = threading.local() 

217 

218class ThreadGPUContext: 

219 """Unified thread-local GPU context manager to prevent stream leaks.""" 

220 

221 def __init__(self): 

222 self._cupy_stream = None 

223 self._torch_stream = None 

224 self._thread_name = threading.current_thread().name 

225 

226 def get_cupy_stream(self): 

227 """Get or create the single CuPy stream for this thread.""" 

228 if self._cupy_stream is None: 

229 cp = _get_cupy() 

230 if cp is not None and hasattr(cp, 'cuda'): 

231 self._cupy_stream = cp.cuda.Stream() 

232 logger.debug(f"🔧 Created CuPy stream for thread {self._thread_name}") 

233 return self._cupy_stream 

234 

235 def get_torch_stream(self): 

236 """Get or create the single PyTorch stream for this thread.""" 

237 if self._torch_stream is None: 

238 torch = _get_torch() 

239 if torch is not None and hasattr(torch, 'cuda') and torch.cuda.is_available(): 

240 self._torch_stream = torch.cuda.Stream() 

241 logger.debug(f"🔧 Created PyTorch stream for thread {self._thread_name}") 

242 return self._torch_stream 

243 

244 def cleanup(self): 

245 """Clean up streams when thread exits.""" 

246 if self._cupy_stream is not None: 

247 logger.debug(f"🔧 Cleaning up CuPy stream for thread {self._thread_name}") 

248 self._cupy_stream = None 

249 

250 if self._torch_stream is not None: 

251 logger.debug(f"🔧 Cleaning up PyTorch stream for thread {self._thread_name}") 

252 self._torch_stream = None 

253 

254def get_thread_gpu_context() -> ThreadGPUContext: 

255 """Get the unified GPU context for the current thread.""" 

256 if not hasattr(_thread_gpu_contexts, 'gpu_context'): 

257 _thread_gpu_contexts.gpu_context = ThreadGPUContext() 

258 

259 # Register cleanup for when thread exits 

260 import weakref 

261 def cleanup_on_thread_exit(): 

262 if hasattr(_thread_gpu_contexts, 'gpu_context'): 

263 _thread_gpu_contexts.gpu_context.cleanup() 

264 

265 # Use weakref to avoid circular references 

266 current_thread = threading.current_thread() 

267 if hasattr(current_thread, '_cleanup_funcs'): 

268 current_thread._cleanup_funcs.append(cleanup_on_thread_exit) 

269 else: 

270 current_thread._cleanup_funcs = [cleanup_on_thread_exit] 

271 

272 return _thread_gpu_contexts.gpu_context 

273 

274 

275def memory_types(*, input_type: str, output_type: str) -> Callable[[F], F]: 

276 """ 

277 Decorator that explicitly declares the memory types for a function's input and output. 

278 

279 This decorator enforces Clause 106-A (Declared Memory Types) by requiring explicit 

280 memory type declarations for both input and output. 

281 

282 Args: 

283 input_type: The memory type for the function's input (e.g., "numpy", "cupy") 

284 output_type: The memory type for the function's output (e.g., "numpy", "cupy") 

285 

286 Returns: 

287 A decorator function that sets the memory type attributes 

288 

289 Raises: 

290 ValueError: If input_type or output_type is not a supported memory type 

291 """ 

292 # 🔒 Clause 88 — No Inferred Capabilities 

293 # Validate memory types at decoration time, not runtime 

294 if not input_type: 294 ↛ 295line 294 didn't jump to line 295 because the condition on line 294 was never true

295 raise ValueError( 

296 "Clause 106-A Violation: input_type must be explicitly declared. " 

297 "No default or inferred memory types are allowed." 

298 ) 

299 

300 if not output_type: 300 ↛ 301line 300 didn't jump to line 301 because the condition on line 300 was never true

301 raise ValueError( 

302 "Clause 106-A Violation: output_type must be explicitly declared. " 

303 "No default or inferred memory types are allowed." 

304 ) 

305 

306 # Validate that memory types are supported 

307 if input_type not in VALID_MEMORY_TYPES: 307 ↛ 308line 307 didn't jump to line 308 because the condition on line 307 was never true

308 raise ValueError( 

309 f"Clause 106-A Violation: input_type '{input_type}' is not supported. " 

310 f"Supported types are: {', '.join(sorted(VALID_MEMORY_TYPES))}" 

311 ) 

312 

313 if output_type not in VALID_MEMORY_TYPES: 313 ↛ 314line 313 didn't jump to line 314 because the condition on line 313 was never true

314 raise ValueError( 

315 f"Clause 106-A Violation: output_type '{output_type}' is not supported. " 

316 f"Supported types are: {', '.join(sorted(VALID_MEMORY_TYPES))}" 

317 ) 

318 

319 def decorator(func: F) -> F: 

320 """ 

321 Decorator function that sets memory type attributes on the function. 

322 

323 Args: 

324 func: The function to decorate 

325 

326 Returns: 

327 The decorated function with memory type attributes set 

328 

329 Raises: 

330 ValueError: If the function already has different memory type attributes 

331 """ 

332 # 🔒 Clause 66 — Immutability 

333 # Check if memory type attributes already exist 

334 if hasattr(func, 'input_memory_type') and func.input_memory_type != input_type: 334 ↛ 335line 334 didn't jump to line 335 because the condition on line 334 was never true

335 raise ValueError( 

336 f"Clause 66 Violation: Function '{func.__name__}' already has input_memory_type " 

337 f"'{func.input_memory_type}', cannot change to '{input_type}'." 

338 ) 

339 

340 if hasattr(func, 'output_memory_type') and func.output_memory_type != output_type: 340 ↛ 341line 340 didn't jump to line 341 because the condition on line 340 was never true

341 raise ValueError( 

342 f"Clause 66 Violation: Function '{func.__name__}' already has output_memory_type " 

343 f"'{func.output_memory_type}', cannot change to '{output_type}'." 

344 ) 

345 

346 # Set memory type attributes using canonical names 

347 # 🔒 Clause 106-A.2 — Canonical Memory Type Attributes 

348 func.input_memory_type = input_type 

349 func.output_memory_type = output_type 

350 

351 # Return the function unchanged (no wrapper) 

352 return func 

353 

354 return decorator 

355 

356 

357def numpy( 

358 func: Optional[F] = None, 

359 *, 

360 input_type: str = "numpy", 

361 output_type: str = "numpy" 

362) -> Any: 

363 """ 

364 Decorator that declares a function as operating on numpy arrays. 

365 

366 This is a convenience wrapper around memory_types with numpy defaults. 

367 

368 Args: 

369 func: The function to decorate (optional) 

370 input_type: The memory type for the function's input (default: "numpy") 

371 output_type: The memory type for the function's output (default: "numpy") 

372 

373 Returns: 

374 The decorated function with memory type attributes set 

375 

376 Raises: 

377 ValueError: If input_type or output_type is not a supported memory type 

378 """ 

379 def decorator_with_dtype_preservation(func: F) -> F: 

380 # Set memory type attributes 

381 memory_decorator = memory_types(input_type=input_type, output_type=output_type) 

382 func = memory_decorator(func) 

383 

384 # Apply dtype preservation wrapper 

385 func = _create_numpy_dtype_preserving_wrapper(func, func.__name__) 

386 

387 return func 

388 

389 # Handle both @numpy and @numpy(input_type=..., output_type=...) forms 

390 if func is None: 390 ↛ 391line 390 didn't jump to line 391 because the condition on line 390 was never true

391 return decorator_with_dtype_preservation 

392 

393 return decorator_with_dtype_preservation(func) 

394 

395 

396def cupy(func: Optional[F] = None, *, input_type: str = "cupy", output_type: str = "cupy", oom_recovery: bool = True) -> Any: 

397 """ 

398 Decorator that declares a function as operating on cupy arrays. 

399 

400 This decorator provides automatic thread-local CUDA stream management for 

401 true parallelization across multiple threads. Each thread gets its own 

402 persistent CUDA stream that is reused for all CuPy operations. 

403 

404 Args: 

405 func: The function to decorate (optional) 

406 input_type: The memory type for the function's input (default: "cupy") 

407 output_type: The memory type for the function's output (default: "cupy") 

408 oom_recovery: Enable automatic OOM recovery (default: True) 

409 

410 Returns: 

411 The decorated function with memory type attributes and stream management 

412 

413 Raises: 

414 ValueError: If input_type or output_type is not a supported memory type 

415 """ 

416 def decorator(func: F) -> F: 

417 # Set memory type attributes 

418 memory_decorator = memory_types(input_type=input_type, output_type=output_type) 

419 func = memory_decorator(func) 

420 

421 # Apply dtype preservation wrapper 

422 func = _create_cupy_dtype_preserving_wrapper(func, func.__name__) 

423 

424 # Add CUDA stream wrapper if CuPy is available (lazy import) 

425 @functools.wraps(func) 

426 def wrapper(*args, **kwargs): 

427 cp = _get_cupy() 

428 if cp is not None and hasattr(cp, 'cuda'): 

429 # Get unified thread context and CuPy stream 

430 gpu_context = get_thread_gpu_context() 

431 cupy_stream = gpu_context.get_cupy_stream() 

432 

433 def execute_with_stream(): 

434 if cupy_stream is not None: 

435 # Execute function in stream context 

436 with cupy_stream: 

437 return func(*args, **kwargs) 

438 else: 

439 # No CUDA available, execute without stream 

440 return func(*args, **kwargs) 

441 

442 # Execute with OOM recovery if enabled 

443 if oom_recovery: 

444 return _execute_with_oom_recovery(execute_with_stream, input_type) 

445 else: 

446 return execute_with_stream() 

447 else: 

448 # CuPy not available, execute without stream 

449 return func(*args, **kwargs) 

450 

451 # Preserve memory type attributes 

452 wrapper.input_memory_type = func.input_memory_type 

453 wrapper.output_memory_type = func.output_memory_type 

454 

455 return wrapper 

456 

457 # Handle both @cupy and @cupy(input_type=..., output_type=...) forms 

458 if func is None: 458 ↛ 459line 458 didn't jump to line 459 because the condition on line 458 was never true

459 return decorator 

460 

461 return decorator(func) 

462 

463 

464def torch( 

465 func: Optional[F] = None, 

466 *, 

467 input_type: str = "torch", 

468 output_type: str = "torch", 

469 oom_recovery: bool = True 

470) -> Any: 

471 """ 

472 Decorator that declares a function as operating on torch tensors. 

473 

474 This decorator provides automatic thread-local CUDA stream management for 

475 true parallelization across multiple threads. Each thread gets its own 

476 persistent CUDA stream that is reused for all PyTorch operations. 

477 

478 Args: 

479 func: The function to decorate (optional) 

480 input_type: The memory type for the function's input (default: "torch") 

481 output_type: The memory type for the function's output (default: "torch") 

482 oom_recovery: Enable automatic OOM recovery (default: True) 

483 

484 Returns: 

485 The decorated function with memory type attributes and stream management 

486 

487 Raises: 

488 ValueError: If input_type or output_type is not a supported memory type 

489 """ 

490 def decorator(func: F) -> F: 

491 # Set memory type attributes 

492 memory_decorator = memory_types(input_type=input_type, output_type=output_type) 

493 func = memory_decorator(func) 

494 

495 # Apply dtype preservation wrapper 

496 func = _create_torch_dtype_preserving_wrapper(func, func.__name__) 

497 

498 # Add CUDA stream wrapper if PyTorch is available and CUDA is available (lazy import) 

499 @functools.wraps(func) 

500 def wrapper(*args, **kwargs): 

501 torch = _get_torch() 

502 if torch is not None and hasattr(torch, 'cuda') and torch.cuda.is_available(): 

503 # Get unified thread context and PyTorch stream 

504 gpu_context = get_thread_gpu_context() 

505 torch_stream = gpu_context.get_torch_stream() 

506 

507 def execute_with_stream(): 

508 if torch_stream is not None: 

509 # Execute function in stream context 

510 with torch.cuda.stream(torch_stream): 

511 return func(*args, **kwargs) 

512 else: 

513 # No CUDA available, execute without stream 

514 return func(*args, **kwargs) 

515 

516 # Execute with OOM recovery if enabled 

517 if oom_recovery: 

518 return _execute_with_oom_recovery(execute_with_stream, input_type) 

519 else: 

520 return execute_with_stream() 

521 else: 

522 # PyTorch not available or CUDA not available, execute without stream 

523 return func(*args, **kwargs) 

524 

525 # Preserve memory type attributes 

526 wrapper.input_memory_type = func.input_memory_type 

527 wrapper.output_memory_type = func.output_memory_type 

528 

529 return wrapper 

530 

531 # Handle both @torch and @torch(input_type=..., output_type=...) forms 

532 if func is None: 532 ↛ 533line 532 didn't jump to line 533 because the condition on line 532 was never true

533 return decorator 

534 

535 return decorator(func) 

536 

537 

538def tensorflow( 

539 func: Optional[F] = None, 

540 *, 

541 input_type: str = "tensorflow", 

542 output_type: str = "tensorflow", 

543 oom_recovery: bool = True 

544) -> Any: 

545 """ 

546 Decorator that declares a function as operating on tensorflow tensors. 

547 

548 This decorator provides automatic thread-local GPU device context management 

549 for parallelization across multiple threads. TensorFlow manages CUDA streams 

550 internally, so we use device contexts for thread isolation. 

551 

552 Args: 

553 func: The function to decorate (optional) 

554 input_type: The memory type for the function's input (default: "tensorflow") 

555 output_type: The memory type for the function's output (default: "tensorflow") 

556 oom_recovery: Enable automatic OOM recovery (default: True) 

557 

558 Returns: 

559 The decorated function with memory type attributes and device management 

560 

561 Raises: 

562 ValueError: If input_type or output_type is not a supported memory type 

563 """ 

564 def decorator(func: F) -> F: 

565 # Set memory type attributes 

566 memory_decorator = memory_types(input_type=input_type, output_type=output_type) 

567 func = memory_decorator(func) 

568 

569 # Apply dtype preservation wrapper 

570 func = _create_tensorflow_dtype_preserving_wrapper(func, func.__name__) 

571 

572 # Add device context wrapper if TensorFlow is available and GPU is available (lazy import) 

573 @functools.wraps(func) 

574 def wrapper(*args, **kwargs): 

575 tf = _get_tensorflow() 

576 if tf is not None and tf.config.list_physical_devices('GPU'): 

577 def execute_with_device(): 

578 # Use GPU device context for thread isolation 

579 # TensorFlow manages internal CUDA streams automatically 

580 with tf.device('/GPU:0'): 

581 return func(*args, **kwargs) 

582 

583 # Execute with OOM recovery if enabled 

584 if oom_recovery: 

585 return _execute_with_oom_recovery(execute_with_device, input_type) 

586 else: 

587 return execute_with_device() 

588 else: 

589 # TensorFlow not available or GPU not available, execute without device context 

590 return func(*args, **kwargs) 

591 

592 # Preserve memory type attributes 

593 wrapper.input_memory_type = func.input_memory_type 

594 wrapper.output_memory_type = func.output_memory_type 

595 

596 return wrapper 

597 

598 # Handle both @tensorflow and @tensorflow(input_type=..., output_type=...) forms 

599 if func is None: 599 ↛ 600line 599 didn't jump to line 600 because the condition on line 599 was never true

600 return decorator 

601 

602 return decorator(func) 

603 

604 

605def jax( 

606 func: Optional[F] = None, 

607 *, 

608 input_type: str = "jax", 

609 output_type: str = "jax", 

610 oom_recovery: bool = True 

611) -> Any: 

612 """ 

613 Decorator that declares a function as operating on JAX arrays. 

614 

615 This decorator provides automatic thread-local GPU device placement for 

616 parallelization across multiple threads. JAX/XLA manages CUDA streams 

617 internally, so we use device placement for thread isolation. 

618 

619 Args: 

620 func: The function to decorate (optional) 

621 input_type: The memory type for the function's input (default: "jax") 

622 output_type: The memory type for the function's output (default: "jax") 

623 oom_recovery: Enable automatic OOM recovery (default: True) 

624 

625 Returns: 

626 The decorated function with memory type attributes and device management 

627 

628 Raises: 

629 ValueError: If input_type or output_type is not a supported memory type 

630 """ 

631 def decorator(func: F) -> F: 

632 # Set memory type attributes 

633 memory_decorator = memory_types(input_type=input_type, output_type=output_type) 

634 func = memory_decorator(func) 

635 

636 # Apply dtype preservation wrapper 

637 func = _create_jax_dtype_preserving_wrapper(func, func.__name__) 

638 

639 # Add device placement wrapper if JAX is available and GPU is available (lazy import) 

640 @functools.wraps(func) 

641 def wrapper(*args, **kwargs): 

642 jax_module = _get_jax() 

643 if jax_module is not None: 

644 devices = jax_module.devices() 

645 gpu_devices = [d for d in devices if d.platform == 'gpu'] 

646 

647 if gpu_devices: 

648 def execute_with_device(): 

649 # Use GPU device placement for thread isolation 

650 # JAX/XLA manages internal CUDA streams automatically 

651 with jax_module.default_device(gpu_devices[0]): 

652 return func(*args, **kwargs) 

653 

654 # Execute with OOM recovery if enabled 

655 if oom_recovery: 

656 return _execute_with_oom_recovery(execute_with_device, input_type) 

657 else: 

658 return execute_with_device() 

659 else: 

660 # No GPU devices available, execute without device placement 

661 return func(*args, **kwargs) 

662 else: 

663 # JAX not available, execute without device placement 

664 return func(*args, **kwargs) 

665 

666 # Preserve memory type attributes 

667 wrapper.input_memory_type = func.input_memory_type 

668 wrapper.output_memory_type = func.output_memory_type 

669 

670 return wrapper 

671 

672 # Handle both @jax and @jax(input_type=..., output_type=...) forms 

673 if func is None: 673 ↛ 674line 673 didn't jump to line 674 because the condition on line 673 was never true

674 return decorator 

675 

676 return decorator(func) 

677 

678 

679def pyclesperanto( 

680 func: Optional[F] = None, 

681 *, 

682 input_type: str = "pyclesperanto", 

683 output_type: str = "pyclesperanto", 

684 oom_recovery: bool = True 

685) -> Any: 

686 """ 

687 Decorator that declares a function as operating on pyclesperanto GPU arrays. 

688 

689 This decorator provides automatic OOM recovery for pyclesperanto functions. 

690 

691 Args: 

692 func: The function to decorate (optional) 

693 input_type: The memory type for the function's input (default: "pyclesperanto") 

694 output_type: The memory type for the function's output (default: "pyclesperanto") 

695 oom_recovery: Enable automatic OOM recovery (default: True) 

696 

697 Returns: 

698 The decorated function with memory type attributes and OOM recovery 

699 

700 Raises: 

701 ValueError: If input_type or output_type is not a supported memory type 

702 """ 

703 def decorator(func: F) -> F: 

704 # Set memory type attributes 

705 memory_decorator = memory_types(input_type=input_type, output_type=output_type) 

706 func = memory_decorator(func) 

707 

708 # Apply dtype preservation wrapper 

709 func = _create_pyclesperanto_dtype_preserving_wrapper(func, func.__name__) 

710 

711 # Add OOM recovery wrapper 

712 @functools.wraps(func) 

713 def wrapper(*args, **kwargs): 

714 if oom_recovery: 

715 return _execute_with_oom_recovery(lambda: func(*args, **kwargs), input_type) 

716 else: 

717 return func(*args, **kwargs) 

718 

719 # Preserve memory type attributes 

720 wrapper.input_memory_type = func.input_memory_type 

721 wrapper.output_memory_type = func.output_memory_type 

722 

723 # Make wrapper pickleable by preserving original function identity 

724 wrapper.__module__ = getattr(func, '__module__', wrapper.__module__) 

725 wrapper.__qualname__ = getattr(func, '__qualname__', wrapper.__qualname__) 

726 

727 # Store reference to original function for pickle support 

728 wrapper.__wrapped__ = func 

729 

730 return wrapper 

731 

732 # Handle both @pyclesperanto and @pyclesperanto(input_type=..., output_type=...) forms 

733 if func is None: 733 ↛ 734line 733 didn't jump to line 734 because the condition on line 733 was never true

734 return decorator 

735 

736 return decorator(func) 

737 

738 

739# ============================================================================ 

740# Dtype Preservation Wrapper Functions 

741# ============================================================================ 

742 

743def _create_numpy_dtype_preserving_wrapper(original_func, func_name): 

744 """ 

745 Create a wrapper that preserves input data type and adds slice_by_slice parameter for NumPy functions. 

746 

747 Many scikit-image functions return float64 regardless of input type. 

748 This wrapper ensures the output has the same dtype as the input and adds 

749 a slice_by_slice parameter to avoid cross-slice contamination in 3D arrays. 

750 """ 

751 import numpy as np 

752 import inspect 

753 from functools import wraps 

754 

755 @wraps(original_func) 

756 def numpy_dtype_and_slice_preserving_wrapper(image, *args, dtype_conversion=None, slice_by_slice: bool = False, **kwargs): 

757 # Set default dtype_conversion if not provided and DtypeConversion is available 

758 if dtype_conversion is None and DtypeConversion is not None: 758 ↛ 761line 758 didn't jump to line 761 because the condition on line 758 was always true

759 dtype_conversion = DtypeConversion.PRESERVE_INPUT 

760 

761 try: 

762 # Store original dtype 

763 original_dtype = image.dtype 

764 

765 # Handle slice_by_slice processing for 3D arrays using OpenHCS stack utilities 

766 if slice_by_slice and hasattr(image, 'ndim') and image.ndim == 3: 766 ↛ 767line 766 didn't jump to line 767 because the condition on line 766 was never true

767 from openhcs.core.memory.stack_utils import unstack_slices, stack_slices, _detect_memory_type 

768 

769 # Detect memory type and use proper OpenHCS utilities 

770 memory_type = _detect_memory_type(image) 

771 gpu_id = 0 # Default GPU ID for slice processing 

772 

773 # Unstack 3D array into 2D slices 

774 slices_2d = unstack_slices(image, memory_type, gpu_id) 

775 

776 # Process each slice and handle special outputs 

777 main_outputs = [] 

778 special_outputs_list = [] 

779 

780 for slice_2d in slices_2d: 

781 slice_result = original_func(slice_2d, *args, **kwargs) 

782 

783 # Check if result is a tuple (indicating special outputs) 

784 if isinstance(slice_result, tuple): 

785 main_outputs.append(slice_result[0]) # First element is main output 

786 special_outputs_list.append(slice_result[1:]) # Rest are special outputs 

787 else: 

788 main_outputs.append(slice_result) # Single output 

789 

790 # Stack main outputs back into 3D array 

791 result = stack_slices(main_outputs, memory_type, gpu_id) 

792 

793 # If we have special outputs, combine them and return tuple 

794 if special_outputs_list: 

795 # Combine special outputs from all slices 

796 combined_special_outputs = [] 

797 num_special_outputs = len(special_outputs_list[0]) 

798 

799 for i in range(num_special_outputs): 

800 # Collect the i-th special output from all slices 

801 special_output_values = [slice_outputs[i] for slice_outputs in special_outputs_list] 

802 combined_special_outputs.append(special_output_values) 

803 

804 # Return tuple: (stacked_main_output, combined_special_output1, combined_special_output2, ...) 

805 result = (result, *combined_special_outputs) 

806 else: 

807 # Call the original function normally 

808 result = original_func(image, *args, **kwargs) 

809 

810 # Apply dtype conversion based on enum value 

811 if hasattr(result, 'dtype') and dtype_conversion is not None: 

812 if dtype_conversion == DtypeConversion.PRESERVE_INPUT: 812 ↛ 816line 812 didn't jump to line 816 because the condition on line 812 was always true

813 # Preserve input dtype 

814 if result.dtype != original_dtype: 814 ↛ 815line 814 didn't jump to line 815 because the condition on line 814 was never true

815 result = _scale_and_convert_numpy(result, original_dtype) 

816 elif dtype_conversion == DtypeConversion.NATIVE_OUTPUT: 

817 # Return NumPy's native output dtype 

818 pass # No conversion needed 

819 else: 

820 # Force specific dtype 

821 target_dtype = dtype_conversion.numpy_dtype 

822 if target_dtype is not None: 

823 result = _scale_and_convert_numpy(result, target_dtype) 

824 

825 return result 

826 except Exception as e: 

827 logger.error(f"Error in NumPy dtype/slice preserving wrapper for {func_name}: {e}") 

828 # Return original result on error 

829 return original_func(image, *args, **kwargs) 

830 

831 # Update function signature to include new parameters 

832 try: 

833 original_sig = inspect.signature(original_func) 

834 new_params = list(original_sig.parameters.values()) 

835 

836 # Check if slice_by_slice parameter already exists 

837 param_names = [p.name for p in new_params] 

838 # Add dtype_conversion parameter first (before slice_by_slice) 

839 if 'dtype_conversion' not in param_names: 839 ↛ 848line 839 didn't jump to line 848 because the condition on line 839 was always true

840 dtype_param = inspect.Parameter( 

841 'dtype_conversion', 

842 inspect.Parameter.KEYWORD_ONLY, 

843 default=DtypeConversion.PRESERVE_INPUT, 

844 annotation=DtypeConversion 

845 ) 

846 new_params.append(dtype_param) 

847 

848 if 'slice_by_slice' not in param_names: 848 ↛ 859line 848 didn't jump to line 859 because the condition on line 848 was always true

849 # Add slice_by_slice parameter as keyword-only (after dtype_conversion) 

850 slice_param = inspect.Parameter( 

851 'slice_by_slice', 

852 inspect.Parameter.KEYWORD_ONLY, 

853 default=False, 

854 annotation=bool 

855 ) 

856 new_params.append(slice_param) 

857 

858 # Create new signature and override the @wraps signature 

859 new_sig = original_sig.replace(parameters=new_params) 

860 numpy_dtype_and_slice_preserving_wrapper.__signature__ = new_sig 

861 

862 # Set type annotations manually for get_type_hints() compatibility 

863 numpy_dtype_and_slice_preserving_wrapper.__annotations__ = getattr(original_func, '__annotations__', {}).copy() 

864 numpy_dtype_and_slice_preserving_wrapper.__annotations__['dtype_conversion'] = DtypeConversion 

865 numpy_dtype_and_slice_preserving_wrapper.__annotations__['slice_by_slice'] = bool 

866 

867 except Exception: 

868 # If signature modification fails, continue without it 

869 pass 

870 

871 # Update docstring to mention slice_by_slice parameter 

872 original_doc = numpy_dtype_and_slice_preserving_wrapper.__doc__ or "" 

873 additional_doc = """ 

874 

875 Additional OpenHCS Parameters 

876 ----------------------------- 

877 slice_by_slice : bool, optional (default: False) 

878 If True, process 3D arrays slice-by-slice to avoid cross-slice contamination. 

879 If False, use original 3D behavior. Recommended for edge detection functions 

880 on stitched microscopy data to prevent artifacts at field boundaries. 

881 

882 dtype_conversion : DtypeConversion, optional (default: PRESERVE_INPUT) 

883 Controls output data type conversion: 

884 

885 - PRESERVE_INPUT: Keep input dtype (uint16 → uint16) 

886 - NATIVE_OUTPUT: Use NumPy's native output dtype 

887 - UINT8: Force 8-bit unsigned integer (0-255 range) 

888 - UINT16: Force 16-bit unsigned integer (microscopy standard) 

889 - INT16: Force 16-bit signed integer 

890 - INT32: Force 32-bit signed integer 

891 - FLOAT32: Force 32-bit float (GPU performance) 

892 - FLOAT64: Force 64-bit float (maximum precision) 

893 """ 

894 numpy_dtype_and_slice_preserving_wrapper.__doc__ = original_doc + additional_doc 

895 

896 return numpy_dtype_and_slice_preserving_wrapper 

897 

898 

899def _create_cupy_dtype_preserving_wrapper(original_func, func_name): 

900 """ 

901 Create a wrapper that preserves input data type and adds slice_by_slice parameter for CuPy functions. 

902 

903 This uses the SAME pattern as scikit-image for consistency. CuPy functions generally preserve 

904 dtypes better than scikit-image, but this wrapper ensures consistent behavior and adds 

905 slice_by_slice parameter to avoid cross-slice contamination in 3D arrays. 

906 """ 

907 import inspect 

908 from functools import wraps 

909 

910 @wraps(original_func) 

911 def cupy_dtype_and_slice_preserving_wrapper(image, *args, dtype_conversion=None, slice_by_slice: bool = False, **kwargs): 

912 # Set default dtype_conversion if not provided and DtypeConversion is available 

913 if dtype_conversion is None and DtypeConversion is not None: 

914 dtype_conversion = DtypeConversion.PRESERVE_INPUT 

915 

916 try: 

917 cupy = optional_import("cupy") 

918 if cupy is None: 

919 return original_func(image, *args, **kwargs) 

920 

921 # Store original dtype 

922 original_dtype = image.dtype 

923 

924 # Handle slice_by_slice processing for 3D arrays using OpenHCS stack utilities 

925 if slice_by_slice and hasattr(image, 'ndim') and image.ndim == 3: 

926 from openhcs.core.memory.stack_utils import unstack_slices, stack_slices, _detect_memory_type 

927 

928 # Detect memory type and use proper OpenHCS utilities 

929 memory_type = _detect_memory_type(image) 

930 gpu_id = image.device.id if hasattr(image, 'device') else 0 

931 

932 # Unstack 3D array into 2D slices 

933 slices_2d = unstack_slices(image, memory_type, gpu_id) 

934 

935 # Process each slice and handle special outputs 

936 main_outputs = [] 

937 special_outputs_list = [] 

938 

939 for slice_2d in slices_2d: 

940 slice_result = original_func(slice_2d, *args, **kwargs) 

941 

942 # Check if result is a tuple (indicating special outputs) 

943 if isinstance(slice_result, tuple): 

944 main_outputs.append(slice_result[0]) # First element is main output 

945 special_outputs_list.append(slice_result[1:]) # Rest are special outputs 

946 else: 

947 main_outputs.append(slice_result) # Single output 

948 

949 # Stack main outputs back into 3D array 

950 result = stack_slices(main_outputs, memory_type, gpu_id) 

951 

952 # If we have special outputs, combine them and return tuple 

953 if special_outputs_list: 

954 # Combine special outputs from all slices 

955 combined_special_outputs = [] 

956 num_special_outputs = len(special_outputs_list[0]) 

957 

958 for i in range(num_special_outputs): 

959 # Collect the i-th special output from all slices 

960 special_output_values = [slice_outputs[i] for slice_outputs in special_outputs_list] 

961 combined_special_outputs.append(special_output_values) 

962 

963 # Return tuple: (stacked_main_output, combined_special_output1, combined_special_output2, ...) 

964 result = (result, *combined_special_outputs) 

965 else: 

966 # Call the original function normally 

967 result = original_func(image, *args, **kwargs) 

968 

969 # Apply dtype conversion based on enum value 

970 if hasattr(result, 'dtype') and dtype_conversion is not None: 

971 if dtype_conversion == DtypeConversion.PRESERVE_INPUT: 

972 # Preserve input dtype 

973 if result.dtype != original_dtype: 

974 result = _scale_and_convert_cupy(result, original_dtype) 

975 elif dtype_conversion == DtypeConversion.NATIVE_OUTPUT: 

976 # Return CuPy's native output dtype 

977 pass # No conversion needed 

978 else: 

979 # Force specific dtype 

980 target_dtype = dtype_conversion.numpy_dtype 

981 if target_dtype is not None: 

982 result = _scale_and_convert_cupy(result, target_dtype) 

983 

984 return result 

985 except Exception as e: 

986 logger.error(f"Error in CuPy dtype/slice preserving wrapper for {func_name}: {e}") 

987 # Return original result on error 

988 return original_func(image, *args, **kwargs) 

989 

990 # Update function signature to include new parameters 

991 try: 

992 original_sig = inspect.signature(original_func) 

993 new_params = list(original_sig.parameters.values()) 

994 

995 # Check if slice_by_slice parameter already exists 

996 param_names = [p.name for p in new_params] 

997 # Add dtype_conversion parameter first (before slice_by_slice) 

998 if 'dtype_conversion' not in param_names: 998 ↛ 1007line 998 didn't jump to line 1007 because the condition on line 998 was always true

999 dtype_param = inspect.Parameter( 

1000 'dtype_conversion', 

1001 inspect.Parameter.KEYWORD_ONLY, 

1002 default=DtypeConversion.PRESERVE_INPUT, 

1003 annotation=DtypeConversion 

1004 ) 

1005 new_params.append(dtype_param) 

1006 

1007 if 'slice_by_slice' not in param_names: 1007 ↛ 1018line 1007 didn't jump to line 1018 because the condition on line 1007 was always true

1008 # Add slice_by_slice parameter as keyword-only (after dtype_conversion) 

1009 slice_param = inspect.Parameter( 

1010 'slice_by_slice', 

1011 inspect.Parameter.KEYWORD_ONLY, 

1012 default=False, 

1013 annotation=bool 

1014 ) 

1015 new_params.append(slice_param) 

1016 

1017 # Create new signature and override the @wraps signature 

1018 new_sig = original_sig.replace(parameters=new_params) 

1019 cupy_dtype_and_slice_preserving_wrapper.__signature__ = new_sig 

1020 

1021 # Set type annotations manually for get_type_hints() compatibility 

1022 cupy_dtype_and_slice_preserving_wrapper.__annotations__ = getattr(original_func, '__annotations__', {}).copy() 

1023 cupy_dtype_and_slice_preserving_wrapper.__annotations__['dtype_conversion'] = DtypeConversion 

1024 cupy_dtype_and_slice_preserving_wrapper.__annotations__['slice_by_slice'] = bool 

1025 

1026 except Exception: 

1027 # If signature modification fails, continue without it 

1028 pass 

1029 

1030 # Update docstring to mention slice_by_slice parameter 

1031 original_doc = cupy_dtype_and_slice_preserving_wrapper.__doc__ or "" 

1032 additional_doc = """ 

1033 

1034 Additional OpenHCS Parameters 

1035 ----------------------------- 

1036 slice_by_slice : bool, optional (default: False) 

1037 If True, process 3D arrays slice-by-slice to avoid cross-slice contamination. 

1038 If False, use original 3D behavior. Recommended for edge detection functions 

1039 on stitched microscopy data to prevent artifacts at field boundaries. 

1040 

1041 dtype_conversion : DtypeConversion, optional (default: PRESERVE_INPUT) 

1042 Controls output data type conversion: 

1043 

1044 - PRESERVE_INPUT: Keep input dtype (uint16 → uint16) 

1045 - NATIVE_OUTPUT: Use CuPy's native output dtype 

1046 - UINT8: Force 8-bit unsigned integer (0-255 range) 

1047 - UINT16: Force 16-bit unsigned integer (microscopy standard) 

1048 - INT16: Force 16-bit signed integer 

1049 - INT32: Force 32-bit signed integer 

1050 - FLOAT32: Force 32-bit float (GPU performance) 

1051 - FLOAT64: Force 64-bit float (maximum precision) 

1052 """ 

1053 cupy_dtype_and_slice_preserving_wrapper.__doc__ = original_doc + additional_doc 

1054 

1055 return cupy_dtype_and_slice_preserving_wrapper 

1056 

1057 

1058def _create_torch_dtype_preserving_wrapper(original_func, func_name): 

1059 """ 

1060 Create a wrapper that preserves input data type and adds slice_by_slice parameter for PyTorch functions. 

1061 

1062 This follows the same pattern as existing dtype preservation wrappers for consistency. 

1063 PyTorch functions generally preserve dtypes well, but this wrapper ensures consistent behavior 

1064 and adds slice_by_slice parameter to avoid cross-slice contamination in 3D arrays. 

1065 """ 

1066 import inspect 

1067 from functools import wraps 

1068 

1069 @wraps(original_func) 

1070 def torch_dtype_and_slice_preserving_wrapper(image, *args, dtype_conversion=None, slice_by_slice: bool = False, **kwargs): 

1071 # Set default dtype_conversion if not provided 

1072 if dtype_conversion is None: 

1073 dtype_conversion = DtypeConversion.PRESERVE_INPUT 

1074 

1075 try: 

1076 torch = optional_import("torch") 

1077 if torch is None: 

1078 return original_func(image, *args, **kwargs) 

1079 

1080 # Store original dtype 

1081 original_dtype = image.dtype if hasattr(image, 'dtype') else None 

1082 

1083 # Handle slice_by_slice processing for 3D arrays 

1084 if slice_by_slice and hasattr(image, 'ndim') and image.ndim == 3: 

1085 from openhcs.core.memory.stack_utils import unstack_slices, stack_slices, _detect_memory_type 

1086 

1087 # Detect memory type and use proper OpenHCS utilities 

1088 memory_type = _detect_memory_type(image) 

1089 gpu_id = image.device.index if hasattr(image, 'device') and image.device.type == 'cuda' else 0 

1090 

1091 # Unstack 3D array into 2D slices 

1092 slices_2d = unstack_slices(image, memory_type=memory_type, gpu_id=gpu_id) 

1093 

1094 # Process each slice and handle special outputs 

1095 main_outputs = [] 

1096 special_outputs_list = [] 

1097 

1098 for slice_2d in slices_2d: 

1099 slice_result = original_func(slice_2d, *args, **kwargs) 

1100 

1101 # Check if result is a tuple (indicating special outputs) 

1102 if isinstance(slice_result, tuple): 

1103 main_outputs.append(slice_result[0]) # First element is main output 

1104 special_outputs_list.append(slice_result[1:]) # Rest are special outputs 

1105 else: 

1106 main_outputs.append(slice_result) # Single output 

1107 

1108 # Stack main outputs back into 3D array 

1109 result = stack_slices(main_outputs, memory_type=memory_type, gpu_id=gpu_id) 

1110 

1111 # If we have special outputs, combine them and return tuple 

1112 if special_outputs_list: 

1113 # Combine special outputs from all slices 

1114 combined_special_outputs = [] 

1115 num_special_outputs = len(special_outputs_list[0]) 

1116 

1117 for i in range(num_special_outputs): 

1118 # Collect the i-th special output from all slices 

1119 special_output_values = [slice_outputs[i] for slice_outputs in special_outputs_list] 

1120 combined_special_outputs.append(special_output_values) 

1121 

1122 # Return tuple: (stacked_main_output, combined_special_output1, combined_special_output2, ...) 

1123 result = (result, *combined_special_outputs) 

1124 else: 

1125 # Process normally 

1126 result = original_func(image, *args, **kwargs) 

1127 

1128 # Apply dtype conversion if result is a tensor and we have dtype conversion info 

1129 if (hasattr(result, 'dtype') and hasattr(result, 'shape') and 

1130 original_dtype is not None and dtype_conversion is not None): 

1131 

1132 if dtype_conversion == DtypeConversion.PRESERVE_INPUT: 

1133 # Preserve input dtype 

1134 if result.dtype != original_dtype: 

1135 result = result.to(original_dtype) 

1136 elif dtype_conversion == DtypeConversion.NATIVE_OUTPUT: 

1137 # Return PyTorch's native output dtype 

1138 pass # No conversion needed 

1139 else: 

1140 # Force specific dtype 

1141 target_dtype = dtype_conversion.numpy_dtype 

1142 if target_dtype is not None: 

1143 # Map numpy dtypes to torch dtypes 

1144 import numpy as np 

1145 numpy_to_torch = { 

1146 np.uint8: torch.uint8, 

1147 np.uint16: torch.int32, # PyTorch doesn't have uint16, use int32 

1148 np.int16: torch.int16, 

1149 np.int32: torch.int32, 

1150 np.float32: torch.float32, 

1151 np.float64: torch.float64, 

1152 } 

1153 torch_dtype = numpy_to_torch.get(target_dtype) 

1154 if torch_dtype is not None: 

1155 result = result.to(torch_dtype) 

1156 

1157 return result 

1158 

1159 except Exception as e: 

1160 logger.error(f"Error in PyTorch dtype/slice preserving wrapper for {func_name}: {e}") 

1161 # Return original result on error 

1162 return original_func(image, *args, **kwargs) 

1163 

1164 # Update function signature to include new parameters 

1165 try: 

1166 original_sig = inspect.signature(original_func) 

1167 new_params = list(original_sig.parameters.values()) 

1168 

1169 # Add dtype_conversion parameter first (before slice_by_slice) 

1170 param_names = [p.name for p in new_params] 

1171 if 'dtype_conversion' not in param_names: 1171 ↛ 1181line 1171 didn't jump to line 1181 because the condition on line 1171 was always true

1172 dtype_param = inspect.Parameter( 

1173 'dtype_conversion', 

1174 inspect.Parameter.KEYWORD_ONLY, 

1175 default=DtypeConversion.PRESERVE_INPUT, 

1176 annotation=DtypeConversion 

1177 ) 

1178 new_params.append(dtype_param) 

1179 

1180 # Add slice_by_slice parameter after dtype_conversion 

1181 if 'slice_by_slice' not in param_names: 

1182 slice_param = inspect.Parameter( 

1183 'slice_by_slice', 

1184 inspect.Parameter.KEYWORD_ONLY, 

1185 default=False, 

1186 annotation=bool 

1187 ) 

1188 new_params.append(slice_param) 

1189 

1190 new_sig = original_sig.replace(parameters=new_params) 

1191 torch_dtype_and_slice_preserving_wrapper.__signature__ = new_sig 

1192 

1193 # Set type annotations manually for get_type_hints() compatibility 

1194 torch_dtype_and_slice_preserving_wrapper.__annotations__ = getattr(original_func, '__annotations__', {}).copy() 

1195 if DtypeConversion is not None: 1195 ↛ 1197line 1195 didn't jump to line 1197 because the condition on line 1195 was always true

1196 torch_dtype_and_slice_preserving_wrapper.__annotations__['dtype_conversion'] = DtypeConversion 

1197 torch_dtype_and_slice_preserving_wrapper.__annotations__['slice_by_slice'] = bool 

1198 

1199 except Exception: 

1200 # If signature modification fails, continue without it 

1201 pass 

1202 

1203 # Update docstring to mention new parameters 

1204 original_doc = torch_dtype_and_slice_preserving_wrapper.__doc__ or "" 

1205 additional_doc = """ 

1206 

1207 Additional OpenHCS Parameters 

1208 ----------------------------- 

1209 slice_by_slice : bool, optional (default: False) 

1210 If True, process 3D arrays slice-by-slice to avoid cross-slice contamination. 

1211 If False, use original 3D behavior. Recommended for edge detection functions 

1212 on stitched microscopy data to prevent artifacts at field boundaries. 

1213 

1214 dtype_conversion : DtypeConversion, optional (default: PRESERVE_INPUT) 

1215 Controls output data type conversion: 

1216 

1217 - PRESERVE_INPUT: Keep input dtype (uint16 → uint16) 

1218 - NATIVE_OUTPUT: Use PyTorch's native output dtype 

1219 - UINT8: Force 8-bit unsigned integer (0-255 range) 

1220 - UINT16: Force 16-bit unsigned integer (mapped to int32 in PyTorch) 

1221 - INT16: Force 16-bit signed integer 

1222 - INT32: Force 32-bit signed integer 

1223 - FLOAT32: Force 32-bit float (GPU performance) 

1224 - FLOAT64: Force 64-bit float (maximum precision) 

1225 """ 

1226 torch_dtype_and_slice_preserving_wrapper.__doc__ = original_doc + additional_doc 

1227 

1228 return torch_dtype_and_slice_preserving_wrapper 

1229 

1230 

1231def _create_tensorflow_dtype_preserving_wrapper(original_func, func_name): 

1232 """ 

1233 Create a wrapper that preserves input data type and adds slice_by_slice parameter for TensorFlow functions. 

1234 

1235 This follows the same pattern as existing dtype preservation wrappers for consistency. 

1236 TensorFlow functions generally preserve dtypes well, but this wrapper ensures consistent behavior 

1237 and adds slice_by_slice parameter to avoid cross-slice contamination in 3D arrays. 

1238 """ 

1239 import inspect 

1240 from functools import wraps 

1241 

1242 @wraps(original_func) 

1243 def tensorflow_dtype_and_slice_preserving_wrapper(image, *args, dtype_conversion=None, slice_by_slice: bool = False, **kwargs): 

1244 # Set default dtype_conversion if not provided 

1245 if dtype_conversion is None: 

1246 dtype_conversion = DtypeConversion.PRESERVE_INPUT 

1247 

1248 try: 

1249 tf = optional_import("tensorflow") 

1250 if tf is None: 

1251 return original_func(image, *args, **kwargs) 

1252 

1253 # Store original dtype 

1254 original_dtype = image.dtype if hasattr(image, 'dtype') else None 

1255 

1256 # Handle slice_by_slice processing for 3D arrays 

1257 if slice_by_slice and hasattr(image, 'ndim') and image.ndim == 3: 

1258 from openhcs.core.memory.stack_utils import unstack_slices, stack_slices, _detect_memory_type 

1259 

1260 # Detect memory type and use proper OpenHCS utilities 

1261 memory_type = _detect_memory_type(image) 

1262 gpu_id = 0 # TensorFlow manages GPU placement internally 

1263 

1264 # Unstack 3D array into 2D slices 

1265 slices_2d = unstack_slices(image, memory_type=memory_type, gpu_id=gpu_id) 

1266 

1267 # Process each slice and handle special outputs 

1268 main_outputs = [] 

1269 special_outputs_list = [] 

1270 

1271 for slice_2d in slices_2d: 

1272 slice_result = original_func(slice_2d, *args, **kwargs) 

1273 

1274 # Check if result is a tuple (indicating special outputs) 

1275 if isinstance(slice_result, tuple): 

1276 main_outputs.append(slice_result[0]) # First element is main output 

1277 special_outputs_list.append(slice_result[1:]) # Rest are special outputs 

1278 else: 

1279 main_outputs.append(slice_result) # Single output 

1280 

1281 # Stack main outputs back into 3D array 

1282 result = stack_slices(main_outputs, memory_type=memory_type, gpu_id=gpu_id) 

1283 

1284 # If we have special outputs, combine them and return tuple 

1285 if special_outputs_list: 

1286 # Combine special outputs from all slices 

1287 combined_special_outputs = [] 

1288 num_special_outputs = len(special_outputs_list[0]) 

1289 

1290 for i in range(num_special_outputs): 

1291 # Collect the i-th special output from all slices 

1292 special_output_values = [slice_outputs[i] for slice_outputs in special_outputs_list] 

1293 combined_special_outputs.append(special_output_values) 

1294 

1295 # Return tuple: (stacked_main_output, combined_special_output1, combined_special_output2, ...) 

1296 result = (result, *combined_special_outputs) 

1297 else: 

1298 # Process normally 

1299 result = original_func(image, *args, **kwargs) 

1300 

1301 # Apply dtype conversion if result is a tensor and we have dtype conversion info 

1302 if (hasattr(result, 'dtype') and hasattr(result, 'shape') and 

1303 original_dtype is not None and dtype_conversion is not None): 

1304 

1305 if dtype_conversion == DtypeConversion.PRESERVE_INPUT: 

1306 # Preserve input dtype 

1307 if result.dtype != original_dtype: 

1308 result = tf.cast(result, original_dtype) 

1309 elif dtype_conversion == DtypeConversion.NATIVE_OUTPUT: 

1310 # Return TensorFlow's native output dtype 

1311 pass # No conversion needed 

1312 else: 

1313 # Force specific dtype 

1314 target_dtype = dtype_conversion.numpy_dtype 

1315 if target_dtype is not None: 

1316 # Convert numpy dtype to tensorflow dtype 

1317 import numpy as np 

1318 numpy_to_tf = { 

1319 np.uint8: tf.uint8, 

1320 np.uint16: tf.uint16, 

1321 np.int16: tf.int16, 

1322 np.int32: tf.int32, 

1323 np.float32: tf.float32, 

1324 np.float64: tf.float64, 

1325 } 

1326 tf_dtype = numpy_to_tf.get(target_dtype) 

1327 if tf_dtype is not None: 

1328 result = tf.cast(result, tf_dtype) 

1329 

1330 return result 

1331 

1332 except Exception as e: 

1333 logger.error(f"Error in TensorFlow dtype/slice preserving wrapper for {func_name}: {e}") 

1334 # Return original result on error 

1335 return original_func(image, *args, **kwargs) 

1336 

1337 # Update function signature to include new parameters 

1338 try: 

1339 original_sig = inspect.signature(original_func) 

1340 new_params = list(original_sig.parameters.values()) 

1341 

1342 # Add slice_by_slice parameter if not already present 

1343 param_names = [p.name for p in new_params] 

1344 if 'slice_by_slice' not in param_names: 1344 ↛ 1354line 1344 didn't jump to line 1354 because the condition on line 1344 was always true

1345 slice_param = inspect.Parameter( 

1346 'slice_by_slice', 

1347 inspect.Parameter.KEYWORD_ONLY, 

1348 default=False, 

1349 annotation=bool 

1350 ) 

1351 new_params.append(slice_param) 

1352 

1353 # Add dtype_conversion parameter if DtypeConversion is available 

1354 if DtypeConversion is not None and 'dtype_conversion' not in param_names: 1354 ↛ 1363line 1354 didn't jump to line 1363 because the condition on line 1354 was always true

1355 dtype_param = inspect.Parameter( 

1356 'dtype_conversion', 

1357 inspect.Parameter.KEYWORD_ONLY, 

1358 default=DtypeConversion.PRESERVE_INPUT, 

1359 annotation=DtypeConversion 

1360 ) 

1361 new_params.append(dtype_param) 

1362 

1363 new_sig = original_sig.replace(parameters=new_params) 

1364 tensorflow_dtype_and_slice_preserving_wrapper.__signature__ = new_sig 

1365 

1366 # Set type annotations manually for get_type_hints() compatibility 

1367 tensorflow_dtype_and_slice_preserving_wrapper.__annotations__ = getattr(original_func, '__annotations__', {}).copy() 

1368 tensorflow_dtype_and_slice_preserving_wrapper.__annotations__['slice_by_slice'] = bool 

1369 if DtypeConversion is not None: 1369 ↛ 1377line 1369 didn't jump to line 1377 because the condition on line 1369 was always true

1370 tensorflow_dtype_and_slice_preserving_wrapper.__annotations__['dtype_conversion'] = DtypeConversion 

1371 

1372 except Exception: 

1373 # If signature modification fails, continue without it 

1374 pass 

1375 

1376 # Update docstring to mention new parameters 

1377 original_doc = tensorflow_dtype_and_slice_preserving_wrapper.__doc__ or "" 

1378 additional_doc = """ 

1379 

1380 Additional OpenHCS Parameters 

1381 ----------------------------- 

1382 slice_by_slice : bool, optional (default: False) 

1383 If True, process 3D arrays slice-by-slice to avoid cross-slice contamination. 

1384 If False, use original 3D behavior. Recommended for edge detection functions 

1385 on stitched microscopy data to prevent artifacts at field boundaries. 

1386 

1387 dtype_conversion : DtypeConversion, optional (default: PRESERVE_INPUT) 

1388 Controls output data type conversion: 

1389 

1390 - PRESERVE_INPUT: Keep input dtype (uint16 → uint16) 

1391 - NATIVE_OUTPUT: Use TensorFlow's native output dtype 

1392 - UINT8: Force 8-bit unsigned integer (0-255 range) 

1393 - UINT16: Force 16-bit unsigned integer (microscopy standard) 

1394 - INT16: Force 16-bit signed integer 

1395 - INT32: Force 32-bit signed integer 

1396 - FLOAT32: Force 32-bit float (GPU performance) 

1397 - FLOAT64: Force 64-bit float (maximum precision) 

1398 """ 

1399 tensorflow_dtype_and_slice_preserving_wrapper.__doc__ = original_doc + additional_doc 

1400 

1401 return tensorflow_dtype_and_slice_preserving_wrapper 

1402 

1403 

1404def _create_jax_dtype_preserving_wrapper(original_func, func_name): 

1405 """ 

1406 Create a wrapper that preserves input data type and adds slice_by_slice parameter for JAX functions. 

1407 

1408 This follows the same pattern as existing dtype preservation wrappers for consistency. 

1409 JAX functions generally preserve dtypes well, but this wrapper ensures consistent behavior 

1410 and adds slice_by_slice parameter to avoid cross-slice contamination in 3D arrays. 

1411 """ 

1412 import inspect 

1413 from functools import wraps 

1414 

1415 @wraps(original_func) 

1416 def jax_dtype_and_slice_preserving_wrapper(image, *args, dtype_conversion=None, slice_by_slice: bool = False, **kwargs): 

1417 # Set default dtype_conversion if not provided 

1418 if dtype_conversion is None: 

1419 dtype_conversion = DtypeConversion.PRESERVE_INPUT 

1420 

1421 try: 

1422 jax = optional_import("jax") 

1423 jnp = optional_import("jax.numpy") if jax is not None else None 

1424 if jax is None or jnp is None: 

1425 return original_func(image, *args, **kwargs) 

1426 

1427 # Store original dtype 

1428 original_dtype = image.dtype if hasattr(image, 'dtype') else None 

1429 

1430 # Handle slice_by_slice processing for 3D arrays 

1431 if slice_by_slice and hasattr(image, 'ndim') and image.ndim == 3: 

1432 from openhcs.core.memory.stack_utils import unstack_slices, stack_slices, _detect_memory_type 

1433 

1434 # Detect memory type and use proper OpenHCS utilities 

1435 memory_type = _detect_memory_type(image) 

1436 gpu_id = 0 # JAX manages GPU placement internally 

1437 

1438 # Unstack 3D array into 2D slices 

1439 slices_2d = unstack_slices(image, memory_type=memory_type, gpu_id=gpu_id) 

1440 

1441 # Process each slice and handle special outputs 

1442 main_outputs = [] 

1443 special_outputs_list = [] 

1444 

1445 for slice_2d in slices_2d: 

1446 slice_result = original_func(slice_2d, *args, **kwargs) 

1447 

1448 # Check if result is a tuple (indicating special outputs) 

1449 if isinstance(slice_result, tuple): 

1450 main_outputs.append(slice_result[0]) # First element is main output 

1451 special_outputs_list.append(slice_result[1:]) # Rest are special outputs 

1452 else: 

1453 main_outputs.append(slice_result) # Single output 

1454 

1455 # Stack main outputs back into 3D array 

1456 result = stack_slices(main_outputs, memory_type=memory_type, gpu_id=gpu_id) 

1457 

1458 # If we have special outputs, combine them and return tuple 

1459 if special_outputs_list: 

1460 # Combine special outputs from all slices 

1461 combined_special_outputs = [] 

1462 num_special_outputs = len(special_outputs_list[0]) 

1463 

1464 for i in range(num_special_outputs): 

1465 # Collect the i-th special output from all slices 

1466 special_output_values = [slice_outputs[i] for slice_outputs in special_outputs_list] 

1467 combined_special_outputs.append(special_output_values) 

1468 

1469 # Return tuple: (stacked_main_output, combined_special_output1, combined_special_output2, ...) 

1470 result = (result, *combined_special_outputs) 

1471 else: 

1472 # Process normally 

1473 result = original_func(image, *args, **kwargs) 

1474 

1475 # Apply dtype conversion if result is an array and we have dtype conversion info 

1476 if (hasattr(result, 'dtype') and hasattr(result, 'shape') and 

1477 original_dtype is not None and dtype_conversion is not None): 

1478 

1479 if dtype_conversion == DtypeConversion.PRESERVE_INPUT: 

1480 # Preserve input dtype 

1481 if result.dtype != original_dtype: 

1482 result = result.astype(original_dtype) 

1483 elif dtype_conversion == DtypeConversion.NATIVE_OUTPUT: 

1484 # Return JAX's native output dtype 

1485 pass # No conversion needed 

1486 else: 

1487 # Force specific dtype 

1488 target_dtype = dtype_conversion.numpy_dtype 

1489 if target_dtype is not None: 

1490 # JAX uses numpy-compatible dtypes 

1491 result = result.astype(target_dtype) 

1492 

1493 return result 

1494 

1495 except Exception as e: 

1496 logger.error(f"Error in JAX dtype/slice preserving wrapper for {func_name}: {e}") 

1497 # Return original result on error 

1498 return original_func(image, *args, **kwargs) 

1499 

1500 # Update function signature to include new parameters 

1501 try: 

1502 original_sig = inspect.signature(original_func) 

1503 new_params = list(original_sig.parameters.values()) 

1504 

1505 # Add slice_by_slice parameter if not already present 

1506 param_names = [p.name for p in new_params] 

1507 if 'slice_by_slice' not in param_names: 

1508 slice_param = inspect.Parameter( 

1509 'slice_by_slice', 

1510 inspect.Parameter.KEYWORD_ONLY, 

1511 default=False, 

1512 annotation=bool 

1513 ) 

1514 new_params.append(slice_param) 

1515 

1516 # Add dtype_conversion parameter if DtypeConversion is available 

1517 if DtypeConversion is not None and 'dtype_conversion' not in param_names: 1517 ↛ 1526line 1517 didn't jump to line 1526 because the condition on line 1517 was always true

1518 dtype_param = inspect.Parameter( 

1519 'dtype_conversion', 

1520 inspect.Parameter.KEYWORD_ONLY, 

1521 default=DtypeConversion.PRESERVE_INPUT, 

1522 annotation=DtypeConversion 

1523 ) 

1524 new_params.append(dtype_param) 

1525 

1526 new_sig = original_sig.replace(parameters=new_params) 

1527 jax_dtype_and_slice_preserving_wrapper.__signature__ = new_sig 

1528 

1529 # Set type annotations manually for get_type_hints() compatibility 

1530 jax_dtype_and_slice_preserving_wrapper.__annotations__ = getattr(original_func, '__annotations__', {}).copy() 

1531 jax_dtype_and_slice_preserving_wrapper.__annotations__['slice_by_slice'] = bool 

1532 if DtypeConversion is not None: 1532 ↛ 1540line 1532 didn't jump to line 1540 because the condition on line 1532 was always true

1533 jax_dtype_and_slice_preserving_wrapper.__annotations__['dtype_conversion'] = DtypeConversion 

1534 

1535 except Exception: 

1536 # If signature modification fails, continue without it 

1537 pass 

1538 

1539 # Update docstring to mention new parameters 

1540 original_doc = jax_dtype_and_slice_preserving_wrapper.__doc__ or "" 

1541 additional_doc = """ 

1542 

1543 Additional OpenHCS Parameters 

1544 ----------------------------- 

1545 slice_by_slice : bool, optional (default: False) 

1546 If True, process 3D arrays slice-by-slice to avoid cross-slice contamination. 

1547 If False, use original 3D behavior. Recommended for edge detection functions 

1548 on stitched microscopy data to prevent artifacts at field boundaries. 

1549 

1550 dtype_conversion : DtypeConversion, optional (default: PRESERVE_INPUT) 

1551 Controls output data type conversion: 

1552 

1553 - PRESERVE_INPUT: Keep input dtype (uint16 → uint16) 

1554 - NATIVE_OUTPUT: Use JAX's native output dtype 

1555 - UINT8: Force 8-bit unsigned integer (0-255 range) 

1556 - UINT16: Force 16-bit unsigned integer (microscopy standard) 

1557 - INT16: Force 16-bit signed integer 

1558 - INT32: Force 32-bit signed integer 

1559 - FLOAT32: Force 32-bit float (GPU performance) 

1560 - FLOAT64: Force 64-bit float (maximum precision) 

1561 """ 

1562 jax_dtype_and_slice_preserving_wrapper.__doc__ = original_doc + additional_doc 

1563 

1564 return jax_dtype_and_slice_preserving_wrapper 

1565 

1566 

1567def _create_pyclesperanto_dtype_preserving_wrapper(original_func, func_name): 

1568 """ 

1569 Create a wrapper that ensures array-in/array-out compliance and dtype preservation for pyclesperanto functions. 

1570 

1571 All OpenHCS functions must: 

1572 1. Take 3D pyclesperanto array as first argument 

1573 2. Return 3D pyclesperanto array as first output 

1574 3. Additional outputs (values, coordinates) as 2nd, 3rd, etc. returns 

1575 4. Preserve input dtype when appropriate 

1576 """ 

1577 import inspect 

1578 from functools import wraps 

1579 

1580 @wraps(original_func) 

1581 def pyclesperanto_dtype_and_slice_preserving_wrapper(image_3d, *args, dtype_conversion=None, slice_by_slice: bool = False, **kwargs): 

1582 # Set default dtype_conversion if not provided 

1583 if dtype_conversion is None: 

1584 dtype_conversion = DtypeConversion.PRESERVE_INPUT 

1585 

1586 try: 

1587 cle = optional_import("pyclesperanto") 

1588 if cle is None: 

1589 return original_func(image_3d, *args, **kwargs) 

1590 

1591 # Store original dtype for preservation 

1592 original_dtype = image_3d.dtype 

1593 

1594 # Handle slice_by_slice processing for 3D arrays using OpenHCS stack utilities 

1595 if slice_by_slice and hasattr(image_3d, 'ndim') and image_3d.ndim == 3: 

1596 from openhcs.core.memory.stack_utils import unstack_slices, stack_slices, _detect_memory_type 

1597 

1598 # Detect memory type and use proper OpenHCS utilities 

1599 memory_type = _detect_memory_type(image_3d) 

1600 gpu_id = 0 # pyclesperanto manages GPU internally 

1601 

1602 # Process each slice and handle special outputs 

1603 slices = unstack_slices(image_3d, memory_type, gpu_id) 

1604 main_outputs = [] 

1605 special_outputs_list = [] 

1606 

1607 for slice_2d in slices: 

1608 # Apply function to 2D slice 

1609 result_slice = original_func(slice_2d, *args, **kwargs) 

1610 

1611 # Check if result is a tuple (indicating special outputs) 

1612 if isinstance(result_slice, tuple): 

1613 main_outputs.append(result_slice[0]) # First element is main output 

1614 special_outputs_list.append(result_slice[1:]) # Rest are special outputs 

1615 else: 

1616 main_outputs.append(result_slice) # Single output 

1617 

1618 # Stack main outputs back into 3D array 

1619 result = stack_slices(main_outputs, memory_type, gpu_id) 

1620 

1621 # If we have special outputs, combine them and return tuple 

1622 if special_outputs_list: 

1623 # Combine special outputs from all slices 

1624 combined_special_outputs = [] 

1625 num_special_outputs = len(special_outputs_list[0]) 

1626 

1627 for i in range(num_special_outputs): 

1628 # Collect the i-th special output from all slices 

1629 special_output_values = [slice_outputs[i] for slice_outputs in special_outputs_list] 

1630 combined_special_outputs.append(special_output_values) 

1631 

1632 # Return tuple: (stacked_main_output, combined_special_output1, combined_special_output2, ...) 

1633 result = (result, *combined_special_outputs) 

1634 else: 

1635 # Normal 3D processing 

1636 result = original_func(image_3d, *args, **kwargs) 

1637 

1638 # Check if result is 2D and needs expansion to 3D 

1639 if hasattr(result, 'ndim') and result.ndim == 2: 

1640 # Expand 2D result to 3D single slice 

1641 try: 

1642 # Concatenate with itself to create 3D, then take first slice 

1643 temp_3d = cle.concatenate_along_z(result, result) # Creates (2, Y, X) 

1644 result = temp_3d[0:1, :, :] # Take first slice to get (1, Y, X) 

1645 except Exception: 

1646 # If expansion fails, return original 2D result 

1647 # This maintains backward compatibility 

1648 pass 

1649 

1650 # Apply dtype conversion based on enum value 

1651 if hasattr(result, 'dtype') and hasattr(result, 'shape') and dtype_conversion is not None: 

1652 if dtype_conversion == DtypeConversion.PRESERVE_INPUT: 

1653 # Preserve input dtype 

1654 if result.dtype != original_dtype: 

1655 return _scale_and_convert_pyclesperanto(result, original_dtype) 

1656 return result 

1657 

1658 elif dtype_conversion == DtypeConversion.NATIVE_OUTPUT: 

1659 # Return pyclesperanto's native output dtype 

1660 return result 

1661 

1662 else: 

1663 # Force specific dtype 

1664 target_dtype = dtype_conversion.numpy_dtype 

1665 if target_dtype is not None and result.dtype != target_dtype: 

1666 return _scale_and_convert_pyclesperanto(result, target_dtype) 

1667 return result 

1668 else: 

1669 # Non-array result, return as-is 

1670 return result 

1671 

1672 except Exception as e: 

1673 logger.error(f"Error in pyclesperanto dtype/slice preserving wrapper for {func_name}: {e}") 

1674 # If anything goes wrong, fall back to original function 

1675 return original_func(image_3d, *args, **kwargs) 

1676 

1677 # Update function signature to include new parameters 

1678 try: 

1679 original_sig = inspect.signature(original_func) 

1680 new_params = list(original_sig.parameters.values()) 

1681 

1682 # Add slice_by_slice parameter if not already present 

1683 param_names = [p.name for p in new_params] 

1684 if 'slice_by_slice' not in param_names: 1684 ↛ 1694line 1684 didn't jump to line 1694 because the condition on line 1684 was always true

1685 slice_param = inspect.Parameter( 

1686 'slice_by_slice', 

1687 inspect.Parameter.KEYWORD_ONLY, 

1688 default=False, 

1689 annotation=bool 

1690 ) 

1691 new_params.append(slice_param) 

1692 

1693 # Add dtype_conversion parameter if DtypeConversion is available 

1694 if DtypeConversion is not None and 'dtype_conversion' not in param_names: 1694 ↛ 1703line 1694 didn't jump to line 1703 because the condition on line 1694 was always true

1695 dtype_param = inspect.Parameter( 

1696 'dtype_conversion', 

1697 inspect.Parameter.KEYWORD_ONLY, 

1698 default=DtypeConversion.PRESERVE_INPUT, 

1699 annotation=DtypeConversion 

1700 ) 

1701 new_params.append(dtype_param) 

1702 

1703 new_sig = original_sig.replace(parameters=new_params) 

1704 pyclesperanto_dtype_and_slice_preserving_wrapper.__signature__ = new_sig 

1705 

1706 # Set type annotations manually for get_type_hints() compatibility 

1707 pyclesperanto_dtype_and_slice_preserving_wrapper.__annotations__ = getattr(original_func, '__annotations__', {}).copy() 

1708 pyclesperanto_dtype_and_slice_preserving_wrapper.__annotations__['slice_by_slice'] = bool 

1709 if DtypeConversion is not None: 1709 ↛ 1717line 1709 didn't jump to line 1717 because the condition on line 1709 was always true

1710 pyclesperanto_dtype_and_slice_preserving_wrapper.__annotations__['dtype_conversion'] = DtypeConversion 

1711 

1712 except Exception: 

1713 # If signature modification fails, continue without it 

1714 pass 

1715 

1716 # Update docstring to mention additional parameters 

1717 original_doc = pyclesperanto_dtype_and_slice_preserving_wrapper.__doc__ or "" 

1718 additional_doc = """ 

1719 

1720 Additional OpenHCS Parameters 

1721 ----------------------------- 

1722 slice_by_slice : bool, optional (default: False) 

1723 If True, process 3D arrays slice-by-slice to avoid cross-slice contamination. 

1724 If False, use original 3D behavior. Recommended for edge detection functions 

1725 on stitched microscopy data to prevent artifacts at field boundaries. 

1726 

1727 dtype_conversion : DtypeConversion, optional (default: PRESERVE_INPUT) 

1728 Controls output data type conversion: 

1729 

1730 - PRESERVE_INPUT: Keep input dtype (uint16 → uint16) 

1731 - NATIVE_OUTPUT: Use pyclesperanto's native output (often float32) 

1732 - UINT8: Force 8-bit unsigned integer (0-255 range) 

1733 - UINT16: Force 16-bit unsigned integer (microscopy standard) 

1734 - INT16: Force 16-bit signed integer 

1735 - INT32: Force 32-bit signed integer 

1736 - FLOAT32: Force 32-bit float (GPU performance) 

1737 - FLOAT64: Force 64-bit float (maximum precision) 

1738 """ 

1739 pyclesperanto_dtype_and_slice_preserving_wrapper.__doc__ = original_doc + additional_doc 

1740 

1741 return pyclesperanto_dtype_and_slice_preserving_wrapper 

1742 

1743 

1744def _scale_and_convert_pyclesperanto(result, target_dtype): 

1745 """ 

1746 Scale and convert pyclesperanto array to target dtype. 

1747 This is a simplified version of the helper function from pyclesperanto_registry.py 

1748 """ 

1749 try: 

1750 cle = optional_import("pyclesperanto") 

1751 if cle is None: 

1752 return result 

1753 

1754 import numpy as np 

1755 

1756 # If result is floating point and target is integer, scale appropriately 

1757 if np.issubdtype(result.dtype, np.floating) and not np.issubdtype(target_dtype, np.floating): 

1758 # Convert to numpy for scaling, then back to pyclesperanto 

1759 result_np = cle.pull(result) 

1760 

1761 # Clip to [0, 1] range and scale to integer range 

1762 clipped = np.clip(result_np, 0, 1) 

1763 if target_dtype == np.uint8: 

1764 scaled = (clipped * 255).astype(target_dtype) 

1765 elif target_dtype == np.uint16: 

1766 scaled = (clipped * 65535).astype(target_dtype) 

1767 elif target_dtype == np.uint32: 

1768 scaled = (clipped * 4294967295).astype(target_dtype) 

1769 else: 

1770 # For other integer types, just convert without scaling 

1771 scaled = clipped.astype(target_dtype) 

1772 

1773 # Push back to GPU 

1774 return cle.push(scaled) 

1775 else: 

1776 # Direct conversion for same numeric type families 

1777 result_np = cle.pull(result) 

1778 converted = result_np.astype(target_dtype) 

1779 return cle.push(converted) 

1780 

1781 except Exception: 

1782 # If conversion fails, return original result 

1783 return result