Coverage for openhcs/core/components/multiprocessing.py: 24.6%

51 statements  

« prev     ^ index     » next       coverage.py v7.10.7, created at 2025-10-01 18:33 +0000

1""" 

2Generic multiprocessing coordinator for configurable axis iteration. 

3 

4This module provides a generic replacement for the hardcoded Well-based 

5multiprocessing logic, allowing any component to serve as the multiprocessing axis. 

6""" 

7 

8import logging 

9from typing import Generic, TypeVar, Dict, List, Optional, Any, Callable 

10from enum import Enum 

11from dataclasses import dataclass 

12 

13from .framework import ComponentConfiguration 

14 

15logger = logging.getLogger(__name__) 

16 

17T = TypeVar('T', bound=Enum) 

18 

19 

20@dataclass 

21class Task(Generic[T]): 

22 """Represents a single multiprocessing task.""" 

23 axis_value: str 

24 context: Any # ProcessingContext - avoiding circular import 

25 

26 

27class MultiprocessingCoordinator(Generic[T]): 

28 """ 

29 Generic coordinator for multiprocessing along any component axis. 

30  

31 This class replaces the hardcoded Well-based multiprocessing logic with 

32 a configurable system that can use any component as the multiprocessing axis. 

33 """ 

34 

35 def __init__(self, config: ComponentConfiguration[T]): 

36 """ 

37 Initialize the coordinator with a component configuration. 

38  

39 Args: 

40 config: ComponentConfiguration specifying the multiprocessing axis 

41 """ 

42 self.config = config 

43 self.axis = config.multiprocessing_axis 

44 logger.debug(f"MultiprocessingCoordinator initialized with axis: {self.axis.value}") 

45 

46 def create_tasks( 

47 self, 

48 orchestrator, 

49 pipeline_definition: List[Any], 

50 axis_filter: Optional[List[str]] = None 

51 ) -> Dict[str, Task[T]]: 

52 """ 

53 Create tasks for each value of the multiprocessing axis. 

54 

55 This method replaces the hardcoded well iteration logic with generic 

56 component iteration based on the configured multiprocessing axis. 

57 

58 Args: 

59 orchestrator: PipelineOrchestrator instance 

60 pipeline_definition: List of pipeline steps 

61 axis_filter: Optional filter for axis values 

62 

63 Returns: 

64 Dictionary mapping axis values to Task objects 

65 """ 

66 # Get axis values from orchestrator using the multiprocessing axis component directly 

67 # The orchestrator should accept VariableComponents enum directly 

68 axis_values = orchestrator.get_component_keys(self.axis, axis_filter) 

69 

70 if not axis_values: 

71 logger.warning(f"No {self.axis.value} values found for multiprocessing") 

72 return {} 

73 

74 logger.info(f"Creating tasks for {len(axis_values)} {self.axis.value} values: {axis_values}") 

75 

76 # Create tasks 

77 tasks = {} 

78 for axis_value in axis_values: 

79 context = orchestrator.create_context(axis_value) 

80 tasks[axis_value] = Task(axis_value=axis_value, context=context) 

81 logger.debug(f"Created task for {self.axis.value}: {axis_value}") 

82 

83 return tasks 

84 

85 def execute_tasks( 

86 self, 

87 tasks: Dict[str, Task[T]], 

88 pipeline_definition: List[Any], 

89 executor, 

90 processor_func: Callable 

91 ) -> Dict[str, Any]: 

92 """ 

93 Execute tasks using the provided executor and processor function. 

94  

95 This method provides a generic interface for task execution that can 

96 work with any multiprocessing axis. 

97  

98 Args: 

99 tasks: Dictionary of tasks to execute 

100 pipeline_definition: List of pipeline steps 

101 executor: Executor instance (ThreadPoolExecutor or ProcessPoolExecutor) 

102 processor_func: Function to process each task 

103  

104 Returns: 

105 Dictionary mapping axis values to execution results 

106 """ 

107 if not tasks: 

108 logger.warning("No tasks to execute") 

109 return {} 

110 

111 logger.info(f"Executing {len(tasks)} tasks for {self.axis.value} axis") 

112 

113 # Submit tasks to executor 

114 future_to_axis_value = {} 

115 for axis_value, task in tasks.items(): 

116 future = executor.submit(processor_func, pipeline_definition, task.context) 

117 future_to_axis_value[future] = axis_value 

118 logger.debug(f"Submitted task for {self.axis.value}: {axis_value}") 

119 

120 # Collect results 

121 results = {} 

122 import concurrent.futures 

123 for future in concurrent.futures.as_completed(future_to_axis_value): 

124 axis_value = future_to_axis_value[future] 

125 try: 

126 result = future.result() 

127 results[axis_value] = result 

128 logger.debug(f"Task completed for {self.axis.value}: {axis_value}") 

129 except Exception as e: 

130 logger.error(f"Task failed for {self.axis.value} {axis_value}: {e}") 

131 results[axis_value] = {"status": "error", "error": str(e)} 

132 

133 logger.info(f"Completed execution of {len(results)} tasks") 

134 return results