Coverage for openhcs/processing/func_registry.py: 38.0%

231 statements  

« prev     ^ index     » next       coverage.py v7.10.7, created at 2025-10-01 18:33 +0000

1""" 

2Function registry for processing backends. 

3 

4This module provides a registry for functions that can be executed by different 

5processing backends (numpy, cupy, torch, etc.). It automatically scans the 

6processing directory to register functions with matching input and output 

7memory types. 

8 

9The function registry is a global singleton that is initialized during application 

10startup and shared across all components. 

11 

12Valid memory types: 

13- numpy 

14- cupy 

15- torch 

16- tensorflow 

17- jax 

18 

19Thread Safety: 

20 All functions in this module are thread-safe and use a lock to ensure 

21 consistent access to the global registry. 

22""" 

23from __future__ import annotations 

24 

25import importlib 

26import inspect 

27import logging 

28import os 

29import pkgutil 

30import sys 

31import threading 

32from typing import Any, Callable, Dict, List, Optional, Set, Tuple 

33 

34logger = logging.getLogger(__name__) 

35 

36# Thread-safe lock for registry access 

37_registry_lock = threading.Lock() 

38 

39# Import hook system for auto-decorating external libraries 

40_original_import = __builtins__['__import__'] 

41_decoration_applied = set() 

42_import_hook_installed = False 

43 

44# Global registry of functions by backend type 

45# Structure: {backend_name: [function1, function2, ...]} 

46FUNC_REGISTRY: Dict[str, List[Callable]] = {} 

47 

48# Valid memory types 

49VALID_MEMORY_TYPES = {"numpy", "cupy", "torch", "tensorflow", "jax", "pyclesperanto"} 

50 

51# CPU-only memory types (for CI/testing without GPU) 

52CPU_ONLY_MEMORY_TYPES = {"numpy"} 

53 

54# Check if CPU-only mode is enabled 

55CPU_ONLY_MODE = os.getenv('OPENHCS_CPU_ONLY', 'false').lower() == 'true' 

56 

57# Flag to track if the registry has been initialized 

58_registry_initialized = False 

59 

60# Flag to track if we're currently in the initialization process (prevent recursion) 

61_registry_initializing = False 

62 

63 

64# Import hook system removed - using existing comprehensive registries with clean decoration 

65 

66 

67# Import hook decoration functions removed - using existing registries 

68 

69 

70def _create_virtual_modules() -> None: 

71 """Create virtual modules that mirror external library structure under openhcs namespace.""" 

72 import sys 

73 import types 

74 from openhcs.processing.backends.lib_registry.registry_service import RegistryService 

75 

76 # Get all registered functions 

77 all_functions = RegistryService.get_all_functions_with_metadata() 

78 

79 # Group functions by their full module path 

80 functions_by_module = {} 

81 for composite_key, metadata in all_functions.items(): 

82 # Only create virtual modules for external library functions with slice_by_slice 

83 if (hasattr(metadata.func, 'slice_by_slice') and 

84 not hasattr(metadata.func, '__processing_contract__') and 

85 not metadata.func.__module__.startswith('openhcs.')): 

86 

87 original_module = metadata.func.__module__ 

88 virtual_module = f'openhcs.{original_module}' 

89 if virtual_module not in functions_by_module: 

90 functions_by_module[virtual_module] = {} 

91 functions_by_module[virtual_module][metadata.func.__name__] = metadata.func 

92 

93 # Create virtual modules for each module path 

94 created_modules = [] 

95 all_virtual_modules = set() 

96 

97 # First, collect all module paths including intermediate ones 

98 for virtual_module in functions_by_module.keys(): 

99 parts = virtual_module.split('.') 

100 for i in range(2, len(parts) + 1): # Start from 'openhcs.xxx' 

101 intermediate_module = '.'.join(parts[:i]) 

102 all_virtual_modules.add(intermediate_module) 

103 

104 # Create intermediate modules first (in order) 

105 for virtual_module in sorted(all_virtual_modules): 

