Coverage for openhcs/processing/backends/lib_registry/unified_registry.py: 66.2%

253 statements  

« prev     ^ index     » next       coverage.py v7.10.3, created at 2025-08-14 05:57 +0000

1""" 

2Unified registry base class for external library function registration. 

3 

4This module provides a common base class that eliminates ~70% of code duplication 

5across library registries (pyclesperanto, scikit-image, cupy, etc.) while enforcing 

6consistent behavior and making it impossible to skip dynamic testing or hardcode 

7function lists. 

8 

9Key Benefits: 

10- Eliminates ~1000+ lines of duplicated code 

11- Enforces consistent testing and registration patterns 

12- Makes adding new libraries trivial (60-120 lines vs 350-400) 

13- Centralizes bug fixes and improvements 

14- Type-safe abstract interface prevents shortcuts 

15 

16Architecture: 

17- LibraryRegistryBase: Abstract base class with common functionality 

18- ProcessingContract: Unified contract enum across all libraries 

19- Dimension error adapter factory for consistent error handling 

20- Integrated caching system using existing cache_utils.py patterns 

21""" 

22 

23import importlib 

24import inspect 

25import json 

26import logging 

27import time 

28from abc import ABC, abstractmethod 

29from dataclasses import dataclass, field 

30from enum import Enum, auto 

31from functools import wraps 

32from pathlib import Path 

33from typing import Any, Callable, Dict, List, Optional, Tuple, Set 

34 

35import numpy as np 

36 

37from openhcs.core.utils import optional_import 

38from openhcs.core.xdg_paths import get_cache_file_path 

39from openhcs.core.memory.stack_utils import unstack_slices, stack_slices 

40 

41logger = logging.getLogger(__name__) 

42 

43 

44class ProcessingContract(Enum): 

45 """ 

46 Unified contract classification with direct method execution. 

47 """ 

48 PURE_3D = "_execute_pure_3d" 

49 PURE_2D = "_execute_pure_2d" 

50 FLEXIBLE = "_execute_flexible" 

51 VOLUMETRIC_TO_SLICE = "_execute_volumetric_to_slice" 

52 

53 def execute(self, registry, func, image, *args, **kwargs): 

54 """Execute the contract method on the registry.""" 

55 method = getattr(registry, self.value) 

56 return method(func, image, *args, **kwargs) 

57 

58 

59@dataclass(frozen=True) 

60class FunctionMetadata: 

61 """Clean metadata with no library-specific leakage.""" 

62 

63 # Core fields only 

64 name: str 

65 func: Callable 

66 contract: ProcessingContract 

67 module: str = "" 

68 doc: str = "" 

69 tags: List[str] = field(default_factory=list) 

70 original_name: str = "" # Original function name for cache reconstruction 

71 

72 

73 

74 

75class LibraryRegistryBase(ABC): 

76 """ 

77 Clean abstraction with essential contracts only. 

78 

79 Enforces only essential behavior contracts, not library-specific details. 

80 Each registry implements the contract its own way while returning unified 

81 ProcessingContract and FunctionMetadata types. 

82 

83 Essential contracts: 

84 - Test function behavior to determine: 3D→3D, 2D→2D only, 3D→2D, etc. 

85 - Create adapters based on contract classification 

86 - Filter functions using consolidated logic 

87 - Provide library identification and discovery 

88 """ 

89 

90 # Common exclusions across all libraries 

91 COMMON_EXCLUSIONS = { 

92 'imread', 'imsave', 'load', 'save', 'read', 'write', 

93 'show', 'imshow', 'plot', 'display', 'view', 'visualize', 

94 'info', 'help', 'version', 'test', 'benchmark' 

95 } 

96 

97 # Abstract class attributes - each implementation must define these 

98 MODULES_TO_SCAN: List[str] 

99 MEMORY_TYPE: str # Memory type string value (e.g., "pyclesperanto", "cupy", "numpy") 

100 FLOAT_DTYPE: Any # Library-specific float32 type (np.float32, cp.float32, etc.) 

101 

102 def __init__(self, library_name: str): 

103 """ 

104 Initialize registry for a specific library. 

105  

106 Args: 

107 library_name: Name of the library (e.g., "pyclesperanto", "skimage") 

108 """ 

109 self.library_name = library_name 

110 self._cache_path = get_cache_file_path(f"{library_name}_function_metadata.json") 

111 

112 # ===== ESSENTIAL ABC METHODS ===== 

113 

114 # ===== LIBRARY IDENTIFICATION ===== 

115 @abstractmethod 

