Coverage for openhcs/processing/backends/lib_registry/unified_registry.py: 60.8%

386 statements  

« prev     ^ index     » next       coverage.py v7.10.7, created at 2025-10-01 18:33 +0000

1""" 

2Unified registry base class for external library function registration. 

3 

4This module provides a common base class that eliminates ~70% of code duplication 

5across library registries (pyclesperanto, scikit-image, cupy, etc.) while enforcing 

6consistent behavior and making it impossible to skip dynamic testing or hardcode 

7function lists. 

8 

9Key Benefits: 

10- Eliminates ~1000+ lines of duplicated code 

11- Enforces consistent testing and registration patterns 

12- Makes adding new libraries trivial (60-120 lines vs 350-400) 

13- Centralizes bug fixes and improvements 

14- Type-safe abstract interface prevents shortcuts 

15 

16Architecture: 

17- LibraryRegistryBase: Abstract base class with common functionality 

18- ProcessingContract: Unified contract enum across all libraries 

19- Dimension error adapter factory for consistent error handling 

20- Integrated caching system using existing cache_utils.py patterns 

21""" 

22 

23import importlib 

24import inspect 

25import importlib 

26import json 

27import logging 

28import time 

29from abc import ABC, abstractmethod 

30from dataclasses import dataclass, field 

31from enum import Enum, auto 

32from functools import wraps 

33from pathlib import Path 

34from typing import Any, Callable, Dict, List, Optional, Tuple, Set 

35 

36import numpy as np 

37 

38from openhcs.core.utils import optional_import 

39from openhcs.core.xdg_paths import get_cache_file_path 

40from openhcs.core.memory.stack_utils import unstack_slices, stack_slices 

41 

42logger = logging.getLogger(__name__) 

43 

44 

45# Enums for OpenHCS principle compliance (replace magic strings) 

46class ModuleFilterComponents(Enum): 

47 """Components to filter out when generating tags from module paths.""" 

48 BACKENDS = "backends" 

49 PROCESSING = "processing" 

50 OPENHCS = "openhcs" 

51 

52 @classmethod 

53 def should_skip(cls, component: str) -> bool: 

54 """Check if component should be skipped in tag generation.""" 

55 return any(component == item.value for item in cls) 

56 

57 

58class ProcessingContract(Enum): 

59 """ 

60 Unified contract classification with direct method execution. 

61 """ 

62 PURE_3D = "_execute_pure_3d" 

63 PURE_2D = "_execute_pure_2d" 

64 FLEXIBLE = "_execute_flexible" 

65 VOLUMETRIC_TO_SLICE = "_execute_volumetric_to_slice" 

66 

67 def execute(self, registry, func, image, *args, **kwargs): 

68 """Execute the contract method on the registry.""" 

69 method = getattr(registry, self.value) 

70 return method(func, image, *args, **kwargs) 

71 

72 

73@dataclass(frozen=True) 

74class FunctionMetadata: 

75 """Clean metadata with no library-specific leakage.""" 

76 

77 # Core fields only 

78 name: str 

79 func: Callable 

80 contract: ProcessingContract 

81 registry: 'LibraryRegistryBase' # Reference to the registry that registered this function - REQUIRED 

82 module: str = "" 

83 doc: str = "" 

84 tags: List[str] = field(default_factory=list) 

85 original_name: str = "" # Original function name for cache reconstruction 

86 

87 

88 

89 

90class LibraryRegistryBase(ABC): 

91 """ 

92 Minimal ABC for all library registries. 

93 

94 Provides only essential contracts that all registries must implement, 

95 regardless of whether they use runtime testing or explicit contracts. 

96 """ 

97 

98 # Common exclusions across all libraries 

99 COMMON_EXCLUSIONS = { 

100 'imread', 'imsave', 'load', 'save', 'read', 'write', 

101 'show', 'imshow', 'plot', 'display', 'view', 'visualize', 

102 'info', 'help', 'version', 'test', 'benchmark' 

103 } 

104 

105 # Abstract class attributes - each implementation must define these 

106 MODULES_TO_SCAN: List[str] 

107 MEMORY_TYPE: str # Memory type string value (e.g., "pyclesperanto", "cupy", "numpy") 

108 FLOAT_DTYPE: Any # Library-specific float32 type (np.float32, cp.float32, etc.) 

109 

110 def __init__(self, library_name: str): 

111 """ 

112 Initialize registry for a specific library. 

113 

114 Args: 

115 library_name: Name of the library (e.g., "pyclesperanto", "skimage") 

116 """ 