106 if virtual_module not in sys.modules: 106 ↛ 105line 106 didn't jump to line 105 because the condition on line 106 was always true

107 module = types.ModuleType(virtual_module) 

108 module.__doc__ = f"Virtual module mirroring {virtual_module.replace('openhcs.', '')} with OpenHCS decorations" 

109 sys.modules[virtual_module] = module 

110 created_modules.append(virtual_module) 

111 

112 # Then add functions to the leaf modules 

113 for virtual_module, functions in functions_by_module.items(): 

114 if virtual_module in sys.modules: 114 ↛ 113line 114 didn't jump to line 113 because the condition on line 114 was always true

115 module = sys.modules[virtual_module] 

116 # Add all functions from this module 

117 for func_name, func in functions.items(): 

118 setattr(module, func_name, func) 

119 

120 if created_modules: 120 ↛ exitline 120 didn't return from function '_create_virtual_modules' because the condition on line 120 was always true

121 logger.info(f"Created {len(created_modules)} virtual modules: {', '.join(created_modules)}") 

122 

123 

124def _auto_initialize_registry() -> None: 

125 """ 

126 Auto-initialize the function registry on module import. 

127 

128 This follows the same pattern as storage_registry in openhcs.io.base. 

129 """ 

130 global _registry_initialized 

131 

132 if _registry_initialized: 132 ↛ 133line 132 didn't jump to line 133 because the condition on line 132 was never true

133 return 

134 

135 try: 

136 # Clear and initialize the registry 

137 FUNC_REGISTRY.clear() 

138 

139 # Phase 1: Register all functions from RegistryService (includes OpenHCS and external libraries) 

140 from openhcs.processing.backends.lib_registry.registry_service import RegistryService 

141 all_functions = RegistryService.get_all_functions_with_metadata() 

142 

143 # Initialize registry structure based on discovered registries 

144 # Handle composite keys from RegistryService (backend:function_name) 

145 for composite_key, metadata in all_functions.items(): 

146 registry_name = metadata.registry.library_name 

147 if registry_name not in FUNC_REGISTRY: 

148 FUNC_REGISTRY[registry_name] = [] 

149 

150 # Register all functions 

151 for composite_key, metadata in all_functions.items(): 

152 registry_name = metadata.registry.library_name 

153 FUNC_REGISTRY[registry_name].append(metadata.func) 

154 

155 # Phase 2: Apply CPU-only filtering if enabled 

156 if CPU_ONLY_MODE: 156 ↛ 160line 156 didn't jump to line 160 because the condition on line 156 was always true

157 logger.info("CPU-only mode enabled - filtering to numpy functions only") 

158 _apply_cpu_only_filtering() 

159 

160 total_functions = sum(len(funcs) for funcs in FUNC_REGISTRY.values()) 

161 logger.info( 

162 "Function registry auto-initialized with %d functions across %d registries", 

163 total_functions, 

164 len(FUNC_REGISTRY) 

165 ) 

166 

167 # Mark registry as initialized 

168 _registry_initialized = True 

169 

170 # Create virtual modules for external library functions 

171 _create_virtual_modules() 

172 

173 except Exception as e: 

174 logger.error(f"Failed to auto-initialize function registry: {e}") 

175 raise 

176 

177 

178def initialize_registry() -> None: 

179 """ 

180 Initialize the function registry and scan for functions to register. 

181 

182 This function is now optional since the registry auto-initializes on import. 

183 It can be called to force re-initialization if needed. 

184 

185 Thread-safe: Uses a lock to ensure consistent access to the global registry. 

186 

187 Raises: 

188 RuntimeError: If the registry is already initialized and force=False 

189 """ 

190 with _registry_lock: 

191 global _registry_initialized 

192 

193 # Check if registry is already initialized 

194 if _registry_initialized: 

195 logger.info("Function registry already initialized, skipping manual initialization") 

196 return 

197 

198 # Clear and initialize the registry 

199 FUNC_REGISTRY.clear() 