116 def get_library_version(self) -> str: 

117 """Get library version for cache validation.""" 

118 pass 

119 

120 @abstractmethod 

121 def is_library_available(self) -> bool: 

122 """Check if the library is available for import.""" 

123 pass 

124 

125 def get_memory_type(self) -> str: 

126 """Get the memory type string value for this library.""" 

127 return self.MEMORY_TYPE 

128 

129 # ===== FUNCTION DISCOVERY ===== 

130 def get_modules_to_scan(self) -> List[Tuple[str, Any]]: 

131 """ 

132 Get list of (module_name, module_object) tuples to scan for functions. 

133 Uses the MODULES_TO_SCAN class attribute and library object from get_library_object(). 

134 

135 Returns: 

136 List of (name, module) pairs where name is for identification 

137 and module is the actual module object to scan. 

138 """ 

139 library = self.get_library_object() 

140 modules = [] 

141 for module_name in self.MODULES_TO_SCAN: 

142 if module_name == "": 

143 # Empty string means scan the main library namespace 

144 module = library 

145 modules.append(("main", module)) 

146 else: 

147 module = getattr(library, module_name) 

148 modules.append((module_name, module)) 

149 return modules 

150 

151 @abstractmethod 

152 def get_library_object(self): 

153 """Get the main library object to scan for modules. Library-specific implementation.""" 

154 pass 

155 

156 def create_test_arrays(self) -> Tuple[Any, Any]: 

157 """ 

158 Create test arrays appropriate for this library. 

159 

160 Returns: 

161 Tuple of (test_3d, test_2d) arrays for behavior testing 

162 """ 

163 test_3d = self._create_array((3, 20, 20), self._get_float_dtype()) 

164 test_2d = self._create_array((20, 20), self._get_float_dtype()) 

165 return test_3d, test_2d 

166 

167 @abstractmethod 

168 def _create_array(self, shape: Tuple[int, ...], dtype): 

169 """Create array with specified shape and dtype. Library-specific implementation.""" 

170 pass 

171 

172 def _get_float_dtype(self): 

173 """Get the appropriate float dtype for this library.""" 

174 return self.FLOAT_DTYPE 

175 

176 # ===== CORE BEHAVIOR CONTRACT ===== 

177 def classify_function_behavior(self, func: Callable) -> Tuple[ProcessingContract, bool]: 

178 """Classify function behavior by testing 3D and 2D inputs.""" 

179 test_3d, test_2d = self.create_test_arrays() 

180 

181 def test_function(test_array): 

182 """Test function with array, return (success, result).""" 

183 try: 

184 result = func(test_array) 

185 return True, result 

186 except: 

187 return False, None 

188 

189 works_3d, result_3d = test_function(test_3d) 

190 works_2d, _ = test_function(test_2d) 

191 

192 # Classification lookup table 

193 classification_map = { 

194 (True, True): self._classify_dual_support(result_3d), 

195 (True, False): ProcessingContract.PURE_3D, 

196 (False, True): ProcessingContract.PURE_2D, 

197 (False, False): None # Invalid function 

198 } 

199 

200 contract = classification_map[(works_3d, works_2d)] 

201 is_valid = works_3d or works_2d 

202 

203 return contract, is_valid 

204 

205 def _classify_dual_support(self, result_3d): 

206 """Classify functions that work on both 3D and 2D inputs.""" 

207 if result_3d is not None and result_3d.ndim == 2: 207 ↛ 208line 207 didn't jump to line 208 because the condition on line 207 was never true

208 return ProcessingContract.VOLUMETRIC_TO_SLICE 

209 return ProcessingContract.FLEXIBLE 

210 

211 @abstractmethod 

212 def _stack_2d_results(self, func, test_3d): 

213 """Stack 2D results. Library-specific implementation required.""" 

214 pass 

215 

216 @abstractmethod 

217 def _arrays_close(self, arr1, arr2): 

218 """Compare arrays. Library-specific implementation required.""" 

219 pass 

220 

221 def create_library_adapter(self, original_func: Callable, contract: ProcessingContract) -> Callable: 

222 """Create adapter based on contract classification.""" 

223 func_name = getattr(original_func, '__name__', 'unknown') 

224 

225 @wraps(original_func) 

226 def unified_adapter(image, *args, slice_by_slice: bool = False, **kwargs): 

227 # Library-specific preprocessing 

228 processed_image = self._preprocess_input(image, func_name) 

229 

230 # Contract-based execution 

231 result = contract.execute(self, original_func, processed_image, *args, **kwargs) 

232 