117 self.library_name = library_name 

118 self._cache_path = get_cache_file_path(f"{library_name}_function_metadata.json") 

119 

120 

121 

122 

123 

124 # ===== ESSENTIAL ABC METHODS ===== 

125 

126 # ===== LIBRARY IDENTIFICATION ===== 

127 @abstractmethod 

128 def get_library_version(self) -> str: 

129 """Get library version for cache validation.""" 

130 pass 

131 

132 @abstractmethod 

133 def is_library_available(self) -> bool: 

134 """Check if the library is available for import.""" 

135 pass 

136 

137 # ===== FUNCTION DISCOVERY ===== 

138 @abstractmethod 

139 def discover_functions(self) -> Dict[str, FunctionMetadata]: 

140 """Discover and return function metadata. Must be implemented by subclasses.""" 

141 pass 

142 

143 # ===== CONTRACT HANDLING ===== 

144 def apply_contract_wrapper(self, func: Callable, contract: ProcessingContract) -> Callable: 

145 """Apply contract-specific wrapper for all contract types.""" 

146 from functools import wraps 

147 import inspect 

148 

149 if contract == ProcessingContract.FLEXIBLE: 

150 # Check if function already has slice_by_slice parameter 

151 original_sig = inspect.signature(func) 

152 param_names = [p.name for p in original_sig.parameters.values()] 

153 

154 if 'slice_by_slice' in param_names: 154 ↛ 156line 154 didn't jump to line 156 because the condition on line 154 was never true

155 # Function already has slice_by_slice, just ensure type hint is correct 

156 if hasattr(func, '__annotations__'): 

157 func.__annotations__['slice_by_slice'] = bool 

158 else: 

159 func.__annotations__ = {'slice_by_slice': bool} 

160 return func 

161 

162 # Add slice_by_slice parameter 

163 new_params = list(original_sig.parameters.values()) 

164 slice_param = inspect.Parameter( 

165 'slice_by_slice', 

166 inspect.Parameter.KEYWORD_ONLY, 

167 default=False, 

168 annotation=bool # Explicit bool type hint for UI 

169 ) 

170 

171 # Insert before any VAR_KEYWORD (**kwargs) parameter 

172 insert_index = len(new_params) 

173 for i, param in enumerate(new_params): 

174 if param.kind == inspect.Parameter.VAR_KEYWORD: 

175 insert_index = i 

176 break 

177 

178 new_params.insert(insert_index, slice_param) 

179 

180 # Create the wrapper function 

181 @wraps(func) 

182 def flexible_wrapper(image, *args, slice_by_slice: bool = False, **kwargs): 

183 func.slice_by_slice = slice_by_slice 

184 return contract.execute(self, func, image, *args, **kwargs) 

185 

186 # Apply the modified signature AFTER @wraps 

187 new_sig = original_sig.replace(parameters=new_params) 

188 flexible_wrapper.__signature__ = new_sig 

189 

190 # Preserve original annotations and add slice_by_slice 

191 if hasattr(func, '__annotations__'): 191 ↛ 194line 191 didn't jump to line 194 because the condition on line 191 was always true

192 flexible_wrapper.__annotations__ = func.__annotations__.copy() 

193 else: 

194 flexible_wrapper.__annotations__ = {} 

195 flexible_wrapper.__annotations__['slice_by_slice'] = bool 

196 

197 flexible_wrapper.slice_by_slice = False 

198 return flexible_wrapper 

199 else: 

200 # For other contracts, wrap with contract execution 

201 @wraps(func) 

202 def contract_wrapper(image, *args, **kwargs): 

203 return contract.execute(self, func, image, *args, **kwargs) 

204 

205 return contract_wrapper 

206 

207 def _inject_optional_dataclass_params(self, func: Callable) -> Callable: 

208 """Inject optional lazy dataclass parameters into function signature. 

209 

210 Can be disabled by setting ENABLE_CONFIG_INJECTION = False. 

211 """ 

212 # Configuration flag to enable/disable config injection 

213 ENABLE_CONFIG_INJECTION = False # Set to True to re-enable config injection 

214 

215 if not ENABLE_CONFIG_INJECTION: 215 ↛ 219line 215 didn't jump to line 219 because the condition on line 215 was always true

216 return func # Return function unchanged when disabled 

217 

218 # Original injection logic (commented out for now but preserved) 

219 import inspect 

220 from functools import wraps 

221 from typing import Optional 

222 

223 # Get original signature 