200 

201 # Phase 1: Register all functions from RegistryService (includes OpenHCS and external libraries) 

202 from openhcs.processing.backends.lib_registry.registry_service import RegistryService 

203 all_functions = RegistryService.get_all_functions_with_metadata() 

204 

205 # Initialize registry structure based on discovered registries 

206 # Handle composite keys from RegistryService (backend:function_name) 

207 for composite_key, metadata in all_functions.items(): 

208 registry_name = metadata.registry.library_name 

209 if registry_name not in FUNC_REGISTRY: 

210 FUNC_REGISTRY[registry_name] = [] 

211 

212 # Register all functions 

213 for composite_key, metadata in all_functions.items(): 

214 registry_name = metadata.registry.library_name 

215 FUNC_REGISTRY[registry_name].append(metadata.func) 

216 

217 # Phase 2: Apply CPU-only filtering if enabled 

218 if CPU_ONLY_MODE: 

219 logger.info("CPU-only mode enabled - filtering to numpy functions only") 

220 _apply_cpu_only_filtering() 

221 

222 logger.info( 

223 "Function registry initialized with %d functions across %d registries", 

224 sum(len(funcs) for funcs in FUNC_REGISTRY.values()), 

225 len(FUNC_REGISTRY) 

226 ) 

227 

228 # Mark registry as initialized 

229 _registry_initialized = True 

230 

231 # Create virtual modules for external library functions 

232 _create_virtual_modules() 

233 

234 

235def load_prebuilt_registry(registry_data: Dict) -> None: 

236 """ 

237 Load a pre-built function registry from serialized data. 

238 

239 This allows subprocess workers to skip function discovery by loading 

240 a registry that was built in the main process. 

241 

242 Args: 

243 registry_data: Dictionary containing the pre-built registry 

244 """ 

245 with _registry_lock: 

246 global _registry_initialized 

247 

248 FUNC_REGISTRY.clear() 

249 FUNC_REGISTRY.update(registry_data) 

250 _registry_initialized = True 

251 

252 total_functions = sum(len(funcs) for funcs in FUNC_REGISTRY.values()) 

253 logger.info(f"Loaded pre-built registry with {total_functions} functions") 

254 

255 

256def _scan_and_register_functions() -> None: 

257 """ 

258 Scan the processing directory for native OpenHCS functions. 

259 

260 This function recursively imports all modules in the processing directory 

261 and registers functions that have matching input_memory_type and output_memory_type 

262 attributes that are in the set of valid memory types. 

263 

264 This is Phase 1 of initialization - only native OpenHCS functions. 

265 External library functions are registered in Phase 2. 

266 """ 

267 from openhcs import processing 

268 

269 processing_path = os.path.dirname(processing.__file__) 

270 processing_package = "openhcs.processing" 

271 

272 logger.info("Phase 1: Scanning for native OpenHCS functions in %s", processing_path) 

273 

274 # Walk through all modules in the processing package 

275 for _, module_name, is_pkg in pkgutil.walk_packages([processing_path], f"{processing_package}."): 

276 try: 

277 # Import the module 

278 logger.debug(f"Scanning module: {module_name}") 

279 module = importlib.import_module(module_name) 

280 

281 # Skip packages (we'll process their modules separately) 

282 if is_pkg: 

283 logger.debug(f"Skipping package: {module_name}") 

284 continue 

285 

286 # Find all functions in the module 

287 function_count = 0 

288 for name, obj in inspect.getmembers(module, inspect.isfunction): 

289 # Check if the function has the required attributes 

290 if hasattr(obj, "input_memory_type") and hasattr(obj, "output_memory_type"): 

291 input_type = getattr(obj, "input_memory_type") 

292 output_type = getattr(obj, "output_memory_type") 

293 

294 # Register if input and output types are valid (OpenHCS functions can have mixed types) 

295 if input_type in VALID_MEMORY_TYPES and output_type in VALID_MEMORY_TYPES: 

296 _register_function(obj, "openhcs") 