233 # Library-specific postprocessing 

234 return self._postprocess_output(result, image, func_name) 

235 

236 return unified_adapter 

237 

238 @abstractmethod 

239 def _preprocess_input(self, image, func_name: str): 

240 """Preprocess input image. Library-specific implementation.""" 

241 pass 

242 

243 @abstractmethod 

244 def _postprocess_output(self, result, original_image, func_name: str): 

245 """Postprocess output result. Library-specific implementation.""" 

246 pass 

247 

248 # ===== BASIC FILTERING ===== 

249 def should_include_function(self, func: Callable, func_name: str) -> bool: 

250 """Single method for all filtering logic (blacklist, signature, etc.)""" 

251 # Skip private functions 

252 if func_name.startswith('_'): 252 ↛ 253line 252 didn't jump to line 253 because the condition on line 252 was never true

253 return False 

254 

255 # Skip exclusions (check both common and library-specific) 

256 exclusions = getattr(self.__class__, 'EXCLUSIONS', self.COMMON_EXCLUSIONS) 

257 if func_name.lower() in exclusions: 257 ↛ 258line 257 didn't jump to line 258 because the condition on line 257 was never true

258 return False 

259 

260 # Skip classes and types 

261 if inspect.isclass(func) or isinstance(func, type): 

262 return False 

263 

264 # Must be callable 

265 if not callable(func): 

266 return False 

267 

268 # Pure functions must have at least one parameter 

269 sig = inspect.signature(func) 

270 params = list(sig.parameters.values()) 

271 if not params: 271 ↛ 272line 271 didn't jump to line 272 because the condition on line 271 was never true

272 return False 

273 

274 # Library-specific signature validation 

275 return self._check_first_parameter(params[0], func_name) 

276 

277 

278 

279 @abstractmethod 

280 def _check_first_parameter(self, first_param, func_name: str) -> bool: 

281 """Check if first parameter meets library-specific criteria. Library-specific implementation.""" 

282 pass 

283 

284 # ===== SHARED IMPLEMENTATION LOGIC ===== 

285 def discover_functions(self) -> Dict[str, FunctionMetadata]: 

286 """Discover and classify all library functions with detailed logging.""" 

287 functions = {} 

288 modules = self.get_modules_to_scan() 

289 logger.info(f"🔍 Starting function discovery for {self.library_name}") 

290 logger.info(f"📦 Scanning {len(modules)} modules: {[name for name, _ in modules]}") 

291 

292 total_tested = 0 

293 total_accepted = 0 

294 

295 for module_name, module in modules: 

296 logger.info(f" 📦 Analyzing {module_name} ({module})...") 

297 module_tested = 0 

298 module_accepted = 0 

299 

300 for name in dir(module): 

301 if name.startswith("_"): 

302 continue 

303 

304 func = getattr(module, name) 

305 full_path = self._get_full_function_path(module, name, module_name) 

306 

307 if not self.should_include_function(func, name): 

308 rejection_reason = self._get_rejection_reason(func, name) 

309 if rejection_reason != "private": 309 ↛ 311line 309 didn't jump to line 311 because the condition on line 309 was always true

310 logger.info(f" 🚫 Skipping {full_path}: {rejection_reason}") 

311 continue 

312 

313 module_tested += 1 

314 total_tested += 1 

315 

316 contract, is_valid = self.classify_function_behavior(func) 

317 logger.info(f" 🧪 Testing {full_path}") 

318 logger.info(f" Classification: {contract.name if contract else contract}") 

319 

320 if not is_valid: 

321 logger.info(f" ❌ Rejected: Invalid classification") 

322 continue 

323 

324 doc_lines = (func.__doc__ or "").splitlines() 

325 first_line_doc = doc_lines[0] if doc_lines else "" 

326 func_name = self._generate_function_name(name, module_name) 

327 

328 metadata = FunctionMetadata( 

329 name=func_name, 

330 func=func, 

331 contract=contract, 

332 module=func.__module__ or "", 

333 doc=first_line_doc, 

334 tags=self._generate_tags(name), 

335 original_name=name 

336 ) 

337 

338 functions[func_name] = metadata 

339 module_accepted += 1 

340 total_accepted += 1 

341 logger.info(f" ✅ Accepted as '{func_name}'") 

342 

343 logger.info(f" 📊 Module {module_name}: {module_accepted}/{module_tested} functions accepted") 

344 

345 logger.info(f"✅ Discovery complete: {total_accepted}/{total_tested} functions accepted") 

346 return functions 

347 

348 def _get_full_function_path(self, module, func_name: str, module_name: str) -> str: 