224 original_sig = inspect.signature(func) 

225 original_params = list(original_sig.parameters.values()) 

226 

227 # Import existing lazy config types 

228 from openhcs.core.config import LazyNapariStreamingConfig, LazyFijiStreamingConfig, LazyStepMaterializationConfig 

229 

230 # Define common lazy dataclass parameters to inject 

231 dataclass_params = [ 

232 ('napari_streaming_config', 'Optional[LazyNapariStreamingConfig]', LazyNapariStreamingConfig), 

233 ('fiji_streaming_config', 'Optional[LazyFijiStreamingConfig]', LazyFijiStreamingConfig), 

234 ('step_materialization_config', 'Optional[LazyStepMaterializationConfig]', LazyStepMaterializationConfig), 

235 ] 

236 

237 # Check if any parameters need to be added 

238 existing_param_names = {p.name for p in original_params} 

239 params_to_add = [(name, type_hint, lazy_class) for name, type_hint, lazy_class in dataclass_params 

240 if name not in existing_param_names] 

241 

242 if not params_to_add: 

243 return func # No parameters to add 

244 

245 # Create new parameters 

246 new_params = original_params.copy() 

247 

248 # Find insertion point (before **kwargs if it exists) 

249 insert_index = len(new_params) 

250 for i, param in enumerate(new_params): 

251 if param.kind == inspect.Parameter.VAR_KEYWORD: 

252 insert_index = i 

253 break 

254 

255 # Add dataclass parameters 

256 from typing import Optional 

257 for param_name, type_hint, lazy_class in params_to_add: 

258 new_param = inspect.Parameter( 

259 param_name, 

260 inspect.Parameter.KEYWORD_ONLY, 

261 default=None, 

262 annotation=Optional[lazy_class] # Use actual type object, not string 

263 ) 

264 new_params.insert(insert_index, new_param) 

265 insert_index += 1 

266 

267 # Create enhanced wrapper function 

268 @wraps(func) 

269 def enhanced_wrapper(*args, **kwargs): 

270 # Extract dataclass parameters from kwargs (they're just ignored for now) 

271 regular_kwargs = {k: v for k, v in kwargs.items() 

272 if k not in [name for name, _, _ in dataclass_params]} 

273 

274 # Call original function with regular parameters only 

275 return func(*args, **regular_kwargs) 

276 

277 # Apply the modified signature 

278 new_sig = original_sig.replace(parameters=new_params) 

279 enhanced_wrapper.__signature__ = new_sig 

280 

281 # Enhance annotations 

282 if hasattr(func, '__annotations__'): 

283 enhanced_wrapper.__annotations__ = func.__annotations__.copy() 

284 else: 

285 enhanced_wrapper.__annotations__ = {} 

286 

287 # Add type annotations for injected parameters 

288 from typing import Optional 

289 for param_name, type_hint, lazy_class in params_to_add: 

290 enhanced_wrapper.__annotations__[param_name] = Optional[lazy_class] 

291 

292 return enhanced_wrapper 

293 

294 def _get_function_by_name(self, module_path: str, func_name: str): 

295 """Get function object by module path and name.""" 

296 module = importlib.import_module(module_path) 

297 return getattr(module, func_name) 

298 

299 # ===== PROCESSING CONTRACT EXECUTION METHODS ===== 

300 def _execute_slice_by_slice(self, func, image, *args, **kwargs): 

301 """Shared slice-by-slice execution logic.""" 

302 if image.ndim == 3: 

303 from openhcs.core.memory.stack_utils import unstack_slices, stack_slices, _detect_memory_type 

304 mem = _detect_memory_type(image) 

305 slices = unstack_slices(image, mem, 0) 

306 results = [func(sl, *args, **kwargs) for sl in slices] 

307 return stack_slices(results, mem, 0) 

308 return func(image, *args, **kwargs) 

309 

310 def _execute_pure_3d(self, func, image, *args, **kwargs): 

311 """Execute 3D→3D function directly (no change).""" 

312 return func(image, *args, **kwargs) 

313 

314 def _execute_pure_2d(self, func, image, *args, **kwargs): 

315 """Execute 2D→2D function with unstack/restack wrapper.""" 

316 from openhcs.core.memory.stack_utils import unstack_slices, stack_slices 

317 # Get memory type from the decorated function 

318 memory_type = func.output_memory_type 

319 slices = unstack_slices(image, memory_type, 0) 

320 results = [func(sl, *args, **kwargs) for sl in slices] 

