Coverage for src/hieraconf/context_manager.py: 38%

219 statements  

« prev     ^ index     » next       coverage.py v7.11.0, created at 2025-11-02 21:44 +0000

1""" 

2Generic contextvars-based context management system for lazy configuration. 

3 

4This module provides explicit context scoping using Python's contextvars to enable 

5hierarchical configuration resolution without explicit parameter passing. 

6 

7Key features: 

81. Explicit context scoping with config_context() manager 

92. Config extraction from functions, dataclasses, and objects 

103. Config merging for context hierarchy 

114. Clean separation between UI windows and contexts 

12 

13Key components: 

14- current_temp_global: ContextVar holding current merged global config 

15- config_context(): Context manager for creating context scopes 

16- extract_config_overrides(): Extract config values from any object type 

17- merge_configs(): Merge overrides into base config 

18""" 

19 

20import contextvars 

21import dataclasses 

22import inspect 

23import logging 

24from contextlib import contextmanager 

25from typing import Any, Dict, Union 

26from dataclasses import fields, is_dataclass 

27 

28logger = logging.getLogger(__name__) 

29 

30# Core contextvar for current merged global config 

31# This holds the current context state that resolution functions can access 

32current_temp_global = contextvars.ContextVar('current_temp_global') 

33 

34 

35def _merge_nested_dataclass(base, override, mask_with_none: bool = False): 

36 """ 

37 Recursively merge nested dataclass fields. 

38 

39 For each field in override: 

40 - If value is None and mask_with_none=False: skip (don't override base) 

41 - If value is None and mask_with_none=True: override with None (mask base) 

42 - If value is dataclass: recursively merge with base's value 

43 - Otherwise: use override value 

44 

45 Args: 

46 base: Base dataclass instance 

47 override: Override dataclass instance 

48 mask_with_none: If True, None values override base values 

49 

50 Returns: 

51 Merged dataclass instance 

52 """ 

53 if not is_dataclass(base) or not is_dataclass(override): 

54 return override 

55 

56 merge_values = {} 

57 for field_info in fields(override): 

58 field_name = field_info.name 

59 override_value = object.__getattribute__(override, field_name) 

60 

61 if override_value is None: 

62 if mask_with_none: 

63 # None overrides base value (masking mode) 

64 merge_values[field_name] = None 

65 else: 

66 # None means "don't override" - keep base value 

67 continue 

68 elif is_dataclass(override_value): 

69 # Recursively merge nested dataclass 

70 base_value = getattr(base, field_name, None) 

71 if base_value is not None and is_dataclass(base_value): 

72 merge_values[field_name] = _merge_nested_dataclass(base_value, override_value, mask_with_none) 

73 else: 

74 merge_values[field_name] = override_value 

75 else: 

76 # Concrete value - use override 

77 merge_values[field_name] = override_value 

78 

79 # Merge with base 

80 if merge_values: 

81 return dataclasses.replace(base, **merge_values) 

82 else: 

83 return base 

84 

85 

86@contextmanager 

87def config_context(obj, mask_with_none: bool = False): 

88 """ 

89 Create new context scope with obj's matching fields merged into base config. 

90 

91 This is the universal context manager for all config context needs. It works by: 

92 1. Finding fields that exist on both obj and the base config type 

93 2. Using matching field values to create a temporary merged config 

94 3. Setting that as the current context 

95 

96 Args: 

97 obj: Object with config fields (pipeline_config, step, etc.) 

98 mask_with_none: If True, None values override/mask base config values. 

99 If False (default), None values are ignored (normal inheritance). 

100 Use True when editing GlobalPipelineConfig to mask thread-local 

101 loaded instance with static class defaults. 

102 

103 Usage: 

104 with config_context(orchestrator.pipeline_config): # Pipeline-level context 

105 # ... 

106 with config_context(step): # Step-level context 

107 # ... 

108 with config_context(GlobalPipelineConfig(), mask_with_none=True): # Static defaults 

109 # ... 

110 """ 

111 # Get current context as base for nested contexts, or fall back to base global config 

112 current_context = get_current_temp_global() 

113 base_config = current_context if current_context is not None else get_base_global_config() 

114 

115 # Find matching fields between obj and base config type 

116 overrides = {} 

117 if obj is not None: 

118 from hieraconf.config import get_base_config_type 

119 

120 base_config_type = get_base_config_type() 

121 

122 for field_info in fields(base_config_type): 

123 field_name = field_info.name 

124 expected_type = field_info.type 

125 

126 # Check if obj has this field 

127 try: 

128 # Use object.__getattribute__ to avoid triggering lazy resolution 

129 if hasattr(obj, field_name): 

130 value = object.__getattribute__(obj, field_name) 

131 # CRITICAL: When mask_with_none=True, None values override base config 

132 # This allows static defaults to mask loaded instance values 

