Coverage for openhcs/core/components/multiprocessing.py: 24.6%
51 statements
« prev ^ index » next coverage.py v7.10.7, created at 2025-10-01 18:33 +0000
« prev ^ index » next coverage.py v7.10.7, created at 2025-10-01 18:33 +0000
1"""
2Generic multiprocessing coordinator for configurable axis iteration.
4This module provides a generic replacement for the hardcoded Well-based
5multiprocessing logic, allowing any component to serve as the multiprocessing axis.
6"""
8import logging
9from typing import Generic, TypeVar, Dict, List, Optional, Any, Callable
10from enum import Enum
11from dataclasses import dataclass
13from .framework import ComponentConfiguration
15logger = logging.getLogger(__name__)
17T = TypeVar('T', bound=Enum)
20@dataclass
21class Task(Generic[T]):
22 """Represents a single multiprocessing task."""
23 axis_value: str
24 context: Any # ProcessingContext - avoiding circular import
27class MultiprocessingCoordinator(Generic[T]):
28 """
29 Generic coordinator for multiprocessing along any component axis.
31 This class replaces the hardcoded Well-based multiprocessing logic with
32 a configurable system that can use any component as the multiprocessing axis.
33 """
35 def __init__(self, config: ComponentConfiguration[T]):
36 """
37 Initialize the coordinator with a component configuration.
39 Args:
40 config: ComponentConfiguration specifying the multiprocessing axis
41 """
42 self.config = config
43 self.axis = config.multiprocessing_axis
44 logger.debug(f"MultiprocessingCoordinator initialized with axis: {self.axis.value}")
46 def create_tasks(
47 self,
48 orchestrator,
49 pipeline_definition: List[Any],
50 axis_filter: Optional[List[str]] = None
51 ) -> Dict[str, Task[T]]:
52 """
53 Create tasks for each value of the multiprocessing axis.
55 This method replaces the hardcoded well iteration logic with generic
56 component iteration based on the configured multiprocessing axis.
58 Args:
59 orchestrator: PipelineOrchestrator instance
60 pipeline_definition: List of pipeline steps
61 axis_filter: Optional filter for axis values
63 Returns:
64 Dictionary mapping axis values to Task objects
65 """
66 # Get axis values from orchestrator using the multiprocessing axis component directly
67 # The orchestrator should accept VariableComponents enum directly
68 axis_values = orchestrator.get_component_keys(self.axis, axis_filter)
70 if not axis_values:
71 logger.warning(f"No {self.axis.value} values found for multiprocessing")
72 return {}
74 logger.info(f"Creating tasks for {len(axis_values)} {self.axis.value} values: {axis_values}")
76 # Create tasks
77 tasks = {}
78 for axis_value in axis_values:
79 context = orchestrator.create_context(axis_value)
80 tasks[axis_value] = Task(axis_value=axis_value, context=context)
81 logger.debug(f"Created task for {self.axis.value}: {axis_value}")
83 return tasks
85 def execute_tasks(
86 self,
87 tasks: Dict[str, Task[T]],
88 pipeline_definition: List[Any],
89 executor,
90 processor_func: Callable
91 ) -> Dict[str, Any]:
92 """
93 Execute tasks using the provided executor and processor function.
95 This method provides a generic interface for task execution that can
96 work with any multiprocessing axis.
98 Args:
99 tasks: Dictionary of tasks to execute
100 pipeline_definition: List of pipeline steps
101 executor: Executor instance (ThreadPoolExecutor or ProcessPoolExecutor)
102 processor_func: Function to process each task
104 Returns:
105 Dictionary mapping axis values to execution results
106 """
107 if not tasks:
108 logger.warning("No tasks to execute")
109 return {}
111 logger.info(f"Executing {len(tasks)} tasks for {self.axis.value} axis")
113 # Submit tasks to executor
114 future_to_axis_value = {}
115 for axis_value, task in tasks.items():
116 future = executor.submit(processor_func, pipeline_definition, task.context)
117 future_to_axis_value[future] = axis_value
118 logger.debug(f"Submitted task for {self.axis.value}: {axis_value}")
120 # Collect results
121 results = {}
122 import concurrent.futures
123 for future in concurrent.futures.as_completed(future_to_axis_value):
124 axis_value = future_to_axis_value[future]
125 try:
126 result = future.result()
127 results[axis_value] = result
128 logger.debug(f"Task completed for {self.axis.value}: {axis_value}")
129 except Exception as e:
130 logger.error(f"Task failed for {self.axis.value} {axis_value}: {e}")
131 results[axis_value] = {"status": "error", "error": str(e)}
133 logger.info(f"Completed execution of {len(results)} tasks")
134 return results