321 return stack_slices(results, memory_type, 0) 

322 

323 def _execute_flexible(self, func, image, *args, **kwargs): 

324 """Execute function that handles both 3D→3D and 2D→2D with toggle.""" 

325 # Check if slice_by_slice attribute is set on the function 

326 slice_by_slice = getattr(func, 'slice_by_slice', False) 

327 if slice_by_slice: 

328 # Reuse the 2D-only execution logic (unstack -> process -> restack) 

329 return self._execute_pure_2d(func, image, *args, **kwargs) 

330 else: 

331 # Use 3D-only execution logic (no modification) 

332 return self._execute_pure_3d(func, image, *args, **kwargs) 

333 

334 def _execute_volumetric_to_slice(self, func, image, *args, **kwargs): 

335 """Execute 3D→2D function returning slice 3D array.""" 

336 from openhcs.core.memory.stack_utils import stack_slices 

337 # Get memory type from the decorated function 

338 memory_type = func.output_memory_type 

339 result_2d = func(image, *args, **kwargs) 

340 return stack_slices([result_2d], memory_type, 0) 

341 

342 # ===== CACHING METHODS ===== 

343 def _load_or_discover_functions(self) -> Dict[str, FunctionMetadata]: 

344 """Load functions from cache or discover them if cache is invalid.""" 

345 cached_functions = self._load_from_cache() 

346 if cached_functions is not None: 346 ↛ 347line 346 didn't jump to line 347 because the condition on line 346 was never true

347 logger.info(f"✅ Loaded {len(cached_functions)} {self.library_name} functions from cache") 

348 return cached_functions 

349 

350 logger.info(f"🔍 Cache miss for {self.library_name} - performing full discovery") 

351 functions = self.discover_functions() 

352 self._save_to_cache(functions) 

353 return functions 

354 

355 def _load_from_cache(self) -> Optional[Dict[str, FunctionMetadata]]: 

356 """Load function metadata from cache with validation.""" 

357 if not self._cache_path.exists(): 357 ↛ 360line 357 didn't jump to line 360 because the condition on line 357 was always true

358 return None 

359 

360 try: 

361 with open(self._cache_path, 'r') as f: 

362 cache_data = json.load(f) 

363 except json.JSONDecodeError: 

364 logger.warning(f"Corrupt cache file {self._cache_path}, rebuilding") 

365 self._cache_path.unlink(missing_ok=True) 

366 return None 

367 

368 if 'functions' not in cache_data: 

369 return None 

370 

371 cached_version = cache_data.get('library_version', 'unknown') 

372 current_version = self.get_library_version() 

373 if cached_version != current_version: 

374 logger.info(f"{self.library_name} version changed ({cached_version}{current_version}) - cache invalid") 

375 return None 

376 

377 cache_timestamp = cache_data.get('timestamp', 0) 

378 cache_age_days = (time.time() - cache_timestamp) / (24 * 3600) 

379 if cache_age_days > 7: 

380 logger.info(f"Cache is {cache_age_days:.1f} days old - rebuilding") 

381 return None 

382 

383 functions = {} 

384 for func_name, cached_data in cache_data['functions'].items(): 

385 original_name = cached_data.get('original_name', func_name) 

386 func = self._get_function_by_name(cached_data['module'], original_name) 

387 contract = ProcessingContract[cached_data['contract']] 

388 

389 # Apply the same wrappers as during discovery 

390 if hasattr(self, 'create_library_adapter'): 

391 # External library - apply library adapter + contract wrapper + param injection 

392 adapted_func = self.create_library_adapter(func, contract) 

393 contract_wrapped_func = self.apply_contract_wrapper(adapted_func, contract) 

394 final_func = self._inject_optional_dataclass_params(contract_wrapped_func) 

395 else: 

396 # OpenHCS - apply contract wrapper + param injection 

397 contract_wrapped_func = self.apply_contract_wrapper(func, contract) 

398 final_func = self._inject_optional_dataclass_params(contract_wrapped_func) 

399 

400 metadata = FunctionMetadata( 

401 name=func_name, 

402 func=final_func, 

403 contract=contract, 

404 registry=self, 

405 module=cached_data.get('module', ''), 

406 doc=cached_data.get('doc', ''), 

407 tags=cached_data.get('tags', []), 

408 original_name=cached_data.get('original_name', func_name) 

409 ) 

410 functions[func_name] = metadata 

411 

412 return functions 

413 

414 def _save_to_cache(self, functions: Dict[str, FunctionMetadata]) -> None: 

