Coverage for src/arraybridge/decorators.py: 63%

177 statements  

« prev     ^ index     » next       coverage.py v7.11.0, created at 2025-11-03 05:09 +0000

1""" 

2Memory type declaration decorators for OpenHCS. 

3 

4This module provides decorators for explicitly declaring the memory interface 

5of pure functions, enforcing Clause 106-A (Declared Memory Types) and supporting 

6memory-type-aware dispatching and orchestration. 

7 

8These decorators annotate functions with input_memory_type and output_memory_type 

9attributes and provide automatic thread-local CUDA stream management for GPU 

10frameworks to enable true parallelization across multiple threads. 

11 

12REFACTORED: Uses enum-driven metaprogramming to eliminate 79% of code duplication. 

13""" 

14 

15import functools 

16import inspect 

17import logging 

18import threading 

19from enum import Enum 

20from typing import Any, Callable, Optional, TypeVar 

21 

22import numpy as np 

23 

24from arraybridge.dtype_scaling import SCALING_FUNCTIONS 

25from arraybridge.framework_ops import _FRAMEWORK_OPS 

26from arraybridge.oom_recovery import _execute_with_oom_recovery 

27from arraybridge.slice_processing import process_slices 

28from arraybridge.types import MemoryType 

29from arraybridge.utils import optional_import 

30 

31logger = logging.getLogger(__name__) 

32 

33F = TypeVar("F", bound=Callable[..., Any]) 

34 

35 

36class DtypeConversion(Enum): 

37 """Data type conversion modes for all memory type functions.""" 

38 

39 PRESERVE_INPUT = "preserve" # Keep input dtype (default) 

40 NATIVE_OUTPUT = "native" # Use framework's native output 

41 UINT8 = "uint8" # Force uint8 (0-255 range) 

42 UINT16 = "uint16" # Force uint16 (microscopy standard) 

43 INT16 = "int16" # Force int16 (signed microscopy data) 

44 INT32 = "int32" # Force int32 (large integer values) 

45 FLOAT32 = "float32" # Force float32 (GPU performance) 

46 FLOAT64 = "float64" # Force float64 (maximum precision) 

47 

48 @property 

49 def numpy_dtype(self): 

50 """Get the corresponding numpy dtype.""" 

51 dtype_map = { 

52 self.UINT8: np.uint8, 

53 self.UINT16: np.uint16, 

54 self.INT16: np.int16, 

55 self.INT32: np.int32, 

56 self.FLOAT32: np.float32, 

57 self.FLOAT64: np.float64, 

58 } 

59 return dtype_map.get(self, None) 

60 

61 

62# Thread-local cache for lazy-loaded GPU frameworks 

63_gpu_frameworks_cache = {} 

64 

65 

66def _create_lazy_getter(framework_name: str): 

67 """Factory function that creates a lazy import getter for a framework.""" 

68 

69 def getter(): 

70 if framework_name not in _gpu_frameworks_cache: 

71 _gpu_frameworks_cache[framework_name] = optional_import(framework_name) 

72 if _gpu_frameworks_cache[framework_name] is not None: 

73 logger.debug( 

74 f"🔧 Lazy imported {framework_name} in thread " 

75 f"{threading.current_thread().name}" 

76 ) 

77 return _gpu_frameworks_cache[framework_name] 

78 

79 return getter 

80 

81 

82# Auto-generate lazy getters for all GPU frameworks 

83for mem_type in MemoryType: 

84 ops = _FRAMEWORK_OPS[mem_type] 

85 if ops["lazy_getter"] is not None: 

86 getter_func = _create_lazy_getter(ops["import_name"]) 

87 globals()[f"_get_{ops['import_name']}"] = getter_func 

88 

89 

90# Thread-local storage for GPU streams and contexts 

91_thread_gpu_contexts = threading.local() 

92 

93 

94class ThreadGPUContext: 

95 """Thread-local GPU context manager for CUDA streams.""" 

96 

97 def __init__(self): 

98 self.cupy_stream = None 

99 self.torch_stream = None 

100 self.tensorflow_device = None 

101 self.jax_device = None 

102 

103 def get_cupy_stream(self): 

104 """Get or create thread-local CuPy stream.""" 

105 if self.cupy_stream is None: 

106 cupy = globals().get("_get_cupy", lambda: None)() # noqa: F821 

107 if cupy is not None and hasattr(cupy, "cuda"): 

108 self.cupy_stream = cupy.cuda.Stream() 

109 logger.debug(f"🔧 Created CuPy stream for thread {threading.current_thread().name}") 

110 return self.cupy_stream 

111 

112 def get_torch_stream(self): 

113 """Get or create thread-local PyTorch stream.""" 

114 if self.torch_stream is None: 

115 torch = globals().get("_get_torch", lambda: None)() # noqa: F821 

116 if torch is not None and hasattr(torch, "cuda") and torch.cuda.is_available(): 

117 self.torch_stream = torch.cuda.Stream() 

118 logger.debug( 

119 f"🔧 Created PyTorch stream for thread " f"{threading.current_thread().name}" 

120 ) 