133 if value is not None or mask_with_none: 

134 # When masking with None, always include the value (even if None) 

135 if mask_with_none: 

136 # For nested dataclasses, merge with mask_with_none=True 

137 if is_dataclass(value): 

138 base_value = getattr(base_config, field_name, None) 

139 if base_value is not None and is_dataclass(base_value): 

140 merged_nested = _merge_nested_dataclass(base_value, value, mask_with_none=True) 

141 overrides[field_name] = merged_nested 

142 else: 

143 overrides[field_name] = value 

144 else: 

145 overrides[field_name] = value 

146 # Normal mode: only include non-None values 

147 elif value is not None: 

148 # Check if value is compatible (handles lazy-to-base type mapping) 

149 if _is_compatible_config_type(value, expected_type): 

150 # Convert lazy configs to base configs for context 

151 if hasattr(value, 'to_base_config'): 

152 value = value.to_base_config() 

153 

154 # CRITICAL FIX: Recursively merge nested dataclass fields 

155 # If this is a dataclass field, merge it with the base config's value 

156 # instead of replacing wholesale 

157 if is_dataclass(value): 

158 base_value = getattr(base_config, field_name, None) 

159 if base_value is not None and is_dataclass(base_value): 

160 # Merge nested dataclass: base + overrides 

161 # Pass mask_with_none to recursive merge 

162 merged_nested = _merge_nested_dataclass(base_value, value, mask_with_none=False) 

163 overrides[field_name] = merged_nested 

164 else: 

165 # No base value to merge with, use override as-is 

166 overrides[field_name] = value 

167 else: 

168 # Non-dataclass field, use override as-is 

169 overrides[field_name] = value 

170 except AttributeError: 

171 continue 

172 

173 # Create merged config if we have overrides 

174 if overrides: 

175 try: 

176 merged_config = dataclasses.replace(base_config, **overrides) 

177 logger.debug(f"Creating config context with {len(overrides)} field overrides from {type(obj).__name__}") 

178 except Exception as e: 

179 logger.warning(f"Failed to merge config overrides from {type(obj).__name__}: {e}") 

180 merged_config = base_config 

181 else: 

182 merged_config = base_config 

183 logger.debug(f"Creating config context with no overrides from {type(obj).__name__}") 

184 

185 token = current_temp_global.set(merged_config) 

186 try: 

187 yield 

188 finally: 

189 current_temp_global.reset(token) 

190 

191 

192# Removed: extract_config_overrides - no longer needed with field matching approach 

193 

194 

195# UNUSED: Kept for compatibility but no longer used with field matching approach 

196def extract_from_function_signature(func) -> Dict[str, Any]: 

197 """ 

198 Get parameter defaults as config overrides. 

199  

200 This enables functions to provide config context through their parameter defaults. 

201 Useful for step functions that want to specify their own config values. 

202  

203 Args: 

204 func: Function to extract parameter defaults from 

205  

206 Returns: 

207 Dict of parameter_name -> default_value for parameters with defaults 

208 """ 

209 try: 

210 sig = inspect.signature(func) 

211 overrides = {} 

212 

213 for name, param in sig.parameters.items(): 

214 if param.default != inspect.Parameter.empty: 

215 overrides[name] = param.default 

216 

217 logger.debug(f"Extracted {len(overrides)} overrides from function {func.__name__}") 

218 return overrides 

219 

220 except (ValueError, TypeError) as e: 

221 logger.debug(f"Could not extract signature from {func}: {e}") 

222 return {} 

223 

224 

225def extract_from_dataclass_fields(obj) -> Dict[str, Any]: 

226 """ 

227 Get non-None fields as config overrides. 

228  

229 This extracts concrete values from dataclass instances, ignoring None values 

230 which represent fields that should inherit from context. 

231  

232 Args: 

233 obj: Dataclass instance to extract field values from 

234  

235 Returns: 

236 Dict of field_name -> value for non-None fields 

237 """ 

238 if not is_dataclass(obj): 

239 return {} 

240 

241 overrides = {} 

242 

243 for field in fields(obj): 

244 value = getattr(obj, field.name) 

245 if value is not None: 

246 overrides[field.name] = value 

247 

248 logger.debug(f"Extracted {len(overrides)} overrides from dataclass {type(obj).__name__}") 

249 return overrides 

250 

251 

252def extract_from_object_attributes(obj) -> Dict[str, Any]: 

253 """ 

254 Extract config attributes from step/pipeline objects. 

255  

256 This handles orchestrators, steps, and other objects that have *_config attributes. 

257 It flattens the config hierarchy into a single dict of field overrides. 

258  

259 Args: 

260 obj: Object to extract config attributes from 

261  

262 Returns: 

263 Dict of field_name -> value for all non-None config fields 

264 """ 

265 overrides = {} 

266 

267 try: 