415 """Save function metadata to cache.""" 

416 cache_data = { 

417 'cache_version': '1.0', 

418 'library_version': self.get_library_version(), 

419 'timestamp': time.time(), 

420 'functions': { 

421 func_name: { 

422 'name': metadata.name, 

423 'original_name': metadata.original_name, 

424 'module': metadata.module, 

425 'contract': metadata.contract.name, 

426 'doc': metadata.doc, 

427 'tags': metadata.tags 

428 } 

429 for func_name, metadata in functions.items() 

430 } 

431 } 

432 

433 self._cache_path.parent.mkdir(parents=True, exist_ok=True) 

434 with open(self._cache_path, 'w') as f: 

435 json.dump(cache_data, f, indent=2) 

436 

437 logger.info(f"💾 Saved {len(functions)} {self.library_name} functions to cache") 

438 

439 def get_memory_type(self) -> str: 

440 """Get the memory type string value for this library.""" 

441 return self.MEMORY_TYPE 

442 

443 def get_module_patterns(self) -> List[str]: 

444 """Get module patterns that identify this library (can be overridden by implementations).""" 

445 # Default: just the library name 

446 return [self.library_name.lower()] 

447 

448 def get_display_name(self) -> str: 

449 """Get display name for this library (can be overridden by implementations).""" 

450 # Default: capitalize library name 

451 return self.library_name.title() 

452 

453 # ===== FUNCTION DISCOVERY ===== 

454 def get_modules_to_scan(self) -> List[Tuple[str, Any]]: 

455 """ 

456 Get list of (module_name, module_object) tuples to scan for functions. 

457 Uses the MODULES_TO_SCAN class attribute and library object from get_library_object(). 

458 

459 Returns: 

460 List of (name, module) pairs where name is for identification 

461 and module is the actual module object to scan. 

462 """ 

463 library = self.get_library_object() 

464 modules = [] 

465 for module_name in self.MODULES_TO_SCAN: 

466 if module_name == "": 

467 # Empty string means scan the main library namespace 

468 module = library 

469 modules.append(("main", module)) 

470 else: 

471 module = getattr(library, module_name) 

472 modules.append((module_name, module)) 

473 return modules 

474 

475 @abstractmethod 

476 def get_library_object(self): 

477 """Get the main library object to scan for modules. Library-specific implementation.""" 

478 pass 

479 

480 

481class RuntimeTestingRegistryBase(LibraryRegistryBase): 

482 """ 

483 Extended ABC for libraries that require runtime testing. 

484 

485 Adds runtime testing methods for libraries that don't have explicit 

486 processing contracts and need behavioral classification through testing. 

487 """ 

488 

489 def create_test_arrays(self) -> Tuple[Any, Any]: 

490 """ 

491 Create test arrays appropriate for this library. 

492 

493 Returns: 

494 Tuple of (test_3d, test_2d) arrays for behavior testing 

495 """ 

496 test_3d = self._create_array((3, 20, 20), self._get_float_dtype()) 

497 test_2d = self._create_array((20, 20), self._get_float_dtype()) 

498 return test_3d, test_2d 

499 

500 @abstractmethod 

501 def _create_array(self, shape: Tuple[int, ...], dtype): 

502 """Create array with specified shape and dtype. Library-specific implementation.""" 

503 pass 

504 

505 def _get_float_dtype(self): 

506 """Get the appropriate float dtype for this library.""" 

507 return self.FLOAT_DTYPE 

508 

509 # ===== CORE BEHAVIOR CONTRACT ===== 

510 def classify_function_behavior(self, func: Callable, declared_contract: Optional[ProcessingContract] = None) -> Tuple[ProcessingContract, bool]: 

511 """Classify function behavior by testing 3D and 2D inputs, or use declared contract if provided.""" 

512 

513 # Fast path: If explicit contract is declared, use it directly (skip runtime testing) 

514 if declared_contract is not None: 514 ↛ 515line 514 didn't jump to line 515 because the condition on line 514 was never true

515 return declared_contract, True 

516 test_3d, test_2d = self.create_test_arrays() 

517 

518 def test_function(test_array): 

519 """Test function with array, return (success, result).""" 

520 try: 

521 result = func(test_array) 

522 return True, result 

523 except: 

524 return False, None 

525 

526 works_3d, result_3d = test_function(test_3d) 

527 works_2d, _ = test_function(test_2d) 

528 

529 # Classification lookup table 