297 function_count += 1 

298 

299 logger.debug(f"Module {module_name}: found {function_count} registerable functions") 

300 except Exception as e: 

301 logger.warning("Error importing module %s: %s", module_name, e) 

302 

303 

304def _apply_unified_decoration(original_func, func_name, memory_type, create_wrapper=True): 

305 """ 

306 Unified decoration pattern for all external library functions. 

307 

308 NOTE: Dtype preservation is now handled at the decorator level in decorators.py. 

309 This function applies memory type attributes, decorator wrappers, and module replacement. 

310 

311 This applies the same hybrid approach across all registries: 

312 1. Direct decoration (for subprocess compatibility) 

313 2. Memory type decorator application (for dtype preservation and other features) 

314 3. Module replacement (for best user experience and pickling compatibility) 

315 

316 Args: 

317 original_func: The original external library function 

318 func_name: Function name for wrapper creation 

319 memory_type: MemoryType enum value (NUMPY, CUPY, PYCLESPERANTO, TORCH, TENSORFLOW, JAX) 

320 create_wrapper: Whether to apply memory type decorator (default: True) 

321 

322 Returns: 

323 The function to register (decorated if create_wrapper=True, original if not) 

324 """ 

325 from openhcs.constants import MemoryType 

326 import sys 

327 

328 # Step 1: Direct decoration (for subprocess compatibility) 

329 original_func.input_memory_type = memory_type.value 

330 original_func.output_memory_type = memory_type.value 

331 

332 if not create_wrapper: 

333 return original_func 

334 

335 # Step 2: Apply memory type decorator (includes dtype preservation, streams, OOM recovery) 

336 from openhcs.core.memory.decorators import numpy, cupy, torch, tensorflow, jax, pyclesperanto 

337 

338 if memory_type == MemoryType.NUMPY: 

339 wrapper_func = numpy(original_func) 

340 elif memory_type == MemoryType.CUPY: 

341 wrapper_func = cupy(original_func) 

342 elif memory_type == MemoryType.TORCH: 

343 wrapper_func = torch(original_func) 

344 elif memory_type == MemoryType.TENSORFLOW: 

345 wrapper_func = tensorflow(original_func) 

346 elif memory_type == MemoryType.JAX: 

347 wrapper_func = jax(original_func) 

348 elif memory_type == MemoryType.PYCLESPERANTO: 

349 wrapper_func = pyclesperanto(original_func) 

350 else: 

351 # Fallback for unknown memory types 

352 wrapper_func = original_func 

353 wrapper_func.input_memory_type = memory_type.value 

354 wrapper_func.output_memory_type = memory_type.value 

355 

356 # Step 3: Module replacement (for best user experience and pickling compatibility) 

357 module_name = original_func.__module__ 

358 if module_name in sys.modules: 

359 target_module = sys.modules[module_name] 

360 if hasattr(target_module, func_name): 

361 setattr(target_module, func_name, wrapper_func) 

362 logger.debug(f"Replaced {module_name}.{func_name} with enhanced function") 

363 

364 return wrapper_func 

365 

366 

367 

368 

369def register_function(func: Callable, backend: str = None, **kwargs) -> None: 

370 """ 

371 Manually register a function with the function registry. 

372 

373 This is the public API for registering functions that are not auto-discovered 

374 by the module scanner (e.g., dynamically decorated functions). 

375 

376 Args: 

377 func: The function to register (must have input_memory_type and output_memory_type attributes) 

378 backend: Optional backend name (defaults to func.input_memory_type) 

379 **kwargs: Additional metadata (ignored for compatibility) 

380 

381 Raises: 

382 ValueError: If function doesn't have required memory type attributes 

383 ValueError: If memory types are invalid 

384 """ 

385 with _registry_lock: 

386 # Ensure registry is initialized 

387 if not _registry_initialized: 

388 _auto_initialize_registry() 

389 

390 # Validate function has required attributes 

391 if not hasattr(func, "input_memory_type") or not hasattr(func, "output_memory_type"): 

