Coverage for openhcs/core/memory/decorators.py: 58.7%

177 statements  

« prev     ^ index     » next       coverage.py v7.11.0, created at 2025-11-04 02:09 +0000

1""" 

2Memory type declaration decorators for OpenHCS. 

3 

4This module provides decorators for explicitly declaring the memory interface 

5of pure functions, enforcing Clause 106-A (Declared Memory Types) and supporting 

6memory-type-aware dispatching and orchestration. 

7 

8These decorators annotate functions with input_memory_type and output_memory_type 

9attributes and provide automatic thread-local CUDA stream management for GPU 

10frameworks to enable true parallelization across multiple threads. 

11 

12REFACTORED: Uses enum-driven metaprogramming to eliminate 79% of code duplication. 

13""" 

14 

15import functools 

16import inspect 

17import logging 

18import threading 

19from typing import Any, Callable, Optional, TypeVar 

20 

21from openhcs.constants.constants import VALID_MEMORY_TYPES, MemoryType 

22from openhcs.core.utils import optional_import 

23from openhcs.core.memory.oom_recovery import _execute_with_oom_recovery 

24from openhcs.core.memory.framework_ops import _FRAMEWORK_OPS 

25from openhcs.core.memory.dtype_scaling import SCALING_FUNCTIONS 

26from openhcs.core.memory.slice_processing import process_slices 

27 

28logger = logging.getLogger(__name__) 

29 

30F = TypeVar('F', bound=Callable[..., Any]) 

31 

32# Dtype conversion enum and utilities for consistent dtype handling across all frameworks 

33from enum import Enum 

34import numpy as np 

35 

36class DtypeConversion(Enum): 

37 """Data type conversion modes for all memory type functions.""" 

38 

39 PRESERVE_INPUT = "preserve" # Keep input dtype (default) 

40 NATIVE_OUTPUT = "native" # Use framework's native output 

41 UINT8 = "uint8" # Force uint8 (0-255 range) 

42 UINT16 = "uint16" # Force uint16 (microscopy standard) 

43 INT16 = "int16" # Force int16 (signed microscopy data) 

44 INT32 = "int32" # Force int32 (large integer values) 

45 FLOAT32 = "float32" # Force float32 (GPU performance) 

46 FLOAT64 = "float64" # Force float64 (maximum precision) 

47 

48 @property 

49 def numpy_dtype(self): 

50 """Get the corresponding numpy dtype.""" 

51 dtype_map = { 

52 self.UINT8: np.uint8, 

53 self.UINT16: np.uint16, 

54 self.INT16: np.int16, 

55 self.INT32: np.int32, 

56 self.FLOAT32: np.float32, 

57 self.FLOAT64: np.float64, 

58 } 

59 return dtype_map.get(self, None) 

60 

61 

62# Thread-local cache for lazy-loaded GPU frameworks 

63_gpu_frameworks_cache = {} 

64 

65 

66def _create_lazy_getter(framework_name: str): 

67 """Factory function that creates a lazy import getter for a framework.""" 

68 def getter(): 

69 if framework_name not in _gpu_frameworks_cache: 

70 _gpu_frameworks_cache[framework_name] = optional_import(framework_name) 

71 if _gpu_frameworks_cache[framework_name] is not None: 

72 logger.debug(f"🔧 Lazy imported {framework_name} in thread {threading.current_thread().name}") 

73 return _gpu_frameworks_cache[framework_name] 

74 return getter 

75 

76 

77# Auto-generate lazy getters for all GPU frameworks 

78for mem_type in MemoryType: 

79 ops = _FRAMEWORK_OPS[mem_type] 

80 if ops['lazy_getter'] is not None: 

81 getter_func = _create_lazy_getter(ops['import_name']) 

82 globals()[f"_get_{ops['import_name']}"] = getter_func 

83 

84 

85# Thread-local storage for GPU streams and contexts 

86_thread_gpu_contexts = threading.local() 

87 

88class ThreadGPUContext: 

89 """Thread-local GPU context manager for CUDA streams.""" 

90 

91 def __init__(self): 

92 self.cupy_stream = None 

93 self.torch_stream = None 

94 self.tensorflow_device = None 

95 self.jax_device = None 

96 

97 def get_cupy_stream(self): 

98 """Get or create thread-local CuPy stream.""" 

99 if self.cupy_stream is None: 

100 cupy = _get_cupy() 

101 if cupy is not None and hasattr(cupy, 'cuda'): 

102 self.cupy_stream = cupy.cuda.Stream() 

103 logger.debug(f"🔧 Created CuPy stream for thread {threading.current_thread().name}") 

104 return self.cupy_stream 

105 

106 def get_torch_stream(self): 

107 """Get or create thread-local PyTorch stream.""" 

108 if self.torch_stream is None: 

109 torch = _get_torch() 

110 if torch is not None and hasattr(torch, 'cuda') and torch.cuda.is_available(): 

111 self.torch_stream = torch.cuda.Stream() 

112 logger.debug(f"🔧 Created PyTorch stream for thread {threading.current_thread().name}") 

113 return self.torch_stream 

114 

115 