268 for attr_name in dir(obj): 

269 if attr_name.endswith('_config'): 

270 attr_value = getattr(obj, attr_name) 

271 if attr_value is not None and is_dataclass(attr_value): 

272 # Extract all non-None fields from this config 

273 config_overrides = extract_from_dataclass_fields(attr_value) 

274 overrides.update(config_overrides) 

275 

276 logger.debug(f"Extracted {len(overrides)} overrides from object {type(obj).__name__}") 

277 

278 except Exception as e: 

279 logger.debug(f"Error extracting from object {obj}: {e}") 

280 

281 return overrides 

282 

283 

284def merge_configs(base, overrides: Dict[str, Any]): 

285 """ 

286 Merge overrides into base config, creating new immutable instance. 

287  

288 This creates a new config instance with override values merged in, 

289 preserving immutability of the original base config. 

290  

291 Args: 

292 base: Base config instance (base config type) 

293 overrides: Dict of field_name -> value to override 

294  

295 Returns: 

296 New config instance with overrides applied 

297 """ 

298 if not base or not overrides: 

299 return base 

300 

301 try: 

302 # Filter out None values - they should not override existing values 

303 filtered_overrides = {k: v for k, v in overrides.items() if v is not None} 

304 

305 if not filtered_overrides: 

306 return base 

307 

308 # Use dataclasses.replace to create new instance with overrides 

309 merged = dataclasses.replace(base, **filtered_overrides) 

310 

311 logger.debug(f"Merged {len(filtered_overrides)} overrides into {type(base).__name__}") 

312 return merged 

313 

314 except Exception as e: 

315 logger.warning(f"Failed to merge configs: {e}") 

316 return base 

317 

318 

319def get_base_global_config(): 

320 """ 

321 Get the base global config (fallback when no context set). 

322 

323 This provides the global config that was set up with ensure_global_config_context(), 

324 or a default if none was set. Used as the base for merging operations. 

325 

326 Returns: 

327 Current global config instance or default instance of base config type 

328 """ 

329 try: 

330 from hieraconf.config import get_base_config_type 

331 from hieraconf.global_config import get_current_global_config 

332 

333 base_config_type = get_base_config_type() 

334 

335 # First try to get the global config that was set up 

336 current_global = get_current_global_config(base_config_type) 

337 if current_global is not None: 

338 return current_global 

339 

340 # Fallback to default if none was set 

341 return base_config_type() 

342 except ImportError: 

343 logger.warning("Could not get base config type") 

344 return None 

345 

346 

347def get_current_temp_global(): 

348 """ 

349 Get current context or None. 

350  

351 This is the primary interface for resolution functions to access 

352 the current context. Returns None if no context is active. 

353  

354 Returns: 

355 Current merged global config or None 

356 """ 

357 return current_temp_global.get(None) 

358 

359 

360def set_current_temp_global(config): 

361 """ 

362 Set current context (for testing/debugging). 

363  

364 This is primarily for testing purposes. Normal code should use 

365 config_context() manager instead. 

366  

367 Args: 

368 config: Global config instance to set as current context 

369  

370 Returns: 

371 Token for resetting the context 

372 """ 

373 return current_temp_global.set(config) 

374 

375 

376def clear_current_temp_global(): 

377 """ 

378 Clear current context (for testing/debugging). 

379  

380 This removes any active context, causing resolution to fall back 

381 to default behavior. 

382 """ 

383 try: 

384 current_temp_global.set(None) 

385 except LookupError: 

386 pass # No context was set 

387 

388 

389# Utility functions for debugging and introspection 

390 

391def get_context_info() -> Dict[str, Any]: 

392 """ 

393 Get information about current context for debugging. 

394  

395 Returns: 

396 Dict with context information including type, field count, etc. 

397 """ 

398 current = get_current_temp_global() 

399 if current is None: 

400 return {"active": False} 

401 

402 return { 

403 "active": True, 

404 "type": type(current).__name__, 

405 "field_count": len(fields(current)) if is_dataclass(current) else 0, 

406 "non_none_fields": sum(1 for f in fields(current) 

407 if getattr(current, f.name) is not None) if is_dataclass(current) else 0 

408 } 

409 

410 

411def extract_all_configs_from_context() -> Dict[str, Any]: 

412 """ 

413 Extract all *_config attributes from current context. 

414 

415 This is used by the resolution system to get all available configs 

416 for cross-dataclass inheritance resolution. 

417 

418 Returns: 

419 Dict of config_name -> config_instance for all *_config attributes 

420 """ 

421 current = get_current_temp_global() 

422 if current is None: 

423 return {} 

424 

425 return extract_all_configs(current) 

426 

427 

428def extract_all_configs(context_obj) -> Dict[str, Any]: 

