Coverage for openhcs/runtime/zmq_messages.py: 60.4%
214 statements
« prev ^ index » next coverage.py v7.11.0, created at 2025-11-04 02:09 +0000
« prev ^ index » next coverage.py v7.11.0, created at 2025-11-04 02:09 +0000
1"""ZMQ Message Type System - enum dispatch and structured messages."""
3import logging
4from enum import Enum
5from dataclasses import dataclass
7logger = logging.getLogger(__name__)
10class MessageFields:
11 TYPE = "type"
12 PLATE_ID = "plate_id"
13 PIPELINE_CODE = "pipeline_code"
14 CONFIG_PARAMS = "config_params"
15 CONFIG_CODE = "config_code"
16 PIPELINE_CONFIG_CODE = "pipeline_config_code"
17 CLIENT_ADDRESS = "client_address"
18 EXECUTION_ID = "execution_id"
19 START_TIME = "start_time"
20 END_TIME = "end_time"
21 ELAPSED = "elapsed"
22 STATUS = "status"
23 ERROR = "error"
24 MESSAGE = "message"
25 PORT = "port"
26 CONTROL_PORT = "control_port"
27 READY = "ready"
28 SERVER = "server"
29 LOG_FILE_PATH = "log_file_path"
30 ACTIVE_EXECUTIONS = "active_executions"
31 RUNNING_EXECUTIONS = "running_executions"
32 WORKERS = "workers"
33 WORKERS_KILLED = "workers_killed"
34 UPTIME = "uptime"
35 EXECUTIONS = "executions"
36 WELL_COUNT = "well_count"
37 WELLS = "wells"
38 RESULTS_SUMMARY = "results_summary"
39 WELL_ID = "well_id"
40 STEP = "step"
41 TIMESTAMP = "timestamp"
42 # Acknowledgment message fields
43 IMAGE_ID = "image_id"
44 VIEWER_PORT = "viewer_port"
45 VIEWER_TYPE = "viewer_type"
46 # ROI message fields
47 ROIS = "rois"
48 LAYER_NAME = "layer_name"
49 SHAPES = "shapes"
50 COORDINATES = "coordinates"
51 METADATA = "metadata"
54class ControlMessageType(Enum):
55 PING = "ping"
56 EXECUTE = "execute"
57 STATUS = "status"
58 CANCEL = "cancel"
59 SHUTDOWN = "shutdown"
60 FORCE_SHUTDOWN = "force_shutdown"
62 def get_handler_method(self):
63 return {
64 ControlMessageType.EXECUTE: "_handle_execute",
65 ControlMessageType.STATUS: "_handle_status",
66 ControlMessageType.CANCEL: "_handle_cancel",
67 ControlMessageType.SHUTDOWN: "_handle_shutdown",
68 ControlMessageType.FORCE_SHUTDOWN: "_handle_force_shutdown",
69 }[self]
71 def dispatch(self, server, message):
72 return getattr(server, self.get_handler_method())(message)
75class ResponseType(Enum):
76 PONG = "pong"
77 ACCEPTED = "accepted"
78 OK = "ok"
79 ERROR = "error"
80 SHUTDOWN_ACK = "shutdown_ack"
83class ExecutionStatus(Enum):
84 QUEUED = "queued"
85 RUNNING = "running"
86 COMPLETE = "complete"
87 COMPLETED = "completed"
88 FAILED = "failed"
89 CANCELLED = "cancelled"
90 ACCEPTED = "accepted"
93class SocketType(Enum):
94 PUB = "PUB"
95 SUB = "SUB"
96 REQ = "REQ"
97 REP = "REP"
99 @classmethod
100 def from_zmq_constant(cls, zmq_const):
101 import zmq
102 return {zmq.PUB: cls.PUB, zmq.SUB: cls.SUB, zmq.REQ: cls.REQ, zmq.REP: cls.REP}.get(zmq_const, cls.PUB)
104 def get_display_name(self):
105 return self.value
108@dataclass(frozen=True)
109class ExecuteRequest:
110 plate_id: str
111 pipeline_code: str
112 config_params: dict = None
113 config_code: str = None
114 pipeline_config_code: str = None
115 client_address: str = None
117 def validate(self):
118 if not self.plate_id:
119 return "Missing required field: plate_id"
120 if not self.pipeline_code:
121 return "Missing required field: pipeline_code"
122 if self.config_params is None and self.config_code is None:
123 return "Missing config: provide either config_params or config_code"
124 return None
126 def to_dict(self):
127 result = {MessageFields.TYPE: ControlMessageType.EXECUTE.value, MessageFields.PLATE_ID: self.plate_id, MessageFields.PIPELINE_CODE: self.pipeline_code}
128 if self.config_params is not None:
129 result[MessageFields.CONFIG_PARAMS] = self.config_params
130 if self.config_code is not None:
131 result[MessageFields.CONFIG_CODE] = self.config_code
132 if self.pipeline_config_code is not None:
133 result[MessageFields.PIPELINE_CONFIG_CODE] = self.pipeline_config_code
134 if self.client_address is not None:
135 result[MessageFields.CLIENT_ADDRESS] = self.client_address
136 return result
138 @classmethod
139 def from_dict(cls, data):
140 return cls(plate_id=data[MessageFields.PLATE_ID], pipeline_code=data[MessageFields.PIPELINE_CODE],
141 config_params=data.get(MessageFields.CONFIG_PARAMS), config_code=data.get(MessageFields.CONFIG_CODE),
142 pipeline_config_code=data.get(MessageFields.PIPELINE_CONFIG_CODE), client_address=data.get(MessageFields.CLIENT_ADDRESS))
145@dataclass(frozen=True)
146class ExecuteResponse:
147 status: ResponseType
148 execution_id: str = None
149 message: str = None
150 error: str = None
152 def to_dict(self):
153 result = {MessageFields.STATUS: self.status.value}
154 if self.execution_id is not None:
155 result[MessageFields.EXECUTION_ID] = self.execution_id
156 if self.message is not None:
157 result[MessageFields.MESSAGE] = self.message
158 if self.error is not None:
159 result[MessageFields.ERROR] = self.error
160 return result
163@dataclass(frozen=True)
164class StatusRequest:
165 execution_id: str = None
167 def to_dict(self):
168 result = {MessageFields.TYPE: ControlMessageType.STATUS.value}
169 if self.execution_id is not None:
170 result[MessageFields.EXECUTION_ID] = self.execution_id
171 return result
173 @classmethod
174 def from_dict(cls, data):
175 return cls(execution_id=data.get(MessageFields.EXECUTION_ID))
178@dataclass(frozen=True)
179class CancelRequest:
180 execution_id: str
182 def validate(self):
183 return "Missing execution_id" if not self.execution_id else None
185 def to_dict(self):
186 return {MessageFields.TYPE: ControlMessageType.CANCEL.value, MessageFields.EXECUTION_ID: self.execution_id}
188 @classmethod
189 def from_dict(cls, data):
190 return cls(execution_id=data[MessageFields.EXECUTION_ID])
193@dataclass(frozen=True)
194class PongResponse:
195 port: int
196 control_port: int
197 ready: bool
198 server: str
199 log_file_path: str = None
200 active_executions: int = None
201 running_executions: list = None
202 workers: list = None
203 uptime: float = None
205 def to_dict(self):
206 result = {MessageFields.TYPE: ResponseType.PONG.value, MessageFields.PORT: self.port,
207 MessageFields.CONTROL_PORT: self.control_port, MessageFields.READY: self.ready, MessageFields.SERVER: self.server}
208 if self.log_file_path is not None:
209 result[MessageFields.LOG_FILE_PATH] = self.log_file_path
210 if self.active_executions is not None:
211 result[MessageFields.ACTIVE_EXECUTIONS] = self.active_executions
212 if self.running_executions is not None:
213 result[MessageFields.RUNNING_EXECUTIONS] = self.running_executions
214 if self.workers is not None:
215 result[MessageFields.WORKERS] = self.workers
216 if self.uptime is not None:
217 result[MessageFields.UPTIME] = self.uptime
218 return result
221@dataclass(frozen=True)
222class ProgressUpdate:
223 well_id: str
224 step: str
225 status: str
226 timestamp: float
228 def to_dict(self):
229 return {MessageFields.TYPE: "progress", MessageFields.WELL_ID: self.well_id,
230 MessageFields.STEP: self.step, MessageFields.STATUS: self.status, MessageFields.TIMESTAMP: self.timestamp}
233@dataclass(frozen=True)
234class ImageAck:
235 """Acknowledgment message sent by viewers after processing an image.
237 Sent via PUSH socket from viewer to shared ack port (7555).
238 Used to track real-time queue depth and show progress like '3/10 images processed'.
239 """
240 image_id: str # UUID of the processed image
241 viewer_port: int # Port of the viewer that processed it (for routing)
242 viewer_type: str # 'napari' or 'fiji'
243 status: str = 'success' # 'success', 'error', etc.
244 timestamp: float = None # When it was processed
245 error: str = None # Error message if status='error'
247 def to_dict(self):
248 result = {
249 MessageFields.TYPE: "image_ack",
250 MessageFields.IMAGE_ID: self.image_id,
251 MessageFields.VIEWER_PORT: self.viewer_port,
252 MessageFields.VIEWER_TYPE: self.viewer_type,
253 MessageFields.STATUS: self.status
254 }
255 if self.timestamp is not None:
256 result[MessageFields.TIMESTAMP] = self.timestamp
257 if self.error is not None:
258 result[MessageFields.ERROR] = self.error
259 return result
261 @classmethod
262 def from_dict(cls, data):
263 return cls(
264 image_id=data[MessageFields.IMAGE_ID],
265 viewer_port=data[MessageFields.VIEWER_PORT],
266 viewer_type=data[MessageFields.VIEWER_TYPE],
267 status=data.get(MessageFields.STATUS, 'success'),
268 timestamp=data.get(MessageFields.TIMESTAMP),
269 error=data.get(MessageFields.ERROR)
270 )
272@dataclass(frozen=True)
273class ROIMessage:
274 """Message for streaming ROIs to viewers (Napari/Fiji).
276 Sent via ZMQ to viewer servers to display ROIs in real-time.
277 """
278 rois: list # List of ROI dictionaries with shapes and metadata
279 layer_name: str = "ROIs" # Name of the layer/overlay
281 def to_dict(self):
282 return {
283 MessageFields.TYPE: "rois",
284 MessageFields.ROIS: self.rois,
285 MessageFields.LAYER_NAME: self.layer_name
286 }
288 @classmethod
289 def from_dict(cls, data):
290 return cls(
291 rois=data[MessageFields.ROIS],
292 layer_name=data.get(MessageFields.LAYER_NAME, "ROIs")
293 )
296@dataclass(frozen=True)
297class ShapesMessage:
298 """Message for Napari shapes layer.
300 Napari-specific format for displaying polygon/ellipse shapes.
301 """
302 shapes: list # List of shape dictionaries with type, coordinates, metadata
303 layer_name: str = "ROIs"
305 def to_dict(self):
306 return {
307 MessageFields.TYPE: "shapes",
308 MessageFields.SHAPES: self.shapes,
309 MessageFields.LAYER_NAME: self.layer_name
310 }
312 @classmethod
313 def from_dict(cls, data):
314 return cls(
315 shapes=data[MessageFields.SHAPES],
316 layer_name=data.get(MessageFields.LAYER_NAME, "ROIs")
317 )