116def _get_thread_gpu_context(): 

117 """Get or create thread-local GPU context.""" 

118 if not hasattr(_thread_gpu_contexts, 'context'): 

119 _thread_gpu_contexts.context = ThreadGPUContext() 

120 return _thread_gpu_contexts.context 

121 

122 

123def memory_types( 

124 input_type: str, 

125 output_type: str, 

126 contract: Optional[Callable[[Any], bool]] = None 

127) -> Callable[[F], F]: 

128 """ 

129 Base decorator for declaring memory types of a function. 

130  

131 This is the foundation decorator that all memory-type-specific decorators build upon. 

132 """ 

133 def decorator(func: F) -> F: 

134 @functools.wraps(func) 

135 def wrapper(*args, **kwargs): 

136 result = func(*args, **kwargs) 

137 

138 # Apply contract validation if provided 

139 if contract is not None and not contract(result): 139 ↛ 140line 139 didn't jump to line 140 because the condition on line 139 was never true

140 raise ValueError(f"Function {func.__name__} violated its output contract") 

141 

142 return result 

143 

144 # Attach memory type metadata 

145 wrapper.input_memory_type = input_type 

146 wrapper.output_memory_type = output_type 

147 

148 return wrapper 

149 

150 return decorator 

151 

152 

153def _create_dtype_wrapper(func, mem_type: MemoryType, func_name: str): 

154 """ 

155 Auto-generate dtype preservation wrapper for any memory type. 

156  

157 This single function replaces 6 nearly-identical dtype wrapper functions. 

158 """ 

159 ops = _FRAMEWORK_OPS[mem_type] 

160 scale_func = SCALING_FUNCTIONS[mem_type.value] 

161 

162 @functools.wraps(func) 

163 def dtype_wrapper(image, *args, dtype_conversion=None, slice_by_slice: bool = False, **kwargs): 

164 # Set default dtype_conversion if not provided 

165 if dtype_conversion is None: 

166 dtype_conversion = DtypeConversion.PRESERVE_INPUT 

167 

168 try: 

169 # Store original dtype 

170 original_dtype = image.dtype 

171 

172 # Handle slice_by_slice processing for 3D arrays 

173 if slice_by_slice and hasattr(image, 'ndim') and image.ndim == 3: 173 ↛ 174line 173 didn't jump to line 174 because the condition on line 173 was never true

174 result = process_slices(image, func, args, kwargs) 

175 else: 

176 # Call the original function normally 

177 result = func(image, *args, **kwargs) 

178 

179 # Apply dtype conversion based on enum value 

180 if hasattr(result, 'dtype') and dtype_conversion is not None: 

181 if dtype_conversion == DtypeConversion.PRESERVE_INPUT: 181 ↛ 185line 181 didn't jump to line 185 because the condition on line 181 was always true

182 # Preserve input dtype 

183 if result.dtype != original_dtype: 183 ↛ 184line 183 didn't jump to line 184 because the condition on line 183 was never true

184 result = scale_func(result, original_dtype) 

185 elif dtype_conversion == DtypeConversion.NATIVE_OUTPUT: 

186 # Return framework's native output dtype 

187 pass # No conversion needed 

188 else: 

189 # Force specific dtype 

190 target_dtype = dtype_conversion.numpy_dtype 

191 if target_dtype is not None: 

192 result = scale_func(result, target_dtype) 

193 

194 return result 

195 except Exception as e: 

196 logger.error(f"Error in {mem_type.value} dtype/slice preserving wrapper for {func_name}: {e}") 

197 # Return original result on error 

198 return func(image, *args, **kwargs) 

199 

200 # Update function signature to include new parameters 

201 try: 

202 original_sig = inspect.signature(func) 

203 new_params = list(original_sig.parameters.values()) 

204 

205 # Check if parameters already exist 

206 param_names = [p.name for p in new_params] 

207 

208 # Add dtype_conversion parameter first (before slice_by_slice) 

209 if 'dtype_conversion' not in param_names: 209 ↛ 219line 209 didn't jump to line 219 because the condition on line 209 was always true

210 dtype_param = inspect.Parameter( 

211 'dtype_conversion', 

212 inspect.Parameter.KEYWORD_ONLY, 

213 default=DtypeConversion.PRESERVE_INPUT, 

214 annotation=Optional[DtypeConversion] 

215 ) 

216 new_params.append(dtype_param) 

217 

218 # Add slice_by_slice parameter 

219 if 'slice_by_slice' not in param_names: 

220 slice_param = inspect.Parameter( 

221 'slice_by_slice', 

222 inspect.Parameter.KEYWORD_ONLY, 

223 default=False, 

224 annotation=bool 

225 ) 

226 new_params.append(slice_param) 

227 

228 # Create new signature 

229 new_sig = original_sig.replace(parameters=new_params) 

230 dtype_wrapper.__signature__ = new_sig 

231 

232 # Update docstring 

233 if dtype_wrapper.__doc__: 

234 dtype_wrapper.__doc__ += f"\n\n Additional Parameters (added by {mem_type.value} decorator):\n" 