429 """ 

430 Extract all config instances from a context object using type-driven approach. 

431 

432 This function leverages dataclass field type annotations to efficiently extract 

433 config instances, avoiding string matching and runtime attribute scanning. 

434 

435 Args: 

436 context_obj: Object to extract configs from (orchestrator, merged config, etc.) 

437 

438 Returns: 

439 Dict mapping config type names to config instances 

440 """ 

441 if context_obj is None: 

442 return {} 

443 

444 configs = {} 

445 

446 # Include the context object itself if it's a dataclass 

447 if is_dataclass(context_obj): 

448 configs[type(context_obj).__name__] = context_obj 

449 

450 # Type-driven extraction: Use dataclass field annotations to find config fields 

451 if is_dataclass(type(context_obj)): 

452 for field_info in fields(type(context_obj)): 

453 field_type = field_info.type 

454 field_name = field_info.name 

455 

456 # Handle Optional[ConfigType] annotations 

457 actual_type = _unwrap_optional_type(field_type) 

458 

459 # Only process fields that are dataclass types (config objects) 

460 if is_dataclass(actual_type): 

461 try: 

462 field_value = getattr(context_obj, field_name) 

463 if field_value is not None: 

464 # Use the actual instance type, not the annotation type 

465 # This handles cases where field is annotated as base class but contains subclass 

466 instance_type = type(field_value) 

467 configs[instance_type.__name__] = field_value 

468 

469 logger.debug(f"Extracted config {instance_type.__name__} from field {field_name}") 

470 

471 except AttributeError: 

472 # Field doesn't exist on instance (shouldn't happen with dataclasses) 

473 logger.debug(f"Field {field_name} not found on {type(context_obj).__name__}") 

474 continue 

475 

476 # For non-dataclass objects (orchestrators, etc.), extract dataclass attributes 

477 else: 

478 _extract_from_object_attributes_typed(context_obj, configs) 

479 

480 logger.debug(f"Extracted {len(configs)} configs: {list(configs.keys())}") 

481 return configs 

482 

483 

484def _unwrap_optional_type(field_type): 

485 """ 

486 Unwrap Optional[T] and Union[T, None] types to get the actual type T. 

487 

488 This handles type annotations like Optional[ConfigType] -> ConfigType 

489 """ 

490 # Handle typing.Optional and typing.Union 

491 if hasattr(field_type, '__origin__'): 

492 if field_type.__origin__ is Union: 

493 # Get non-None types from Union 

494 non_none_types = [arg for arg in field_type.__args__ if arg is not type(None)] 

495 if len(non_none_types) == 1: 

496 return non_none_types[0] 

497 

498 return field_type 

499 

500 

501def _extract_from_object_attributes_typed(obj, configs: Dict[str, Any]) -> None: 

502 """ 

503 Type-safe extraction from object attributes for non-dataclass objects. 

504 

505 This is used for orchestrators and other objects that aren't dataclasses 

506 but have config attributes. Uses type checking instead of string matching. 

507 """ 

508 try: 

509 # Get all attributes that are dataclass instances 

510 for attr_name in dir(obj): 

511 if attr_name.startswith('_'): 

512 continue 

513 

514 try: 

515 attr_value = getattr(obj, attr_name) 

516 if attr_value is not None and is_dataclass(attr_value): 

517 configs[type(attr_value).__name__] = attr_value 

518 logger.debug(f"Extracted config {type(attr_value).__name__} from attribute {attr_name}") 

519 

520 except (AttributeError, TypeError): 

521 # Skip attributes that can't be accessed or aren't relevant 

522 continue 

523 

524 except Exception as e: 

525 logger.debug(f"Error in typed attribute extraction: {e}") 

526 

527 

528def _is_compatible_config_type(value, expected_type) -> bool: 

529 """ 

530 Check if value is compatible with expected_type, handling lazy-to-base type mapping. 

531 

532 This handles cases where: 

533 - value is LazyStepMaterializationConfig, expected_type is StepMaterializationConfig 

534 - value is a subclass of the expected type 

535 - value is exactly the expected type 

536 """ 

537 value_type = type(value) 

538 

539 # Direct type match 

540 if value_type == expected_type: 

541 return True 

542 

543 # Check if value_type is a subclass of expected_type 

544 try: 

545 if issubclass(value_type, expected_type): 

546 return True 

547 except TypeError: 

548 # expected_type might not be a class (e.g., Union, Optional) 

549 pass 

550 

551 # Check lazy-to-base type mapping 

552 if hasattr(value, 'to_base_config'): 

553 # This is a lazy config - check if its base type matches expected_type 

554 from hieraconf.lazy_factory import _lazy_type_registry 

555 base_type = _lazy_type_registry.get(value_type) 

556 if base_type == expected_type: 

557 return True 

558 # Also check if base type is subclass of expected type 

559 if base_type and issubclass(base_type, expected_type): 

560 return True 

561 

562 return False