392 raise ValueError( 

393 f"Function '{func.__name__}' must have input_memory_type and output_memory_type attributes" 

394 ) 

395 

396 input_type = func.input_memory_type 

397 output_type = func.output_memory_type 

398 

399 # Validate memory types 

400 if input_type not in VALID_MEMORY_TYPES: 

401 raise ValueError(f"Invalid input memory type: {input_type}") 

402 if output_type not in VALID_MEMORY_TYPES: 

403 raise ValueError(f"Invalid output memory type: {output_type}") 

404 

405 # Use backend if specified, otherwise register as openhcs 

406 registry_name = backend or "openhcs" 

407 if registry_name not in FUNC_REGISTRY: 

408 raise ValueError(f"Invalid registry name: {registry_name}") 

409 

410 # Register the function 

411 _register_function(func, registry_name) 

412 

413 

414def _apply_cpu_only_filtering() -> None: 

415 """Filter registry to only include numpy-compatible functions when CPU_ONLY_MODE is enabled.""" 

416 for registry_name, functions in list(FUNC_REGISTRY.items()): 

417 filtered_functions = [] 

418 for func in functions: 

419 # Only keep functions with numpy memory types 

420 if hasattr(func, 'output_memory_type') and func.output_memory_type == "numpy": 

421 filtered_functions.append(func) 

422 

423 # Update registry with filtered functions, remove empty registries 

424 if filtered_functions: 424 ↛ 427line 424 didn't jump to line 427 because the condition on line 424 was always true

425 FUNC_REGISTRY[registry_name] = filtered_functions 

426 else: 

427 del FUNC_REGISTRY[registry_name] 

428 

429 

430def _register_function(func: Callable, registry_name: str) -> None: 

431 """ 

432 Register a function for a specific registry. 

433 

434 This is an internal function used during automatic scanning and manual registration. 

435 

436 Args: 

437 func: The function to register 

438 registry_name: The registry name (e.g., "openhcs", "skimage", "pyclesperanto") 

439 """ 

440 # Skip if function is already registered 

441 if func in FUNC_REGISTRY[registry_name]: 

442 logger.debug( 

443 "Function '%s' already registered for registry '%s'", 

444 func.__name__, registry_name 

445 ) 

446 return 

447 

448 # Add function to registry 

449 FUNC_REGISTRY[registry_name].append(func) 

450 

451 # Add registry_name attribute for easier inspection 

452 setattr(func, "registry", registry_name) 

453 

454 logger.debug( 

455 "Registered function '%s' for memory type '%s'", 

456 func.__name__, memory_type 

457 ) 

458 

459 

460def get_functions_by_memory_type(memory_type: str) -> List[Callable]: 

461 """ 

462 Get all functions for a specific memory type using the new RegistryService. 

463 

464 Args: 

465 memory_type: The memory type (e.g., "numpy", "cupy", "torch") 

466 

467 Returns: 

468 A list of functions for the specified memory type 

469 

470 Raises: 

471 ValueError: If the memory type is not valid 

472 """ 

473 # Check if memory type is valid 

474 if memory_type not in VALID_MEMORY_TYPES: 

475 raise ValueError( 

476 f"Invalid memory type: {memory_type}. " 

477 f"Valid types are: {', '.join(sorted(VALID_MEMORY_TYPES))}" 

478 ) 

479 

480 # Get functions from new RegistryService 

481 from openhcs.processing.backends.lib_registry.registry_service import RegistryService 

482 all_functions = RegistryService.get_all_functions_with_metadata() 

483 

484 # Filter functions by memory type using proper architecture 

485 functions = [] 

486 for func_name, metadata in all_functions.items(): 

487 # Handle two distinct patterns: 

488 

489 # 1. Runtime Testing Libraries: Use registry's MEMORY_TYPE attribute 

490 if hasattr(metadata, 'registry') and hasattr(metadata.registry, 'MEMORY_TYPE'): 