349 """Generate full module path for logging.""" 

350 if module_name == "main": 350 ↛ 351line 350 didn't jump to line 351 because the condition on line 350 was never true

351 return f"{self.library_name}.{func_name}" 

352 else: 

353 # Extract clean module path 

354 module_str = str(module) 

355 if "'" in module_str: 355 ↛ 359line 355 didn't jump to line 359 because the condition on line 355 was always true

356 clean_path = module_str.split("'")[1] 

357 return f"{clean_path}.{func_name}" 

358 else: 

359 return f"{module_name}.{func_name}" 

360 

361 def _get_rejection_reason(self, func: Callable, func_name: str) -> str: 

362 """Get detailed reason why a function was rejected.""" 

363 # Check each rejection criteria in order 

364 if func_name.startswith('_'): 364 ↛ 365line 364 didn't jump to line 365 because the condition on line 364 was never true

365 return "private" 

366 

367 exclusions = getattr(self.__class__, 'EXCLUSIONS', self.COMMON_EXCLUSIONS) 

368 if func_name.lower() in exclusions: 368 ↛ 369line 368 didn't jump to line 369 because the condition on line 368 was never true

369 return "blacklisted" 

370 

371 if inspect.isclass(func) or isinstance(func, type): 

372 return "is class/type" 

373 

374 if not callable(func): 

375 return "not callable" 

376 

377 try: 

378 sig = inspect.signature(func) 

379 params = list(sig.parameters.values()) 

380 if not params: 380 ↛ 381line 380 didn't jump to line 381 because the condition on line 380 was never true

381 return "no parameters (not pure function)" 

382 except (ValueError, TypeError): 

383 return "invalid signature" 

384 

385 return "unknown" 

386 

387 # ===== CACHING METHODS ===== 

388 def _load_or_discover_functions(self) -> Dict[str, FunctionMetadata]: 

389 """Load functions from cache or discover them if cache is invalid.""" 

390 cached_functions = self._load_from_cache() 

391 if cached_functions is not None: 391 ↛ 392line 391 didn't jump to line 392 because the condition on line 391 was never true

392 logger.info(f"✅ Loaded {len(cached_functions)} {self.library_name} functions from cache") 

393 return cached_functions 

394 

395 logger.info(f"🔍 Cache miss for {self.library_name} - performing full discovery") 

396 functions = self.discover_functions() 

397 self._save_to_cache(functions) 

398 return functions 

399 

400 def _load_from_cache(self) -> Optional[Dict[str, FunctionMetadata]]: 

401 """Load function metadata from cache with validation.""" 

402 if not self._cache_path.exists(): 

403 return None 

404 

405 with open(self._cache_path, 'r') as f: 

406 cache_data = json.load(f) 

407 

408 if 'functions' not in cache_data: 

409 return None 

410 

411 cached_version = cache_data.get('library_version', 'unknown') 

412 current_version = self.get_library_version() 

413 if cached_version != current_version: 

414 logger.info(f"{self.library_name} version changed ({cached_version}{current_version}) - cache invalid") 

415 return None 

416 

417 cache_timestamp = cache_data.get('timestamp', 0) 

418 cache_age_days = (time.time() - cache_timestamp) / (24 * 3600) 

419 if cache_age_days > 7: 

420 logger.info(f"Cache is {cache_age_days:.1f} days old - rebuilding") 

421 return None 

422 

423 functions = {} 

424 for func_name, cached_data in cache_data['functions'].items(): 

425 original_name = cached_data.get('original_name', func_name) 

426 func = self._get_function_by_name(cached_data['module'], original_name) 

427 contract = ProcessingContract[cached_data['contract']] 

428 

429 metadata = FunctionMetadata( 

430 name=func_name, 

431 func=func, 

432 contract=contract, 

433 module=cached_data.get('module', ''), 

434 doc=cached_data.get('doc', ''), 

435 tags=cached_data.get('tags', []), 

436 original_name=cached_data.get('original_name', func_name) 

437 ) 

438 functions[func_name] = metadata 

439 

440 return functions 

441 

442 def register_functions_direct(self): 

443 """Register functions directly with OpenHCS function registry using shared logic.""" 

444 from openhcs.processing.func_registry import _apply_unified_decoration, _register_function 

445 from openhcs.constants import MemoryType 

446 

447 functions = self._load_or_discover_functions() 

448 registered_count = 0 

449 

450 for func_name, metadata in functions.items(): 

451 adapted = self.create_library_adapter(metadata.func, metadata.contract) 