530 classification_map = { 

531 (True, True): self._classify_dual_support(result_3d), 

532 (True, False): ProcessingContract.PURE_3D, 

533 (False, True): ProcessingContract.PURE_2D, 

534 (False, False): None # Invalid function 

535 } 

536 

537 contract = classification_map[(works_3d, works_2d)] 

538 is_valid = works_3d or works_2d 

539 

540 return contract, is_valid 

541 

542 def _classify_dual_support(self, result_3d): 

543 """Classify functions that work on both 3D and 2D inputs.""" 

544 if result_3d is not None: 

545 # Handle tuple results (some functions return multiple arrays) 

546 if isinstance(result_3d, tuple): 

547 # Check the first element if it's a tuple 

548 first_result = result_3d[0] if len(result_3d) > 0 else None 

549 if hasattr(first_result, 'ndim') and first_result.ndim == 2: 549 ↛ 550line 549 didn't jump to line 550 because the condition on line 549 was never true

550 return ProcessingContract.VOLUMETRIC_TO_SLICE 

551 # Handle single array results 

552 elif hasattr(result_3d, 'ndim') and result_3d.ndim == 2: 

553 return ProcessingContract.VOLUMETRIC_TO_SLICE 

554 return ProcessingContract.FLEXIBLE 

555 

556 @abstractmethod 

557 def _stack_2d_results(self, func, test_3d): 

558 """Stack 2D results. Library-specific implementation required.""" 

559 pass 

560 

561 @abstractmethod 

562 def _arrays_close(self, arr1, arr2): 

563 """Compare arrays. Library-specific implementation required.""" 

564 pass 

565 

566 def create_library_adapter(self, original_func: Callable, contract: ProcessingContract) -> Callable: 

567 """Create adapter with library-specific processing only.""" 

568 import inspect 

569 func_name = getattr(original_func, '__name__', 'unknown') 

570 

571 # Get original signature to preserve it 

572 original_sig = inspect.signature(original_func) 

573 

574 def adapter(image, *args, **kwargs): 

575 processed_image = self._preprocess_input(image, func_name) 

576 result = contract.execute(self, original_func, processed_image, *args, **kwargs) 

577 return self._postprocess_output(result, image, func_name) 

578 

579 # Apply wraps and preserve signature 

580 wrapped_adapter = wraps(original_func)(adapter) 

581 wrapped_adapter.__signature__ = original_sig 

582 

583 # Preserve and enhance annotations 

584 if hasattr(original_func, '__annotations__'): 584 ↛ 587line 584 didn't jump to line 587 because the condition on line 584 was always true

585 wrapped_adapter.__annotations__ = original_func.__annotations__.copy() 

586 else: 

587 wrapped_adapter.__annotations__ = {} 

588 

589 # Extract type hints from docstring if annotations are missing 

590 self._enhance_annotations_from_docstring(wrapped_adapter, original_func) 

591 

592 # Set memory type attributes for contract execution compatibility 

593 # Only set if registry has a specific memory type (external libraries) 

594 if self.MEMORY_TYPE is not None: 594 ↛ 597line 594 didn't jump to line 597 because the condition on line 594 was always true

595 wrapped_adapter.input_memory_type = self.MEMORY_TYPE 

596 wrapped_adapter.output_memory_type = self.MEMORY_TYPE 

597 wrapped_adapter.stream_to_napari = False 

598 

599 return wrapped_adapter 

600 

601 def _enhance_annotations_from_docstring(self, wrapped_func: Callable, original_func: Callable): 

602 """Extract type hints from docstring using mathematical simplification approach.""" 

603 try: 

604 from openhcs.textual_tui.widgets.shared.signature_analyzer import SignatureAnalyzer 

605 import numpy as np 

606 

607 # Unified type extraction with compatibility validation (mathematical simplification) 

608 TYPE_PATTERNS = {'ndarray': np.ndarray, 'array': np.ndarray, 'array_like': np.ndarray, 

609 'int': int, 'integer': int, 'float': float, 'scalar': float, 

610 'bool': bool, 'boolean': bool, 'str': str, 'string': str, 

611 'tuple': tuple, 'list': list, 'dict': dict, 'sequence': list} 

612 

613 COMPATIBLE_DEFAULTS = {float: (int, float, range), int: (int, float), 

614 list: (list, tuple, range), tuple: (list, tuple, range)} 

615 

616 param_info = SignatureAnalyzer.analyze(original_func, skip_first_param=False) 

617 

618 # Inline type extraction and validation (single-use function inlining rule) 

619 for param_name, info in param_info.items(): 