121 return self.torch_stream 

122 

123 

124def _get_thread_gpu_context(): 

125 """Get or create thread-local GPU context.""" 

126 if not hasattr(_thread_gpu_contexts, "context"): 

127 _thread_gpu_contexts.context = ThreadGPUContext() 

128 return _thread_gpu_contexts.context 

129 

130 

131def memory_types( 

132 input_type: str, output_type: str, contract: Optional[Callable[[Any], bool]] = None 

133) -> Callable[[F], F]: 

134 """ 

135 Base decorator for declaring memory types of a function. 

136 

137 This is the foundation decorator that all memory-type-specific decorators build upon. 

138 """ 

139 

140 def decorator(func: F) -> F: 

141 @functools.wraps(func) 

142 def wrapper(*args, **kwargs): 

143 result = func(*args, **kwargs) 

144 

145 # Apply contract validation if provided 

146 if contract is not None and not contract(result): 

147 raise ValueError(f"Function {func.__name__} violated its output contract") 

148 

149 return result 

150 

151 # Attach memory type metadata 

152 wrapper.input_memory_type = input_type 

153 wrapper.output_memory_type = output_type 

154 

155 return wrapper 

156 

157 return decorator 

158 

159 

160def _create_dtype_wrapper(func, mem_type: MemoryType, func_name: str): 

161 """ 

162 Auto-generate dtype preservation wrapper for any memory type. 

163 

164 This single function replaces 6 nearly-identical dtype wrapper functions. 

165 """ 

166 _FRAMEWORK_OPS[mem_type] 

167 scale_func = SCALING_FUNCTIONS[mem_type.value] 

168 

169 @functools.wraps(func) 

170 def dtype_wrapper(image, *args, dtype_conversion=None, slice_by_slice: bool = False, **kwargs): 

171 # Set default dtype_conversion if not provided 

172 if dtype_conversion is None: 

173 dtype_conversion = DtypeConversion.PRESERVE_INPUT 

174 

175 try: 

176 # Store original dtype 

177 original_dtype = image.dtype 

178 

179 # Handle slice_by_slice processing for 3D arrays 

180 if slice_by_slice and hasattr(image, "ndim") and image.ndim == 3: 

181 result = process_slices(image, func, args, kwargs) 

182 else: 

183 # Call the original function normally 

184 result = func(image, *args, **kwargs) 

185 

186 # Apply dtype conversion based on enum value 

187 if hasattr(result, "dtype") and dtype_conversion is not None: 

188 if dtype_conversion == DtypeConversion.PRESERVE_INPUT: 

189 # Preserve input dtype 

190 if result.dtype != original_dtype: 

191 result = scale_func(result, original_dtype) 

192 elif dtype_conversion == DtypeConversion.NATIVE_OUTPUT: 

193 # Return framework's native output dtype 

194 pass # No conversion needed 

195 else: 

196 # Force specific dtype 

197 target_dtype = dtype_conversion.numpy_dtype 

198 if target_dtype is not None: 

199 result = scale_func(result, target_dtype) 

200 

201 return result 

202 except Exception as e: 

203 logger.error( 

204 f"Error in {mem_type.value} dtype/slice preserving wrapper " f"for {func_name}: {e}" 

205 ) 

206 # Return original result on error 

207 return func(image, *args, **kwargs) 

208 

209 # Update function signature to include new parameters 

210 try: 

211 original_sig = inspect.signature(func) 

212 new_params = list(original_sig.parameters.values()) 

213 

214 # Check if parameters already exist 

215 param_names = [p.name for p in new_params] 

216 

217 # Add dtype_conversion parameter first (before slice_by_slice) 

218 if "dtype_conversion" not in param_names: 

219 dtype_param = inspect.Parameter( 

220 "dtype_conversion", 

221 inspect.Parameter.KEYWORD_ONLY, 

222 default=DtypeConversion.PRESERVE_INPUT, 

223 annotation=Optional[DtypeConversion], 

224 ) 

225 new_params.append(dtype_param) 

226 

227 # Add slice_by_slice parameter 

228 if "slice_by_slice" not in param_names: 

229 slice_param = inspect.Parameter( 

230 "slice_by_slice", inspect.Parameter.KEYWORD_ONLY, default=False, annotation=bool 

231 ) 

232 new_params.append(slice_param) 

233 

234 # Create new signature 

235 new_sig = original_sig.replace(parameters=new_params) 

236 dtype_wrapper.__signature__ = new_sig 

237 

238 # Update docstring 

239 if dtype_wrapper.__doc__: 

240 dtype_wrapper.__doc__ += ( 

241 f"\n\n Additional Parameters " f"(added by {mem_type.value} decorator):\n" 

242 ) 

243 dtype_wrapper.__doc__ += ( 

244 " dtype_conversion (DtypeConversion, optional): " 

245 "How to handle output dtype.\n" 

246 ) 

247 dtype_wrapper.__doc__ += " Defaults to PRESERVE_INPUT (match input dtype).\n" 