491 if metadata.registry.MEMORY_TYPE == memory_type: 

492 functions.append(metadata.func) 

493 

494 # 2. OpenHCS Native Functions: Check function's own memory type attributes 

495 elif metadata.tags and 'openhcs' in metadata.tags: 

496 # Check if function has memory type information 

497 func = metadata.func 

498 if hasattr(func, 'input_memory_type') and func.input_memory_type == memory_type: 

499 functions.append(func) 

500 elif hasattr(func, 'backend') and func.backend == memory_type: 

501 functions.append(func) 

502 

503 # Also include legacy FUNC_REGISTRY functions for backward compatibility 

504 with _registry_lock: 

505 if _registry_initialized and memory_type in FUNC_REGISTRY: 

506 functions.extend(FUNC_REGISTRY[memory_type]) 

507 

508 return functions 

509 

510 

511def get_function_info(func: Callable) -> Dict[str, Any]: 

512 """ 

513 Get information about a registered function. 

514  

515 Args: 

516 func: The function to get information about 

517  

518 Returns: 

519 A dictionary containing information about the function 

520  

521 Raises: 

522 ValueError: If the function does not have memory type attributes 

523 """ 

524 if not hasattr(func, "input_memory_type") or not hasattr(func, "output_memory_type"): 

525 raise ValueError( 

526 f"Function '{func.__name__}' does not have memory type attributes" 

527 ) 

528 

529 return { 

530 "name": func.__name__, 

531 "input_memory_type": func.input_memory_type, 

532 "output_memory_type": func.output_memory_type, 

533 "backend": getattr(func, "backend", func.input_memory_type), 

534 "doc": func.__doc__, 

535 "module": func.__module__ 

536 } 

537 

538 

539def is_registry_initialized() -> bool: 

540 """ 

541 Check if the function registry has been initialized. 

542  

543 Thread-safe: Uses a lock to ensure consistent access to the initialization flag. 

544  

545 Returns: 

546 True if the registry is initialized, False otherwise 

547 """ 

548 with _registry_lock: 

549 return _registry_initialized 

550 

551 

552def get_valid_memory_types() -> Set[str]: 

553 """ 

554 Get the set of valid memory types. 

555 

556 Returns: 

557 A set of valid memory type names 

558 """ 

559 return VALID_MEMORY_TYPES.copy() 

560 

561 

562# Import hook system removed - using existing comprehensive registries 

563 

564 

565def get_function_by_name(function_name: str, memory_type: str) -> Optional[Callable]: 

566 """ 

567 Get a specific function by name and memory type from the registry. 

568 

569 Args: 

570 function_name: Name of the function to find 

571 memory_type: The memory type (e.g., "numpy", "cupy", "torch") 

572 

573 Returns: 

574 The function if found, None otherwise 

575 

576 Raises: 

577 RuntimeError: If the registry is not initialized 

578 ValueError: If the memory type is not valid 

579 """ 

580 functions = get_functions_by_memory_type(memory_type) 

581 

582 for func in functions: 

583 if func.__name__ == function_name: 

584 return func 

585 

586 return None 

587 

588 

589def get_all_function_names(memory_type: str) -> List[str]: 

590 """ 

591 Get all function names registered for a specific memory type. 

592 

593 Args: 

594 memory_type: The memory type (e.g., "numpy", "cupy", "torch") 

595 

596 Returns: 

597 A list of function names 

598 

599 Raises: 

600 RuntimeError: If the registry is not initialized 

601 ValueError: If the memory type is not valid 

602 """ 

603 functions = get_functions_by_memory_type(memory_type) 

604 return [func.__name__ for func in functions] 

605 

606 

607# Auto-initialize the registry on module import (following storage_registry pattern) 

608# Skip initialization in subprocess runner mode for faster startup 

609import os 

610if not os.environ.get('OPENHCS_SUBPROCESS_NO_GPU'): 610 ↛ exitline 610 didn't exit the module because the condition on line 610 was always true

611 _auto_initialize_registry()