620 if param_name not in wrapped_func.__annotations__ and info.description: 

621 # Inline type extraction with priority patterns 

622 desc = info.description.lower().replace(', optional', '').replace(' optional', '').split(' or ')[0].strip() 

623 python_type = (str if desc.startswith('{') and '}' in desc 

624 else list if any(p in desc for p in ['sequence', 'iterable', 'array of', 'list of']) 

625 else next((t for pattern, t in TYPE_PATTERNS.items() if pattern in desc), None)) 

626 

627 # Inline compatibility check (single-use function inlining rule) 

628 if python_type and (info.default_value is None or 

629 type(info.default_value) in COMPATIBLE_DEFAULTS.get(python_type, (python_type,))): 

630 wrapped_func.__annotations__[param_name] = python_type 

631 except Exception: 

632 pass 

633 

634 @abstractmethod 

635 def _preprocess_input(self, image, func_name: str): 

636 """Preprocess input image. Library-specific implementation.""" 

637 pass 

638 

639 @abstractmethod 

640 def _postprocess_output(self, result, original_image, func_name: str): 

641 """Postprocess output result. Library-specific implementation.""" 

642 pass 

643 

644 # ===== BASIC FILTERING ===== 

645 def should_include_function(self, func: Callable, func_name: str) -> bool: 

646 """Single method for all filtering logic (blacklist, signature, etc.)""" 

647 # Skip private functions 

648 if func_name.startswith('_'): 648 ↛ 649line 648 didn't jump to line 649 because the condition on line 648 was never true

649 return False 

650 

651 # Skip exclusions (check both common and library-specific) 

652 exclusions = getattr(self.__class__, 'EXCLUSIONS', self.COMMON_EXCLUSIONS) 

653 if func_name.lower() in exclusions: 653 ↛ 654line 653 didn't jump to line 654 because the condition on line 653 was never true

654 return False 

655 

656 # Skip classes and types 

657 if inspect.isclass(func) or isinstance(func, type): 

658 return False 

659 

660 # Must be callable 

661 if not callable(func): 

662 return False 

663 

664 # Pure functions must have at least one parameter 

665 sig = inspect.signature(func) 

666 params = list(sig.parameters.values()) 

667 if not params: 667 ↛ 668line 667 didn't jump to line 668 because the condition on line 667 was never true

668 return False 

669 

670 # Validate that type hints can be resolved (skip functions with missing dependencies) 

671 if not self._validate_type_hints(func, func_name): 671 ↛ 672line 671 didn't jump to line 672 because the condition on line 671 was never true

672 return False 

673 

674 # Library-specific signature validation 

675 return self._check_first_parameter(params[0], func_name) 

676 

677 

678 def _validate_type_hints(self, func: Callable, func_name: str) -> bool: 

679 """ 

680 Validate that function type hints can be resolved. 

681 

682 Returns False if type hints reference missing dependencies (e.g., torch when not installed). 

683 This prevents functions with unresolvable type hints from being registered. 

684 """ 

685 try: 

686 from typing import get_type_hints 

687 # Try to resolve type hints - this will fail if dependencies are missing 

688 get_type_hints(func) 

689 return True 

690 except NameError as e: 

691 # Type hint references a missing dependency (e.g., 'torch' not defined) 

692 logger.warning(f"Skipping function '{func_name}' due to unresolvable type hints: {e}") 

693 return False 

694 except Exception: 

695 # Other type hint resolution errors - be conservative and allow the function 

696 # (this handles edge cases where get_type_hints fails for other reasons) 

697 return True 

698 

699 @abstractmethod 

700 def _check_first_parameter(self, first_param, func_name: str) -> bool: 

701 """Check if first parameter meets library-specific criteria. Library-specific implementation.""" 

702 pass 

703 

704 # ===== RUNTIME TESTING IMPLEMENTATION ===== 

705 def discover_functions(self) -> Dict[str, FunctionMetadata]: 

706 """Discover and classify all library functions with runtime testing.""" 

707 functions = {} 

708 modules = self.get_modules_to_scan() 

709 logger.info(f"🔍 Starting function discovery for {self.library_name}") 

710 logger.info(f"📦 Scanning {len(modules)} modules: {[name for name, _ in modules]}") 

711 

712 total_tested = 0 

713 total_accepted = 0 

714 

715 for module_name, module in modules: 

716 logger.info(f" 📦 Analyzing {module_name} ({module})...") 

717 module_tested = 0 

718 module_accepted = 0 

719 