248 dtype_wrapper.__doc__ += ( 

249 " slice_by_slice (bool, optional): " "Process 3D arrays slice-by-slice.\n" 

250 ) 

251 dtype_wrapper.__doc__ += ( 

252 " Defaults to False. " "Prevents cross-slice contamination.\n" 

253 ) 

254 

255 except Exception as e: 

256 logger.warning(f"Could not update signature for {func_name}: {e}") 

257 

258 return dtype_wrapper 

259 

260 

261def _create_gpu_wrapper(func, mem_type: MemoryType, oom_recovery: bool): 

262 """ 

263 Auto-generate GPU stream/device wrapper for any GPU memory type. 

264 

265 This function creates the GPU-specific wrapper with stream management and OOM recovery. 

266 """ 

267 ops = _FRAMEWORK_OPS[mem_type] 

268 framework_name = ops["import_name"] 

269 lazy_getter = globals().get(ops["lazy_getter"]) 

270 

271 @functools.wraps(func) 

272 def gpu_wrapper(*args, **kwargs): 

273 framework = lazy_getter() 

274 

275 # Check if GPU is available for this framework 

276 if framework is not None: 

277 gpu_check_expr = ops["gpu_check"].format(mod=framework_name) 

278 try: 

279 gpu_available = eval(gpu_check_expr, {framework_name: framework}) 

280 except Exception: 

281 gpu_available = False 

282 

283 if gpu_available: 

284 # Get thread-local context 

285 ctx = _get_thread_gpu_context() 

286 

287 # Get stream if framework supports it 

288 stream = None 

289 if mem_type == MemoryType.CUPY: 

290 stream = ctx.get_cupy_stream() 

291 elif mem_type == MemoryType.TORCH: 

292 stream = ctx.get_torch_stream() 

293 

294 # Define execution function that captures args/kwargs 

295 def execute_with_stream(): 

296 if stream is not None: 

297 with stream: 

298 return func(*args, **kwargs) 

299 else: 

300 return func(*args, **kwargs) 

301 

302 # Execute with OOM recovery if enabled 

303 if oom_recovery and ops["has_oom_recovery"]: 

304 return _execute_with_oom_recovery(execute_with_stream, mem_type.value) 

305 else: 

306 return execute_with_stream() 

307 

308 # CPU fallback or framework not available 

309 return func(*args, **kwargs) 

310 

311 # Preserve memory type attributes 

312 gpu_wrapper.input_memory_type = func.input_memory_type 

313 gpu_wrapper.output_memory_type = func.output_memory_type 

314 

315 return gpu_wrapper 

316 

317 

318def _create_memory_decorator(mem_type: MemoryType): 

319 """ 

320 Factory function that creates a decorator for a specific memory type. 

321 

322 This single factory replaces 6 nearly-identical decorator functions. 

323 """ 

324 ops = _FRAMEWORK_OPS[mem_type] 

325 

326 def decorator( 

327 func=None, 

328 *, 

329 input_type=mem_type.value, 

330 output_type=mem_type.value, 

331 oom_recovery=True, 

332 contract=None, 

333 ): 

334 """ 

335 Decorator for {mem_type} memory type functions. 

336 

337 Args: 

338 func: Function to decorate (when used as @decorator) 

339 input_type: Expected input memory type (default: {mem_type}) 

340 output_type: Expected output memory type (default: {mem_type}) 

341 oom_recovery: Enable automatic OOM recovery (default: True) 

342 contract: Optional validation function for outputs 

343 

344 Returns: 

345 Decorated function with memory type metadata and dtype preservation 

346 """ 

347 

348 def inner_decorator(func): 

349 # Apply base memory_types decorator 

350 memory_decorator = memory_types( 

351 input_type=input_type, output_type=output_type, contract=contract 

352 ) 

353 func = memory_decorator(func) 

354 

355 # Apply dtype preservation wrapper 

356 func = _create_dtype_wrapper(func, mem_type, func.__name__) 

357 

358 # Apply GPU wrapper if this is a GPU memory type 

359 if ops["gpu_check"] is not None: 

360 func = _create_gpu_wrapper(func, mem_type, oom_recovery) 

361 

362 return func 

363 

364 # Handle both @decorator and @decorator() forms 

365 if func is None: 

366 return inner_decorator 

367 return inner_decorator(func) 

368 

369 # Set proper function name and docstring 

370 decorator.__name__ = mem_type.value 

371 decorator.__doc__ = decorator.__doc__.format(mem_type=ops["display_name"]) 

372 

373 return decorator 

374 

375 

376# Auto-generate all 6 memory type decorators 

377for mem_type in MemoryType: 

378 decorator_func = _create_memory_decorator(mem_type) 

379 globals()[mem_type.value] = decorator_func 

380 

381 

382# Export all decorators 

383__all__ = [ 

384 "memory_types", 

385 "DtypeConversion", 

386 "numpy", # noqa: F822 

387 "cupy", # noqa: F822 

388 "torch", # noqa: F822 

389 "tensorflow", # noqa: F822 

390 "jax", # noqa: F822 

391 "pyclesperanto", # noqa: F822 

392]