Coverage for openhcs/processing/backends/lib_registry/unified_registry.py: 60.8%
386 statements
« prev ^ index » next coverage.py v7.10.7, created at 2025-10-01 18:33 +0000
« prev ^ index » next coverage.py v7.10.7, created at 2025-10-01 18:33 +0000
1"""
2Unified registry base class for external library function registration.
4This module provides a common base class that eliminates ~70% of code duplication
5across library registries (pyclesperanto, scikit-image, cupy, etc.) while enforcing
6consistent behavior and making it impossible to skip dynamic testing or hardcode
7function lists.
9Key Benefits:
10- Eliminates ~1000+ lines of duplicated code
11- Enforces consistent testing and registration patterns
12- Makes adding new libraries trivial (60-120 lines vs 350-400)
13- Centralizes bug fixes and improvements
14- Type-safe abstract interface prevents shortcuts
16Architecture:
17- LibraryRegistryBase: Abstract base class with common functionality
18- ProcessingContract: Unified contract enum across all libraries
19- Dimension error adapter factory for consistent error handling
20- Integrated caching system using existing cache_utils.py patterns
21"""
23import importlib
24import inspect
25import importlib
26import json
27import logging
28import time
29from abc import ABC, abstractmethod
30from dataclasses import dataclass, field
31from enum import Enum, auto
32from functools import wraps
33from pathlib import Path
34from typing import Any, Callable, Dict, List, Optional, Tuple, Set
36import numpy as np
38from openhcs.core.utils import optional_import
39from openhcs.core.xdg_paths import get_cache_file_path
40from openhcs.core.memory.stack_utils import unstack_slices, stack_slices
42logger = logging.getLogger(__name__)
45# Enums for OpenHCS principle compliance (replace magic strings)
46class ModuleFilterComponents(Enum):
47 """Components to filter out when generating tags from module paths."""
48 BACKENDS = "backends"
49 PROCESSING = "processing"
50 OPENHCS = "openhcs"
52 @classmethod
53 def should_skip(cls, component: str) -> bool:
54 """Check if component should be skipped in tag generation."""
55 return any(component == item.value for item in cls)
58class ProcessingContract(Enum):
59 """
60 Unified contract classification with direct method execution.
61 """
62 PURE_3D = "_execute_pure_3d"
63 PURE_2D = "_execute_pure_2d"
64 FLEXIBLE = "_execute_flexible"
65 VOLUMETRIC_TO_SLICE = "_execute_volumetric_to_slice"
67 def execute(self, registry, func, image, *args, **kwargs):
68 """Execute the contract method on the registry."""
69 method = getattr(registry, self.value)
70 return method(func, image, *args, **kwargs)
73@dataclass(frozen=True)
74class FunctionMetadata:
75 """Clean metadata with no library-specific leakage."""
77 # Core fields only
78 name: str
79 func: Callable
80 contract: ProcessingContract
81 registry: 'LibraryRegistryBase' # Reference to the registry that registered this function - REQUIRED
82 module: str = ""
83 doc: str = ""
84 tags: List[str] = field(default_factory=list)
85 original_name: str = "" # Original function name for cache reconstruction
90class LibraryRegistryBase(ABC):
91 """
92 Minimal ABC for all library registries.
94 Provides only essential contracts that all registries must implement,
95 regardless of whether they use runtime testing or explicit contracts.
96 """
98 # Common exclusions across all libraries
99 COMMON_EXCLUSIONS = {
100 'imread', 'imsave', 'load', 'save', 'read', 'write',
101 'show', 'imshow', 'plot', 'display', 'view', 'visualize',
102 'info', 'help', 'version', 'test', 'benchmark'
103 }
105 # Abstract class attributes - each implementation must define these
106 MODULES_TO_SCAN: List[str]
107 MEMORY_TYPE: str # Memory type string value (e.g., "pyclesperanto", "cupy", "numpy")
108 FLOAT_DTYPE: Any # Library-specific float32 type (np.float32, cp.float32, etc.)
110 def __init__(self, library_name: str):
111 """
112 Initialize registry for a specific library.
114 Args:
115 library_name: Name of the library (e.g., "pyclesperanto", "skimage")
116 """
117 self.library_name = library_name
118 self._cache_path = get_cache_file_path(f"{library_name}_function_metadata.json")
124 # ===== ESSENTIAL ABC METHODS =====
126 # ===== LIBRARY IDENTIFICATION =====
127 @abstractmethod
128 def get_library_version(self) -> str:
129 """Get library version for cache validation."""
130 pass
132 @abstractmethod
133 def is_library_available(self) -> bool:
134 """Check if the library is available for import."""
135 pass
137 # ===== FUNCTION DISCOVERY =====
138 @abstractmethod
139 def discover_functions(self) -> Dict[str, FunctionMetadata]:
140 """Discover and return function metadata. Must be implemented by subclasses."""
141 pass
143 # ===== CONTRACT HANDLING =====
144 def apply_contract_wrapper(self, func: Callable, contract: ProcessingContract) -> Callable:
145 """Apply contract-specific wrapper for all contract types."""
146 from functools import wraps
147 import inspect
149 if contract == ProcessingContract.FLEXIBLE:
150 # Check if function already has slice_by_slice parameter
151 original_sig = inspect.signature(func)
152 param_names = [p.name for p in original_sig.parameters.values()]
154 if 'slice_by_slice' in param_names: 154 ↛ 156line 154 didn't jump to line 156 because the condition on line 154 was never true
155 # Function already has slice_by_slice, just ensure type hint is correct
156 if hasattr(func, '__annotations__'):
157 func.__annotations__['slice_by_slice'] = bool
158 else:
159 func.__annotations__ = {'slice_by_slice': bool}
160 return func
162 # Add slice_by_slice parameter
163 new_params = list(original_sig.parameters.values())
164 slice_param = inspect.Parameter(
165 'slice_by_slice',
166 inspect.Parameter.KEYWORD_ONLY,
167 default=False,
168 annotation=bool # Explicit bool type hint for UI
169 )
171 # Insert before any VAR_KEYWORD (**kwargs) parameter
172 insert_index = len(new_params)
173 for i, param in enumerate(new_params):
174 if param.kind == inspect.Parameter.VAR_KEYWORD:
175 insert_index = i
176 break
178 new_params.insert(insert_index, slice_param)
180 # Create the wrapper function
181 @wraps(func)
182 def flexible_wrapper(image, *args, slice_by_slice: bool = False, **kwargs):
183 func.slice_by_slice = slice_by_slice
184 return contract.execute(self, func, image, *args, **kwargs)
186 # Apply the modified signature AFTER @wraps
187 new_sig = original_sig.replace(parameters=new_params)
188 flexible_wrapper.__signature__ = new_sig
190 # Preserve original annotations and add slice_by_slice
191 if hasattr(func, '__annotations__'): 191 ↛ 194line 191 didn't jump to line 194 because the condition on line 191 was always true
192 flexible_wrapper.__annotations__ = func.__annotations__.copy()
193 else:
194 flexible_wrapper.__annotations__ = {}
195 flexible_wrapper.__annotations__['slice_by_slice'] = bool
197 flexible_wrapper.slice_by_slice = False
198 return flexible_wrapper
199 else:
200 # For other contracts, wrap with contract execution
201 @wraps(func)
202 def contract_wrapper(image, *args, **kwargs):
203 return contract.execute(self, func, image, *args, **kwargs)
205 return contract_wrapper
207 def _inject_optional_dataclass_params(self, func: Callable) -> Callable:
208 """Inject optional lazy dataclass parameters into function signature.
210 Can be disabled by setting ENABLE_CONFIG_INJECTION = False.
211 """
212 # Configuration flag to enable/disable config injection
213 ENABLE_CONFIG_INJECTION = False # Set to True to re-enable config injection
215 if not ENABLE_CONFIG_INJECTION: 215 ↛ 219line 215 didn't jump to line 219 because the condition on line 215 was always true
216 return func # Return function unchanged when disabled
218 # Original injection logic (commented out for now but preserved)
219 import inspect
220 from functools import wraps
221 from typing import Optional
223 # Get original signature
224 original_sig = inspect.signature(func)
225 original_params = list(original_sig.parameters.values())
227 # Import existing lazy config types
228 from openhcs.core.config import LazyNapariStreamingConfig, LazyFijiStreamingConfig, LazyStepMaterializationConfig
230 # Define common lazy dataclass parameters to inject
231 dataclass_params = [
232 ('napari_streaming_config', 'Optional[LazyNapariStreamingConfig]', LazyNapariStreamingConfig),
233 ('fiji_streaming_config', 'Optional[LazyFijiStreamingConfig]', LazyFijiStreamingConfig),
234 ('step_materialization_config', 'Optional[LazyStepMaterializationConfig]', LazyStepMaterializationConfig),
235 ]
237 # Check if any parameters need to be added
238 existing_param_names = {p.name for p in original_params}
239 params_to_add = [(name, type_hint, lazy_class) for name, type_hint, lazy_class in dataclass_params
240 if name not in existing_param_names]
242 if not params_to_add:
243 return func # No parameters to add
245 # Create new parameters
246 new_params = original_params.copy()
248 # Find insertion point (before **kwargs if it exists)
249 insert_index = len(new_params)
250 for i, param in enumerate(new_params):
251 if param.kind == inspect.Parameter.VAR_KEYWORD:
252 insert_index = i
253 break
255 # Add dataclass parameters
256 from typing import Optional
257 for param_name, type_hint, lazy_class in params_to_add:
258 new_param = inspect.Parameter(
259 param_name,
260 inspect.Parameter.KEYWORD_ONLY,
261 default=None,
262 annotation=Optional[lazy_class] # Use actual type object, not string
263 )
264 new_params.insert(insert_index, new_param)
265 insert_index += 1
267 # Create enhanced wrapper function
268 @wraps(func)
269 def enhanced_wrapper(*args, **kwargs):
270 # Extract dataclass parameters from kwargs (they're just ignored for now)
271 regular_kwargs = {k: v for k, v in kwargs.items()
272 if k not in [name for name, _, _ in dataclass_params]}
274 # Call original function with regular parameters only
275 return func(*args, **regular_kwargs)
277 # Apply the modified signature
278 new_sig = original_sig.replace(parameters=new_params)
279 enhanced_wrapper.__signature__ = new_sig
281 # Enhance annotations
282 if hasattr(func, '__annotations__'):
283 enhanced_wrapper.__annotations__ = func.__annotations__.copy()
284 else:
285 enhanced_wrapper.__annotations__ = {}
287 # Add type annotations for injected parameters
288 from typing import Optional
289 for param_name, type_hint, lazy_class in params_to_add:
290 enhanced_wrapper.__annotations__[param_name] = Optional[lazy_class]
292 return enhanced_wrapper
294 def _get_function_by_name(self, module_path: str, func_name: str):
295 """Get function object by module path and name."""
296 module = importlib.import_module(module_path)
297 return getattr(module, func_name)
299 # ===== PROCESSING CONTRACT EXECUTION METHODS =====
300 def _execute_slice_by_slice(self, func, image, *args, **kwargs):
301 """Shared slice-by-slice execution logic."""
302 if image.ndim == 3:
303 from openhcs.core.memory.stack_utils import unstack_slices, stack_slices, _detect_memory_type
304 mem = _detect_memory_type(image)
305 slices = unstack_slices(image, mem, 0)
306 results = [func(sl, *args, **kwargs) for sl in slices]
307 return stack_slices(results, mem, 0)
308 return func(image, *args, **kwargs)
310 def _execute_pure_3d(self, func, image, *args, **kwargs):
311 """Execute 3D→3D function directly (no change)."""
312 return func(image, *args, **kwargs)
314 def _execute_pure_2d(self, func, image, *args, **kwargs):
315 """Execute 2D→2D function with unstack/restack wrapper."""
316 from openhcs.core.memory.stack_utils import unstack_slices, stack_slices
317 # Get memory type from the decorated function
318 memory_type = func.output_memory_type
319 slices = unstack_slices(image, memory_type, 0)
320 results = [func(sl, *args, **kwargs) for sl in slices]
321 return stack_slices(results, memory_type, 0)
323 def _execute_flexible(self, func, image, *args, **kwargs):
324 """Execute function that handles both 3D→3D and 2D→2D with toggle."""
325 # Check if slice_by_slice attribute is set on the function
326 slice_by_slice = getattr(func, 'slice_by_slice', False)
327 if slice_by_slice:
328 # Reuse the 2D-only execution logic (unstack -> process -> restack)
329 return self._execute_pure_2d(func, image, *args, **kwargs)
330 else:
331 # Use 3D-only execution logic (no modification)
332 return self._execute_pure_3d(func, image, *args, **kwargs)
334 def _execute_volumetric_to_slice(self, func, image, *args, **kwargs):
335 """Execute 3D→2D function returning slice 3D array."""
336 from openhcs.core.memory.stack_utils import stack_slices
337 # Get memory type from the decorated function
338 memory_type = func.output_memory_type
339 result_2d = func(image, *args, **kwargs)
340 return stack_slices([result_2d], memory_type, 0)
342 # ===== CACHING METHODS =====
343 def _load_or_discover_functions(self) -> Dict[str, FunctionMetadata]:
344 """Load functions from cache or discover them if cache is invalid."""
345 cached_functions = self._load_from_cache()
346 if cached_functions is not None: 346 ↛ 347line 346 didn't jump to line 347 because the condition on line 346 was never true
347 logger.info(f"✅ Loaded {len(cached_functions)} {self.library_name} functions from cache")
348 return cached_functions
350 logger.info(f"🔍 Cache miss for {self.library_name} - performing full discovery")
351 functions = self.discover_functions()
352 self._save_to_cache(functions)
353 return functions
355 def _load_from_cache(self) -> Optional[Dict[str, FunctionMetadata]]:
356 """Load function metadata from cache with validation."""
357 if not self._cache_path.exists(): 357 ↛ 360line 357 didn't jump to line 360 because the condition on line 357 was always true
358 return None
360 try:
361 with open(self._cache_path, 'r') as f:
362 cache_data = json.load(f)
363 except json.JSONDecodeError:
364 logger.warning(f"Corrupt cache file {self._cache_path}, rebuilding")
365 self._cache_path.unlink(missing_ok=True)
366 return None
368 if 'functions' not in cache_data:
369 return None
371 cached_version = cache_data.get('library_version', 'unknown')
372 current_version = self.get_library_version()
373 if cached_version != current_version:
374 logger.info(f"{self.library_name} version changed ({cached_version} → {current_version}) - cache invalid")
375 return None
377 cache_timestamp = cache_data.get('timestamp', 0)
378 cache_age_days = (time.time() - cache_timestamp) / (24 * 3600)
379 if cache_age_days > 7:
380 logger.info(f"Cache is {cache_age_days:.1f} days old - rebuilding")
381 return None
383 functions = {}
384 for func_name, cached_data in cache_data['functions'].items():
385 original_name = cached_data.get('original_name', func_name)
386 func = self._get_function_by_name(cached_data['module'], original_name)
387 contract = ProcessingContract[cached_data['contract']]
389 # Apply the same wrappers as during discovery
390 if hasattr(self, 'create_library_adapter'):
391 # External library - apply library adapter + contract wrapper + param injection
392 adapted_func = self.create_library_adapter(func, contract)
393 contract_wrapped_func = self.apply_contract_wrapper(adapted_func, contract)
394 final_func = self._inject_optional_dataclass_params(contract_wrapped_func)
395 else:
396 # OpenHCS - apply contract wrapper + param injection
397 contract_wrapped_func = self.apply_contract_wrapper(func, contract)
398 final_func = self._inject_optional_dataclass_params(contract_wrapped_func)
400 metadata = FunctionMetadata(
401 name=func_name,
402 func=final_func,
403 contract=contract,
404 registry=self,
405 module=cached_data.get('module', ''),
406 doc=cached_data.get('doc', ''),
407 tags=cached_data.get('tags', []),
408 original_name=cached_data.get('original_name', func_name)
409 )
410 functions[func_name] = metadata
412 return functions
414 def _save_to_cache(self, functions: Dict[str, FunctionMetadata]) -> None:
415 """Save function metadata to cache."""
416 cache_data = {
417 'cache_version': '1.0',
418 'library_version': self.get_library_version(),
419 'timestamp': time.time(),
420 'functions': {
421 func_name: {
422 'name': metadata.name,
423 'original_name': metadata.original_name,
424 'module': metadata.module,
425 'contract': metadata.contract.name,
426 'doc': metadata.doc,
427 'tags': metadata.tags
428 }
429 for func_name, metadata in functions.items()
430 }
431 }
433 self._cache_path.parent.mkdir(parents=True, exist_ok=True)
434 with open(self._cache_path, 'w') as f:
435 json.dump(cache_data, f, indent=2)
437 logger.info(f"💾 Saved {len(functions)} {self.library_name} functions to cache")
439 def get_memory_type(self) -> str:
440 """Get the memory type string value for this library."""
441 return self.MEMORY_TYPE
443 def get_module_patterns(self) -> List[str]:
444 """Get module patterns that identify this library (can be overridden by implementations)."""
445 # Default: just the library name
446 return [self.library_name.lower()]
448 def get_display_name(self) -> str:
449 """Get display name for this library (can be overridden by implementations)."""
450 # Default: capitalize library name
451 return self.library_name.title()
453 # ===== FUNCTION DISCOVERY =====
454 def get_modules_to_scan(self) -> List[Tuple[str, Any]]:
455 """
456 Get list of (module_name, module_object) tuples to scan for functions.
457 Uses the MODULES_TO_SCAN class attribute and library object from get_library_object().
459 Returns:
460 List of (name, module) pairs where name is for identification
461 and module is the actual module object to scan.
462 """
463 library = self.get_library_object()
464 modules = []
465 for module_name in self.MODULES_TO_SCAN:
466 if module_name == "":
467 # Empty string means scan the main library namespace
468 module = library
469 modules.append(("main", module))
470 else:
471 module = getattr(library, module_name)
472 modules.append((module_name, module))
473 return modules
475 @abstractmethod
476 def get_library_object(self):
477 """Get the main library object to scan for modules. Library-specific implementation."""
478 pass
481class RuntimeTestingRegistryBase(LibraryRegistryBase):
482 """
483 Extended ABC for libraries that require runtime testing.
485 Adds runtime testing methods for libraries that don't have explicit
486 processing contracts and need behavioral classification through testing.
487 """
489 def create_test_arrays(self) -> Tuple[Any, Any]:
490 """
491 Create test arrays appropriate for this library.
493 Returns:
494 Tuple of (test_3d, test_2d) arrays for behavior testing
495 """
496 test_3d = self._create_array((3, 20, 20), self._get_float_dtype())
497 test_2d = self._create_array((20, 20), self._get_float_dtype())
498 return test_3d, test_2d
500 @abstractmethod
501 def _create_array(self, shape: Tuple[int, ...], dtype):
502 """Create array with specified shape and dtype. Library-specific implementation."""
503 pass
505 def _get_float_dtype(self):
506 """Get the appropriate float dtype for this library."""
507 return self.FLOAT_DTYPE
509 # ===== CORE BEHAVIOR CONTRACT =====
510 def classify_function_behavior(self, func: Callable, declared_contract: Optional[ProcessingContract] = None) -> Tuple[ProcessingContract, bool]:
511 """Classify function behavior by testing 3D and 2D inputs, or use declared contract if provided."""
513 # Fast path: If explicit contract is declared, use it directly (skip runtime testing)
514 if declared_contract is not None: 514 ↛ 515line 514 didn't jump to line 515 because the condition on line 514 was never true
515 return declared_contract, True
516 test_3d, test_2d = self.create_test_arrays()
518 def test_function(test_array):
519 """Test function with array, return (success, result)."""
520 try:
521 result = func(test_array)
522 return True, result
523 except:
524 return False, None
526 works_3d, result_3d = test_function(test_3d)
527 works_2d, _ = test_function(test_2d)
529 # Classification lookup table
530 classification_map = {
531 (True, True): self._classify_dual_support(result_3d),
532 (True, False): ProcessingContract.PURE_3D,
533 (False, True): ProcessingContract.PURE_2D,
534 (False, False): None # Invalid function
535 }
537 contract = classification_map[(works_3d, works_2d)]
538 is_valid = works_3d or works_2d
540 return contract, is_valid
542 def _classify_dual_support(self, result_3d):
543 """Classify functions that work on both 3D and 2D inputs."""
544 if result_3d is not None:
545 # Handle tuple results (some functions return multiple arrays)
546 if isinstance(result_3d, tuple):
547 # Check the first element if it's a tuple
548 first_result = result_3d[0] if len(result_3d) > 0 else None
549 if hasattr(first_result, 'ndim') and first_result.ndim == 2: 549 ↛ 550line 549 didn't jump to line 550 because the condition on line 549 was never true
550 return ProcessingContract.VOLUMETRIC_TO_SLICE
551 # Handle single array results
552 elif hasattr(result_3d, 'ndim') and result_3d.ndim == 2:
553 return ProcessingContract.VOLUMETRIC_TO_SLICE
554 return ProcessingContract.FLEXIBLE
556 @abstractmethod
557 def _stack_2d_results(self, func, test_3d):
558 """Stack 2D results. Library-specific implementation required."""
559 pass
561 @abstractmethod
562 def _arrays_close(self, arr1, arr2):
563 """Compare arrays. Library-specific implementation required."""
564 pass
566 def create_library_adapter(self, original_func: Callable, contract: ProcessingContract) -> Callable:
567 """Create adapter with library-specific processing only."""
568 import inspect
569 func_name = getattr(original_func, '__name__', 'unknown')
571 # Get original signature to preserve it
572 original_sig = inspect.signature(original_func)
574 def adapter(image, *args, **kwargs):
575 processed_image = self._preprocess_input(image, func_name)
576 result = contract.execute(self, original_func, processed_image, *args, **kwargs)
577 return self._postprocess_output(result, image, func_name)
579 # Apply wraps and preserve signature
580 wrapped_adapter = wraps(original_func)(adapter)
581 wrapped_adapter.__signature__ = original_sig
583 # Preserve and enhance annotations
584 if hasattr(original_func, '__annotations__'): 584 ↛ 587line 584 didn't jump to line 587 because the condition on line 584 was always true
585 wrapped_adapter.__annotations__ = original_func.__annotations__.copy()
586 else:
587 wrapped_adapter.__annotations__ = {}
589 # Extract type hints from docstring if annotations are missing
590 self._enhance_annotations_from_docstring(wrapped_adapter, original_func)
592 # Set memory type attributes for contract execution compatibility
593 # Only set if registry has a specific memory type (external libraries)
594 if self.MEMORY_TYPE is not None: 594 ↛ 597line 594 didn't jump to line 597 because the condition on line 594 was always true
595 wrapped_adapter.input_memory_type = self.MEMORY_TYPE
596 wrapped_adapter.output_memory_type = self.MEMORY_TYPE
597 wrapped_adapter.stream_to_napari = False
599 return wrapped_adapter
601 def _enhance_annotations_from_docstring(self, wrapped_func: Callable, original_func: Callable):
602 """Extract type hints from docstring using mathematical simplification approach."""
603 try:
604 from openhcs.textual_tui.widgets.shared.signature_analyzer import SignatureAnalyzer
605 import numpy as np
607 # Unified type extraction with compatibility validation (mathematical simplification)
608 TYPE_PATTERNS = {'ndarray': np.ndarray, 'array': np.ndarray, 'array_like': np.ndarray,
609 'int': int, 'integer': int, 'float': float, 'scalar': float,
610 'bool': bool, 'boolean': bool, 'str': str, 'string': str,
611 'tuple': tuple, 'list': list, 'dict': dict, 'sequence': list}
613 COMPATIBLE_DEFAULTS = {float: (int, float, range), int: (int, float),
614 list: (list, tuple, range), tuple: (list, tuple, range)}
616 param_info = SignatureAnalyzer.analyze(original_func, skip_first_param=False)
618 # Inline type extraction and validation (single-use function inlining rule)
619 for param_name, info in param_info.items():
620 if param_name not in wrapped_func.__annotations__ and info.description:
621 # Inline type extraction with priority patterns
622 desc = info.description.lower().replace(', optional', '').replace(' optional', '').split(' or ')[0].strip()
623 python_type = (str if desc.startswith('{') and '}' in desc
624 else list if any(p in desc for p in ['sequence', 'iterable', 'array of', 'list of'])
625 else next((t for pattern, t in TYPE_PATTERNS.items() if pattern in desc), None))
627 # Inline compatibility check (single-use function inlining rule)
628 if python_type and (info.default_value is None or
629 type(info.default_value) in COMPATIBLE_DEFAULTS.get(python_type, (python_type,))):
630 wrapped_func.__annotations__[param_name] = python_type
631 except Exception:
632 pass
634 @abstractmethod
635 def _preprocess_input(self, image, func_name: str):
636 """Preprocess input image. Library-specific implementation."""
637 pass
639 @abstractmethod
640 def _postprocess_output(self, result, original_image, func_name: str):
641 """Postprocess output result. Library-specific implementation."""
642 pass
644 # ===== BASIC FILTERING =====
645 def should_include_function(self, func: Callable, func_name: str) -> bool:
646 """Single method for all filtering logic (blacklist, signature, etc.)"""
647 # Skip private functions
648 if func_name.startswith('_'): 648 ↛ 649line 648 didn't jump to line 649 because the condition on line 648 was never true
649 return False
651 # Skip exclusions (check both common and library-specific)
652 exclusions = getattr(self.__class__, 'EXCLUSIONS', self.COMMON_EXCLUSIONS)
653 if func_name.lower() in exclusions: 653 ↛ 654line 653 didn't jump to line 654 because the condition on line 653 was never true
654 return False
656 # Skip classes and types
657 if inspect.isclass(func) or isinstance(func, type):
658 return False
660 # Must be callable
661 if not callable(func):
662 return False
664 # Pure functions must have at least one parameter
665 sig = inspect.signature(func)
666 params = list(sig.parameters.values())
667 if not params: 667 ↛ 668line 667 didn't jump to line 668 because the condition on line 667 was never true
668 return False
670 # Validate that type hints can be resolved (skip functions with missing dependencies)
671 if not self._validate_type_hints(func, func_name): 671 ↛ 672line 671 didn't jump to line 672 because the condition on line 671 was never true
672 return False
674 # Library-specific signature validation
675 return self._check_first_parameter(params[0], func_name)
678 def _validate_type_hints(self, func: Callable, func_name: str) -> bool:
679 """
680 Validate that function type hints can be resolved.
682 Returns False if type hints reference missing dependencies (e.g., torch when not installed).
683 This prevents functions with unresolvable type hints from being registered.
684 """
685 try:
686 from typing import get_type_hints
687 # Try to resolve type hints - this will fail if dependencies are missing
688 get_type_hints(func)
689 return True
690 except NameError as e:
691 # Type hint references a missing dependency (e.g., 'torch' not defined)
692 logger.warning(f"Skipping function '{func_name}' due to unresolvable type hints: {e}")
693 return False
694 except Exception:
695 # Other type hint resolution errors - be conservative and allow the function
696 # (this handles edge cases where get_type_hints fails for other reasons)
697 return True
699 @abstractmethod
700 def _check_first_parameter(self, first_param, func_name: str) -> bool:
701 """Check if first parameter meets library-specific criteria. Library-specific implementation."""
702 pass
704 # ===== RUNTIME TESTING IMPLEMENTATION =====
705 def discover_functions(self) -> Dict[str, FunctionMetadata]:
706 """Discover and classify all library functions with runtime testing."""
707 functions = {}
708 modules = self.get_modules_to_scan()
709 logger.info(f"🔍 Starting function discovery for {self.library_name}")
710 logger.info(f"📦 Scanning {len(modules)} modules: {[name for name, _ in modules]}")
712 total_tested = 0
713 total_accepted = 0
715 for module_name, module in modules:
716 logger.info(f" 📦 Analyzing {module_name} ({module})...")
717 module_tested = 0
718 module_accepted = 0
720 for name in dir(module):
721 if name.startswith("_"):
722 continue
724 func = getattr(module, name)
725 full_path = self._get_full_function_path(module, name, module_name)
727 if not self.should_include_function(func, name):
728 rejection_reason = self._get_rejection_reason(func, name)
729 if rejection_reason != "private": 729 ↛ 731line 729 didn't jump to line 731 because the condition on line 729 was always true
730 logger.info(f" 🚫 Skipping {full_path}: {rejection_reason}")
731 continue
733 module_tested += 1
734 total_tested += 1
736 contract, is_valid = self.classify_function_behavior(func)
737 logger.info(f" 🧪 Testing {full_path}")
738 logger.info(f" Classification: {contract.name if contract else contract}")
740 if not is_valid:
741 logger.info(f" ❌ Rejected: Invalid classification")
742 continue
744 doc_lines = (func.__doc__ or "").splitlines()
745 first_line_doc = doc_lines[0] if doc_lines else ""
746 func_name = self._generate_function_name(name, module_name)
748 # Apply library adapter (preprocessing/postprocessing)
749 adapted_func = self.create_library_adapter(func, contract)
751 # Apply contract wrapper (slice_by_slice for FLEXIBLE)
752 contract_wrapped_func = self.apply_contract_wrapper(adapted_func, contract)
754 # Inject optional dataclass parameters
755 final_func = self._inject_optional_dataclass_params(contract_wrapped_func)
757 metadata = FunctionMetadata(
758 name=func_name,
759 func=final_func,
760 contract=contract,
761 registry=self,
762 module=func.__module__ or "",
763 doc=first_line_doc,
764 tags=self._generate_tags(name),
765 original_name=name
766 )
768 functions[func_name] = metadata
769 module_accepted += 1
770 total_accepted += 1
771 logger.info(f" ✅ Accepted as '{func_name}'")
773 logger.info(f" 📊 Module {module_name}: {module_accepted}/{module_tested} functions accepted")
775 logger.info(f"✅ Discovery complete: {total_accepted}/{total_tested} functions accepted")
776 return functions
780 def _get_full_function_path(self, module, func_name: str, module_name: str) -> str:
781 """Generate full module path for logging."""
782 if module_name == "main": 782 ↛ 783line 782 didn't jump to line 783 because the condition on line 782 was never true
783 return f"{self.library_name}.{func_name}"
784 else:
785 # Extract clean module path
786 module_str = str(module)
787 if "'" in module_str: 787 ↛ 791line 787 didn't jump to line 791 because the condition on line 787 was always true
788 clean_path = module_str.split("'")[1]
789 return f"{clean_path}.{func_name}"
790 else:
791 return f"{module_name}.{func_name}"
793 def _get_rejection_reason(self, func: Callable, func_name: str) -> str:
794 """Get detailed reason why a function was rejected."""
795 # Check each rejection criteria in order
796 if func_name.startswith('_'): 796 ↛ 797line 796 didn't jump to line 797 because the condition on line 796 was never true
797 return "private"
799 exclusions = getattr(self.__class__, 'EXCLUSIONS', self.COMMON_EXCLUSIONS)
800 if func_name.lower() in exclusions: 800 ↛ 801line 800 didn't jump to line 801 because the condition on line 800 was never true
801 return "blacklisted"
803 if inspect.isclass(func) or isinstance(func, type):
804 return "is class/type"
806 if not callable(func):
807 return "not callable"
809 try:
810 sig = inspect.signature(func)
811 params = list(sig.parameters.values())
812 if not params: 812 ↛ 813line 812 didn't jump to line 813 because the condition on line 812 was never true
813 return "no parameters (not pure function)"
814 except (ValueError, TypeError):
815 return "invalid signature"
817 return "unknown"
821 # ===== CUSTOMIZATION HOOKS =====
822 def _generate_function_name(self, name: str, module_name: str) -> str:
823 """Generate function name. Override in subclasses for custom naming."""
824 return name
826 def _generate_tags(self, func_name: str) -> List[str]:
827 """Generate tags using library name."""
828 return [self.library_name]