720 for name in dir(module): 

721 if name.startswith("_"): 

722 continue 

723 

724 func = getattr(module, name) 

725 full_path = self._get_full_function_path(module, name, module_name) 

726 

727 if not self.should_include_function(func, name): 

728 rejection_reason = self._get_rejection_reason(func, name) 

729 if rejection_reason != "private": 729 ↛ 731line 729 didn't jump to line 731 because the condition on line 729 was always true

730 logger.info(f" 🚫 Skipping {full_path}: {rejection_reason}") 

731 continue 

732 

733 module_tested += 1 

734 total_tested += 1 

735 

736 contract, is_valid = self.classify_function_behavior(func) 

737 logger.info(f" 🧪 Testing {full_path}") 

738 logger.info(f" Classification: {contract.name if contract else contract}") 

739 

740 if not is_valid: 

741 logger.info(f" ❌ Rejected: Invalid classification") 

742 continue 

743 

744 doc_lines = (func.__doc__ or "").splitlines() 

745 first_line_doc = doc_lines[0] if doc_lines else "" 

746 func_name = self._generate_function_name(name, module_name) 

747 

748 # Apply library adapter (preprocessing/postprocessing) 

749 adapted_func = self.create_library_adapter(func, contract) 

750 

751 # Apply contract wrapper (slice_by_slice for FLEXIBLE) 

752 contract_wrapped_func = self.apply_contract_wrapper(adapted_func, contract) 

753 

754 # Inject optional dataclass parameters 

755 final_func = self._inject_optional_dataclass_params(contract_wrapped_func) 

756 

757 metadata = FunctionMetadata( 

758 name=func_name, 

759 func=final_func, 

760 contract=contract, 

761 registry=self, 

762 module=func.__module__ or "", 

763 doc=first_line_doc, 

764 tags=self._generate_tags(name), 

765 original_name=name 

766 ) 

767 

768 functions[func_name] = metadata 

769 module_accepted += 1 

770 total_accepted += 1 

771 logger.info(f" ✅ Accepted as '{func_name}'") 

772 

773 logger.info(f" 📊 Module {module_name}: {module_accepted}/{module_tested} functions accepted") 

774 

775 logger.info(f"✅ Discovery complete: {total_accepted}/{total_tested} functions accepted") 

776 return functions 

777 

778 

779 

780 def _get_full_function_path(self, module, func_name: str, module_name: str) -> str: 

781 """Generate full module path for logging.""" 

782 if module_name == "main": 782 ↛ 783line 782 didn't jump to line 783 because the condition on line 782 was never true

783 return f"{self.library_name}.{func_name}" 

784 else: 

785 # Extract clean module path 

786 module_str = str(module) 

787 if "'" in module_str: 787 ↛ 791line 787 didn't jump to line 791 because the condition on line 787 was always true

788 clean_path = module_str.split("'")[1] 

789 return f"{clean_path}.{func_name}" 

790 else: 

791 return f"{module_name}.{func_name}" 

792 

793 def _get_rejection_reason(self, func: Callable, func_name: str) -> str: 

794 """Get detailed reason why a function was rejected.""" 

795 # Check each rejection criteria in order 

796 if func_name.startswith('_'): 796 ↛ 797line 796 didn't jump to line 797 because the condition on line 796 was never true

797 return "private" 

798 

799 exclusions = getattr(self.__class__, 'EXCLUSIONS', self.COMMON_EXCLUSIONS) 

800 if func_name.lower() in exclusions: 800 ↛ 801line 800 didn't jump to line 801 because the condition on line 800 was never true

801 return "blacklisted" 

802 

803 if inspect.isclass(func) or isinstance(func, type): 

804 return "is class/type" 

805 

806 if not callable(func): 

807 return "not callable" 

808 

809 try: 

810 sig = inspect.signature(func) 

811 params = list(sig.parameters.values()) 

812 if not params: 812 ↛ 813line 812 didn't jump to line 813 because the condition on line 812 was never true

813 return "no parameters (not pure function)" 

814 except (ValueError, TypeError): 

815 return "invalid signature" 

816 

817 return "unknown" 

818 

819 

820 

821 # ===== CUSTOMIZATION HOOKS ===== 

822 def _generate_function_name(self, name: str, module_name: str) -> str: 

823 """Generate function name. Override in subclasses for custom naming.""" 

824 return name 

825 

826 def _generate_tags(self, func_name: str) -> List[str]: 

827 """Generate tags using library name.""" 

828 return [self.library_name]