452 memory_type_enum = MemoryType(self.get_memory_type()) 

453 wrapper_func = _apply_unified_decoration( 

454 original_func=adapted, 

455 func_name=metadata.name, 

456 memory_type=memory_type_enum, 

457 create_wrapper=True 

458 ) 

459 

460 _register_function(wrapper_func, self.get_memory_type()) 

461 registered_count += 1 

462 

463 logger.info(f"Registered {registered_count} {self.library_name} functions") 

464 return registered_count 

465 

466 # ===== SHARED ADAPTER LOGIC ===== 

467 def _execute_slice_by_slice(self, func, image, *args, **kwargs): 

468 """Shared slice-by-slice execution logic.""" 

469 if image.ndim == 3: 

470 from openhcs.core.memory.stack_utils import unstack_slices, stack_slices, _detect_memory_type 

471 mem = _detect_memory_type(image) 

472 slices = unstack_slices(image, mem, 0) 

473 results = [func(sl, *args, **kwargs) for sl in slices] 

474 return stack_slices(results, mem, 0) 

475 return func(image, *args, **kwargs) 

476 

477 # ===== PROCESSING CONTRACT EXECUTION METHODS ===== 

478 def _execute_pure_3d(self, func, image, *args, **kwargs): 

479 """Execute 3D→3D function directly (no change).""" 

480 return func(image, *args, **kwargs) 

481 

482 def _execute_pure_2d(self, func, image, *args, **kwargs): 

483 """Execute 2D→2D function with unstack/restack wrapper.""" 

484 slices = unstack_slices(image, self.MEMORY_TYPE, 0) 

485 results = [func(sl, *args, **kwargs) for sl in slices] 

486 return stack_slices(results, self.MEMORY_TYPE, 0) 

487 

488 def _execute_flexible(self, func, image, *args, slice_by_slice: bool = False, **kwargs): 

489 """Execute function that handles both 3D→3D and 2D→2D with toggle.""" 

490 if slice_by_slice: 

491 return self._execute_pure_2d(func, image, *args, **kwargs) 

492 else: 

493 return self._execute_pure_3d(func, image, *args, **kwargs) 

494 

495 def _execute_volumetric_to_slice(self, func, image, *args, **kwargs): 

496 """Execute 3D→2D function returning slice 3D array.""" 

497 result_2d = func(image, *args, **kwargs) 

498 return stack_slices([result_2d], self.MEMORY_TYPE, 0) 

499 

500 # ===== CUSTOMIZATION HOOKS ===== 

501 def _generate_function_name(self, name: str, module_name: str) -> str: 

502 """Generate function name. Override in subclasses for custom naming.""" 

503 return name 

504 

505 def _generate_tags(self, func_name: str) -> List[str]: 

506 """Generate tags. Override in subclasses for custom tags.""" 

507 return func_name.lower().replace("_", " ").split() 

508 

509 def _promote_2d_to_3d(self, result): 

510 """Promote 2D results to 3D using library-specific expansion method.""" 

511 if result.ndim == 2: 

512 return self._expand_2d_to_3d(result) 

513 elif isinstance(result, tuple) and result[0].ndim == 2: 

514 expanded_first = self._expand_2d_to_3d(result[0]) 

515 return (expanded_first, *result[1:]) 

516 return result 

517 

518 @abstractmethod 

519 def _expand_2d_to_3d(self, array_2d): 

520 """Expand 2D array to 3D. Library-specific implementation required.""" 

521 pass 

522 

523 def _save_to_cache(self, functions: Dict[str, FunctionMetadata]) -> None: 

524 """Save function metadata to cache.""" 

525 cache_data = { 

526 'cache_version': '1.0', 

527 'library_version': self.get_library_version(), 

528 'timestamp': time.time(), 

529 'functions': { 

530 func_name: { 

531 'name': metadata.name, 

532 'original_name': metadata.original_name, 

533 'module': metadata.module, 

534 'contract': metadata.contract.name, 

535 'doc': metadata.doc, 

536 'tags': metadata.tags 

537 } 

538 for func_name, metadata in functions.items() 

539 } 

540 } 

541 

542 self._cache_path.parent.mkdir(parents=True, exist_ok=True) 

543 with open(self._cache_path, 'w') as f: 

544 json.dump(cache_data, f, indent=2) 

545 

546 def _get_function_by_name(self, module_path: str, func_name: str) -> Optional[Callable]: 

547 """Reconstruct function object from module path and function name.""" 

548 module = importlib.import_module(module_path) 

549 return getattr(module, func_name) 

550 

551 

552