Coverage for openhcs/processing/backends/lib_registry/unified_registry.py: 66.2%
253 statements
« prev ^ index » next coverage.py v7.10.3, created at 2025-08-14 05:57 +0000
« prev ^ index » next coverage.py v7.10.3, created at 2025-08-14 05:57 +0000
1"""
2Unified registry base class for external library function registration.
4This module provides a common base class that eliminates ~70% of code duplication
5across library registries (pyclesperanto, scikit-image, cupy, etc.) while enforcing
6consistent behavior and making it impossible to skip dynamic testing or hardcode
7function lists.
9Key Benefits:
10- Eliminates ~1000+ lines of duplicated code
11- Enforces consistent testing and registration patterns
12- Makes adding new libraries trivial (60-120 lines vs 350-400)
13- Centralizes bug fixes and improvements
14- Type-safe abstract interface prevents shortcuts
16Architecture:
17- LibraryRegistryBase: Abstract base class with common functionality
18- ProcessingContract: Unified contract enum across all libraries
19- Dimension error adapter factory for consistent error handling
20- Integrated caching system using existing cache_utils.py patterns
21"""
23import importlib
24import inspect
25import json
26import logging
27import time
28from abc import ABC, abstractmethod
29from dataclasses import dataclass, field
30from enum import Enum, auto
31from functools import wraps
32from pathlib import Path
33from typing import Any, Callable, Dict, List, Optional, Tuple, Set
35import numpy as np
37from openhcs.core.utils import optional_import
38from openhcs.core.xdg_paths import get_cache_file_path
39from openhcs.core.memory.stack_utils import unstack_slices, stack_slices
41logger = logging.getLogger(__name__)
44class ProcessingContract(Enum):
45 """
46 Unified contract classification with direct method execution.
47 """
48 PURE_3D = "_execute_pure_3d"
49 PURE_2D = "_execute_pure_2d"
50 FLEXIBLE = "_execute_flexible"
51 VOLUMETRIC_TO_SLICE = "_execute_volumetric_to_slice"
53 def execute(self, registry, func, image, *args, **kwargs):
54 """Execute the contract method on the registry."""
55 method = getattr(registry, self.value)
56 return method(func, image, *args, **kwargs)
59@dataclass(frozen=True)
60class FunctionMetadata:
61 """Clean metadata with no library-specific leakage."""
63 # Core fields only
64 name: str
65 func: Callable
66 contract: ProcessingContract
67 module: str = ""
68 doc: str = ""
69 tags: List[str] = field(default_factory=list)
70 original_name: str = "" # Original function name for cache reconstruction
75class LibraryRegistryBase(ABC):
76 """
77 Clean abstraction with essential contracts only.
79 Enforces only essential behavior contracts, not library-specific details.
80 Each registry implements the contract its own way while returning unified
81 ProcessingContract and FunctionMetadata types.
83 Essential contracts:
84 - Test function behavior to determine: 3D→3D, 2D→2D only, 3D→2D, etc.
85 - Create adapters based on contract classification
86 - Filter functions using consolidated logic
87 - Provide library identification and discovery
88 """
90 # Common exclusions across all libraries
91 COMMON_EXCLUSIONS = {
92 'imread', 'imsave', 'load', 'save', 'read', 'write',
93 'show', 'imshow', 'plot', 'display', 'view', 'visualize',
94 'info', 'help', 'version', 'test', 'benchmark'
95 }
97 # Abstract class attributes - each implementation must define these
98 MODULES_TO_SCAN: List[str]
99 MEMORY_TYPE: str # Memory type string value (e.g., "pyclesperanto", "cupy", "numpy")
100 FLOAT_DTYPE: Any # Library-specific float32 type (np.float32, cp.float32, etc.)
102 def __init__(self, library_name: str):
103 """
104 Initialize registry for a specific library.
106 Args:
107 library_name: Name of the library (e.g., "pyclesperanto", "skimage")
108 """
109 self.library_name = library_name
110 self._cache_path = get_cache_file_path(f"{library_name}_function_metadata.json")
112 # ===== ESSENTIAL ABC METHODS =====
114 # ===== LIBRARY IDENTIFICATION =====
115 @abstractmethod
116 def get_library_version(self) -> str:
117 """Get library version for cache validation."""
118 pass
120 @abstractmethod
121 def is_library_available(self) -> bool:
122 """Check if the library is available for import."""
123 pass
125 def get_memory_type(self) -> str:
126 """Get the memory type string value for this library."""
127 return self.MEMORY_TYPE
129 # ===== FUNCTION DISCOVERY =====
130 def get_modules_to_scan(self) -> List[Tuple[str, Any]]:
131 """
132 Get list of (module_name, module_object) tuples to scan for functions.
133 Uses the MODULES_TO_SCAN class attribute and library object from get_library_object().
135 Returns:
136 List of (name, module) pairs where name is for identification
137 and module is the actual module object to scan.
138 """
139 library = self.get_library_object()
140 modules = []
141 for module_name in self.MODULES_TO_SCAN:
142 if module_name == "":
143 # Empty string means scan the main library namespace
144 module = library
145 modules.append(("main", module))
146 else:
147 module = getattr(library, module_name)
148 modules.append((module_name, module))
149 return modules
151 @abstractmethod
152 def get_library_object(self):
153 """Get the main library object to scan for modules. Library-specific implementation."""
154 pass
156 def create_test_arrays(self) -> Tuple[Any, Any]:
157 """
158 Create test arrays appropriate for this library.
160 Returns:
161 Tuple of (test_3d, test_2d) arrays for behavior testing
162 """
163 test_3d = self._create_array((3, 20, 20), self._get_float_dtype())
164 test_2d = self._create_array((20, 20), self._get_float_dtype())
165 return test_3d, test_2d
167 @abstractmethod
168 def _create_array(self, shape: Tuple[int, ...], dtype):
169 """Create array with specified shape and dtype. Library-specific implementation."""
170 pass
172 def _get_float_dtype(self):
173 """Get the appropriate float dtype for this library."""
174 return self.FLOAT_DTYPE
176 # ===== CORE BEHAVIOR CONTRACT =====
177 def classify_function_behavior(self, func: Callable) -> Tuple[ProcessingContract, bool]:
178 """Classify function behavior by testing 3D and 2D inputs."""
179 test_3d, test_2d = self.create_test_arrays()
181 def test_function(test_array):
182 """Test function with array, return (success, result)."""
183 try:
184 result = func(test_array)
185 return True, result
186 except:
187 return False, None
189 works_3d, result_3d = test_function(test_3d)
190 works_2d, _ = test_function(test_2d)
192 # Classification lookup table
193 classification_map = {
194 (True, True): self._classify_dual_support(result_3d),
195 (True, False): ProcessingContract.PURE_3D,
196 (False, True): ProcessingContract.PURE_2D,
197 (False, False): None # Invalid function
198 }
200 contract = classification_map[(works_3d, works_2d)]
201 is_valid = works_3d or works_2d
203 return contract, is_valid
205 def _classify_dual_support(self, result_3d):
206 """Classify functions that work on both 3D and 2D inputs."""
207 if result_3d is not None and result_3d.ndim == 2: 207 ↛ 208line 207 didn't jump to line 208 because the condition on line 207 was never true
208 return ProcessingContract.VOLUMETRIC_TO_SLICE
209 return ProcessingContract.FLEXIBLE
211 @abstractmethod
212 def _stack_2d_results(self, func, test_3d):
213 """Stack 2D results. Library-specific implementation required."""
214 pass
216 @abstractmethod
217 def _arrays_close(self, arr1, arr2):
218 """Compare arrays. Library-specific implementation required."""
219 pass
221 def create_library_adapter(self, original_func: Callable, contract: ProcessingContract) -> Callable:
222 """Create adapter based on contract classification."""
223 func_name = getattr(original_func, '__name__', 'unknown')
225 @wraps(original_func)
226 def unified_adapter(image, *args, slice_by_slice: bool = False, **kwargs):
227 # Library-specific preprocessing
228 processed_image = self._preprocess_input(image, func_name)
230 # Contract-based execution
231 result = contract.execute(self, original_func, processed_image, *args, **kwargs)
233 # Library-specific postprocessing
234 return self._postprocess_output(result, image, func_name)
236 return unified_adapter
238 @abstractmethod
239 def _preprocess_input(self, image, func_name: str):
240 """Preprocess input image. Library-specific implementation."""
241 pass
243 @abstractmethod
244 def _postprocess_output(self, result, original_image, func_name: str):
245 """Postprocess output result. Library-specific implementation."""
246 pass
248 # ===== BASIC FILTERING =====
249 def should_include_function(self, func: Callable, func_name: str) -> bool:
250 """Single method for all filtering logic (blacklist, signature, etc.)"""
251 # Skip private functions
252 if func_name.startswith('_'): 252 ↛ 253line 252 didn't jump to line 253 because the condition on line 252 was never true
253 return False
255 # Skip exclusions (check both common and library-specific)
256 exclusions = getattr(self.__class__, 'EXCLUSIONS', self.COMMON_EXCLUSIONS)
257 if func_name.lower() in exclusions: 257 ↛ 258line 257 didn't jump to line 258 because the condition on line 257 was never true
258 return False
260 # Skip classes and types
261 if inspect.isclass(func) or isinstance(func, type):
262 return False
264 # Must be callable
265 if not callable(func):
266 return False
268 # Pure functions must have at least one parameter
269 sig = inspect.signature(func)
270 params = list(sig.parameters.values())
271 if not params: 271 ↛ 272line 271 didn't jump to line 272 because the condition on line 271 was never true
272 return False
274 # Library-specific signature validation
275 return self._check_first_parameter(params[0], func_name)
279 @abstractmethod
280 def _check_first_parameter(self, first_param, func_name: str) -> bool:
281 """Check if first parameter meets library-specific criteria. Library-specific implementation."""
282 pass
284 # ===== SHARED IMPLEMENTATION LOGIC =====
285 def discover_functions(self) -> Dict[str, FunctionMetadata]:
286 """Discover and classify all library functions with detailed logging."""
287 functions = {}
288 modules = self.get_modules_to_scan()
289 logger.info(f"🔍 Starting function discovery for {self.library_name}")
290 logger.info(f"📦 Scanning {len(modules)} modules: {[name for name, _ in modules]}")
292 total_tested = 0
293 total_accepted = 0
295 for module_name, module in modules:
296 logger.info(f" 📦 Analyzing {module_name} ({module})...")
297 module_tested = 0
298 module_accepted = 0
300 for name in dir(module):
301 if name.startswith("_"):
302 continue
304 func = getattr(module, name)
305 full_path = self._get_full_function_path(module, name, module_name)
307 if not self.should_include_function(func, name):
308 rejection_reason = self._get_rejection_reason(func, name)
309 if rejection_reason != "private": 309 ↛ 311line 309 didn't jump to line 311 because the condition on line 309 was always true
310 logger.info(f" 🚫 Skipping {full_path}: {rejection_reason}")
311 continue
313 module_tested += 1
314 total_tested += 1
316 contract, is_valid = self.classify_function_behavior(func)
317 logger.info(f" 🧪 Testing {full_path}")
318 logger.info(f" Classification: {contract.name if contract else contract}")
320 if not is_valid:
321 logger.info(f" ❌ Rejected: Invalid classification")
322 continue
324 doc_lines = (func.__doc__ or "").splitlines()
325 first_line_doc = doc_lines[0] if doc_lines else ""
326 func_name = self._generate_function_name(name, module_name)
328 metadata = FunctionMetadata(
329 name=func_name,
330 func=func,
331 contract=contract,
332 module=func.__module__ or "",
333 doc=first_line_doc,
334 tags=self._generate_tags(name),
335 original_name=name
336 )
338 functions[func_name] = metadata
339 module_accepted += 1
340 total_accepted += 1
341 logger.info(f" ✅ Accepted as '{func_name}'")
343 logger.info(f" 📊 Module {module_name}: {module_accepted}/{module_tested} functions accepted")
345 logger.info(f"✅ Discovery complete: {total_accepted}/{total_tested} functions accepted")
346 return functions
348 def _get_full_function_path(self, module, func_name: str, module_name: str) -> str:
349 """Generate full module path for logging."""
350 if module_name == "main": 350 ↛ 351line 350 didn't jump to line 351 because the condition on line 350 was never true
351 return f"{self.library_name}.{func_name}"
352 else:
353 # Extract clean module path
354 module_str = str(module)
355 if "'" in module_str: 355 ↛ 359line 355 didn't jump to line 359 because the condition on line 355 was always true
356 clean_path = module_str.split("'")[1]
357 return f"{clean_path}.{func_name}"
358 else:
359 return f"{module_name}.{func_name}"
361 def _get_rejection_reason(self, func: Callable, func_name: str) -> str:
362 """Get detailed reason why a function was rejected."""
363 # Check each rejection criteria in order
364 if func_name.startswith('_'): 364 ↛ 365line 364 didn't jump to line 365 because the condition on line 364 was never true
365 return "private"
367 exclusions = getattr(self.__class__, 'EXCLUSIONS', self.COMMON_EXCLUSIONS)
368 if func_name.lower() in exclusions: 368 ↛ 369line 368 didn't jump to line 369 because the condition on line 368 was never true
369 return "blacklisted"
371 if inspect.isclass(func) or isinstance(func, type):
372 return "is class/type"
374 if not callable(func):
375 return "not callable"
377 try:
378 sig = inspect.signature(func)
379 params = list(sig.parameters.values())
380 if not params: 380 ↛ 381line 380 didn't jump to line 381 because the condition on line 380 was never true
381 return "no parameters (not pure function)"
382 except (ValueError, TypeError):
383 return "invalid signature"
385 return "unknown"
387 # ===== CACHING METHODS =====
388 def _load_or_discover_functions(self) -> Dict[str, FunctionMetadata]:
389 """Load functions from cache or discover them if cache is invalid."""
390 cached_functions = self._load_from_cache()
391 if cached_functions is not None: 391 ↛ 392line 391 didn't jump to line 392 because the condition on line 391 was never true
392 logger.info(f"✅ Loaded {len(cached_functions)} {self.library_name} functions from cache")
393 return cached_functions
395 logger.info(f"🔍 Cache miss for {self.library_name} - performing full discovery")
396 functions = self.discover_functions()
397 self._save_to_cache(functions)
398 return functions
400 def _load_from_cache(self) -> Optional[Dict[str, FunctionMetadata]]:
401 """Load function metadata from cache with validation."""
402 if not self._cache_path.exists():
403 return None
405 with open(self._cache_path, 'r') as f:
406 cache_data = json.load(f)
408 if 'functions' not in cache_data:
409 return None
411 cached_version = cache_data.get('library_version', 'unknown')
412 current_version = self.get_library_version()
413 if cached_version != current_version:
414 logger.info(f"{self.library_name} version changed ({cached_version} → {current_version}) - cache invalid")
415 return None
417 cache_timestamp = cache_data.get('timestamp', 0)
418 cache_age_days = (time.time() - cache_timestamp) / (24 * 3600)
419 if cache_age_days > 7:
420 logger.info(f"Cache is {cache_age_days:.1f} days old - rebuilding")
421 return None
423 functions = {}
424 for func_name, cached_data in cache_data['functions'].items():
425 original_name = cached_data.get('original_name', func_name)
426 func = self._get_function_by_name(cached_data['module'], original_name)
427 contract = ProcessingContract[cached_data['contract']]
429 metadata = FunctionMetadata(
430 name=func_name,
431 func=func,
432 contract=contract,
433 module=cached_data.get('module', ''),
434 doc=cached_data.get('doc', ''),
435 tags=cached_data.get('tags', []),
436 original_name=cached_data.get('original_name', func_name)
437 )
438 functions[func_name] = metadata
440 return functions
442 def register_functions_direct(self):
443 """Register functions directly with OpenHCS function registry using shared logic."""
444 from openhcs.processing.func_registry import _apply_unified_decoration, _register_function
445 from openhcs.constants import MemoryType
447 functions = self._load_or_discover_functions()
448 registered_count = 0
450 for func_name, metadata in functions.items():
451 adapted = self.create_library_adapter(metadata.func, metadata.contract)
452 memory_type_enum = MemoryType(self.get_memory_type())
453 wrapper_func = _apply_unified_decoration(
454 original_func=adapted,
455 func_name=metadata.name,
456 memory_type=memory_type_enum,
457 create_wrapper=True
458 )
460 _register_function(wrapper_func, self.get_memory_type())
461 registered_count += 1
463 logger.info(f"Registered {registered_count} {self.library_name} functions")
464 return registered_count
466 # ===== SHARED ADAPTER LOGIC =====
467 def _execute_slice_by_slice(self, func, image, *args, **kwargs):
468 """Shared slice-by-slice execution logic."""
469 if image.ndim == 3:
470 from openhcs.core.memory.stack_utils import unstack_slices, stack_slices, _detect_memory_type
471 mem = _detect_memory_type(image)
472 slices = unstack_slices(image, mem, 0)
473 results = [func(sl, *args, **kwargs) for sl in slices]
474 return stack_slices(results, mem, 0)
475 return func(image, *args, **kwargs)
477 # ===== PROCESSING CONTRACT EXECUTION METHODS =====
478 def _execute_pure_3d(self, func, image, *args, **kwargs):
479 """Execute 3D→3D function directly (no change)."""
480 return func(image, *args, **kwargs)
482 def _execute_pure_2d(self, func, image, *args, **kwargs):
483 """Execute 2D→2D function with unstack/restack wrapper."""
484 slices = unstack_slices(image, self.MEMORY_TYPE, 0)
485 results = [func(sl, *args, **kwargs) for sl in slices]
486 return stack_slices(results, self.MEMORY_TYPE, 0)
488 def _execute_flexible(self, func, image, *args, slice_by_slice: bool = False, **kwargs):
489 """Execute function that handles both 3D→3D and 2D→2D with toggle."""
490 if slice_by_slice:
491 return self._execute_pure_2d(func, image, *args, **kwargs)
492 else:
493 return self._execute_pure_3d(func, image, *args, **kwargs)
495 def _execute_volumetric_to_slice(self, func, image, *args, **kwargs):
496 """Execute 3D→2D function returning slice 3D array."""
497 result_2d = func(image, *args, **kwargs)
498 return stack_slices([result_2d], self.MEMORY_TYPE, 0)
500 # ===== CUSTOMIZATION HOOKS =====
501 def _generate_function_name(self, name: str, module_name: str) -> str:
502 """Generate function name. Override in subclasses for custom naming."""
503 return name
505 def _generate_tags(self, func_name: str) -> List[str]:
506 """Generate tags. Override in subclasses for custom tags."""
507 return func_name.lower().replace("_", " ").split()
509 def _promote_2d_to_3d(self, result):
510 """Promote 2D results to 3D using library-specific expansion method."""
511 if result.ndim == 2:
512 return self._expand_2d_to_3d(result)
513 elif isinstance(result, tuple) and result[0].ndim == 2:
514 expanded_first = self._expand_2d_to_3d(result[0])
515 return (expanded_first, *result[1:])
516 return result
518 @abstractmethod
519 def _expand_2d_to_3d(self, array_2d):
520 """Expand 2D array to 3D. Library-specific implementation required."""
521 pass
523 def _save_to_cache(self, functions: Dict[str, FunctionMetadata]) -> None:
524 """Save function metadata to cache."""
525 cache_data = {
526 'cache_version': '1.0',
527 'library_version': self.get_library_version(),
528 'timestamp': time.time(),
529 'functions': {
530 func_name: {
531 'name': metadata.name,
532 'original_name': metadata.original_name,
533 'module': metadata.module,
534 'contract': metadata.contract.name,
535 'doc': metadata.doc,
536 'tags': metadata.tags
537 }
538 for func_name, metadata in functions.items()
539 }
540 }
542 self._cache_path.parent.mkdir(parents=True, exist_ok=True)
543 with open(self._cache_path, 'w') as f:
544 json.dump(cache_data, f, indent=2)
546 def _get_function_by_name(self, module_path: str, func_name: str) -> Optional[Callable]:
547 """Reconstruct function object from module path and function name."""
548 module = importlib.import_module(module_path)
549 return getattr(module, func_name)