235 dtype_wrapper.__doc__ += " dtype_conversion (DtypeConversion, optional): How to handle output dtype.\n" 

236 dtype_wrapper.__doc__ += " Defaults to PRESERVE_INPUT (match input dtype).\n" 

237 dtype_wrapper.__doc__ += " slice_by_slice (bool, optional): Process 3D arrays slice-by-slice.\n" 

238 dtype_wrapper.__doc__ += " Defaults to False. Prevents cross-slice contamination.\n" 

239 

240 except Exception as e: 

241 logger.warning(f"Could not update signature for {func_name}: {e}") 

242 

243 return dtype_wrapper 

244 

245 

246def _create_gpu_wrapper(func, mem_type: MemoryType, oom_recovery: bool): 

247 """ 

248 Auto-generate GPU stream/device wrapper for any GPU memory type. 

249  

250 This function creates the GPU-specific wrapper with stream management and OOM recovery. 

251 """ 

252 ops = _FRAMEWORK_OPS[mem_type] 

253 framework_name = ops['import_name'] 

254 lazy_getter = globals().get(ops['lazy_getter']) 

255 

256 @functools.wraps(func) 

257 def gpu_wrapper(*args, **kwargs): 

258 framework = lazy_getter() 

259 

260 # Check if GPU is available for this framework 

261 if framework is not None: 

262 gpu_check_expr = ops['gpu_check'].format(mod=framework_name) 

263 try: 

264 gpu_available = eval(gpu_check_expr, {framework_name: framework}) 

265 except: 

266 gpu_available = False 

267 

268 if gpu_available: 

269 # Get thread-local context 

270 ctx = _get_thread_gpu_context() 

271 

272 # Get stream if framework supports it 

273 stream = None 

274 if mem_type == MemoryType.CUPY: 

275 stream = ctx.get_cupy_stream() 

276 elif mem_type == MemoryType.TORCH: 

277 stream = ctx.get_torch_stream() 

278 

279 # Define execution function that captures args/kwargs 

280 def execute_with_stream(): 

281 if stream is not None: 

282 with stream: 

283 return func(*args, **kwargs) 

284 else: 

285 return func(*args, **kwargs) 

286 

287 # Execute with OOM recovery if enabled 

288 if oom_recovery and ops['has_oom_recovery']: 

289 return _execute_with_oom_recovery(execute_with_stream, mem_type.value) 

290 else: 

291 return execute_with_stream() 

292 

293 # CPU fallback or framework not available 

294 return func(*args, **kwargs) 

295 

296 # Preserve memory type attributes 

297 gpu_wrapper.input_memory_type = func.input_memory_type 

298 gpu_wrapper.output_memory_type = func.output_memory_type 

299 

300 return gpu_wrapper 

301 

302 

303def _create_memory_decorator(mem_type: MemoryType): 

304 """ 

305 Factory function that creates a decorator for a specific memory type. 

306 

307 This single factory replaces 6 nearly-identical decorator functions. 

308 """ 

309 ops = _FRAMEWORK_OPS[mem_type] 

310 

311 def decorator(func=None, *, input_type=mem_type.value, output_type=mem_type.value, 

312 oom_recovery=True, contract=None): 

313 """ 

314 Decorator for {mem_type} memory type functions. 

315 

316 Args: 

317 func: Function to decorate (when used as @decorator) 

318 input_type: Expected input memory type (default: {mem_type}) 

319 output_type: Expected output memory type (default: {mem_type}) 

320 oom_recovery: Enable automatic OOM recovery (default: True) 

321 contract: Optional validation function for outputs 

322 

323 Returns: 

324 Decorated function with memory type metadata and dtype preservation 

325 """ 

326 def inner_decorator(func): 

327 # Apply base memory_types decorator 

328 memory_decorator = memory_types(input_type=input_type, output_type=output_type, contract=contract) 

329 func = memory_decorator(func) 

330 

331 # Apply dtype preservation wrapper 

332 func = _create_dtype_wrapper(func, mem_type, func.__name__) 

333 

334 # Apply GPU wrapper if this is a GPU memory type 

335 if ops['gpu_check'] is not None: 

336 func = _create_gpu_wrapper(func, mem_type, oom_recovery) 

337 

338 return func 

339 

340 # Handle both @decorator and @decorator() forms 

341 if func is None: 341 ↛ 342line 341 didn't jump to line 342 because the condition on line 341 was never true

342 return inner_decorator 

343 return inner_decorator(func) 

344 

345 # Set proper function name and docstring 

346 decorator.__name__ = mem_type.value 

347 decorator.__doc__ = decorator.__doc__.format(mem_type=ops['display_name']) 

348 

349 return decorator 

350 

351 

352# Auto-generate all 6 memory type decorators 

353for mem_type in MemoryType: 

354 decorator_func = _create_memory_decorator(mem_type) 

355 globals()[mem_type.value] = decorator_func 

356 

357 

358# Export all decorators 

359__all__ = [ 

360 'memory_types', 

361 'DtypeConversion', 

362 'numpy', 

363 'cupy', 

364 'torch', 

365 'tensorflow', 

366 'jax', 

367 'pyclesperanto', 

368] 

369