Coverage for openhcs/runtime/zmq_messages.py: 60.4%

214 statements  

« prev     ^ index     » next       coverage.py v7.11.0, created at 2025-11-04 02:09 +0000

1"""ZMQ Message Type System - enum dispatch and structured messages.""" 

2 

3import logging 

4from enum import Enum 

5from dataclasses import dataclass 

6 

7logger = logging.getLogger(__name__) 

8 

9 

10class MessageFields: 

11 TYPE = "type" 

12 PLATE_ID = "plate_id" 

13 PIPELINE_CODE = "pipeline_code" 

14 CONFIG_PARAMS = "config_params" 

15 CONFIG_CODE = "config_code" 

16 PIPELINE_CONFIG_CODE = "pipeline_config_code" 

17 CLIENT_ADDRESS = "client_address" 

18 EXECUTION_ID = "execution_id" 

19 START_TIME = "start_time" 

20 END_TIME = "end_time" 

21 ELAPSED = "elapsed" 

22 STATUS = "status" 

23 ERROR = "error" 

24 MESSAGE = "message" 

25 PORT = "port" 

26 CONTROL_PORT = "control_port" 

27 READY = "ready" 

28 SERVER = "server" 

29 LOG_FILE_PATH = "log_file_path" 

30 ACTIVE_EXECUTIONS = "active_executions" 

31 RUNNING_EXECUTIONS = "running_executions" 

32 WORKERS = "workers" 

33 WORKERS_KILLED = "workers_killed" 

34 UPTIME = "uptime" 

35 EXECUTIONS = "executions" 

36 WELL_COUNT = "well_count" 

37 WELLS = "wells" 

38 RESULTS_SUMMARY = "results_summary" 

39 WELL_ID = "well_id" 

40 STEP = "step" 

41 TIMESTAMP = "timestamp" 

42 # Acknowledgment message fields 

43 IMAGE_ID = "image_id" 

44 VIEWER_PORT = "viewer_port" 

45 VIEWER_TYPE = "viewer_type" 

46 # ROI message fields 

47 ROIS = "rois" 

48 LAYER_NAME = "layer_name" 

49 SHAPES = "shapes" 

50 COORDINATES = "coordinates" 

51 METADATA = "metadata" 

52 

53 

54class ControlMessageType(Enum): 

55 PING = "ping" 

56 EXECUTE = "execute" 

57 STATUS = "status" 

58 CANCEL = "cancel" 

59 SHUTDOWN = "shutdown" 

60 FORCE_SHUTDOWN = "force_shutdown" 

61 

62 def get_handler_method(self): 

63 return { 

64 ControlMessageType.EXECUTE: "_handle_execute", 

65 ControlMessageType.STATUS: "_handle_status", 

66 ControlMessageType.CANCEL: "_handle_cancel", 

67 ControlMessageType.SHUTDOWN: "_handle_shutdown", 

68 ControlMessageType.FORCE_SHUTDOWN: "_handle_force_shutdown", 

69 }[self] 

70 

71 def dispatch(self, server, message): 

72 return getattr(server, self.get_handler_method())(message) 

73 

74 

75class ResponseType(Enum): 

76 PONG = "pong" 

77 ACCEPTED = "accepted" 

78 OK = "ok" 

79 ERROR = "error" 

80 SHUTDOWN_ACK = "shutdown_ack" 

81 

82 

83class ExecutionStatus(Enum): 

84 QUEUED = "queued" 

85 RUNNING = "running" 

86 COMPLETE = "complete" 

87 COMPLETED = "completed" 

88 FAILED = "failed" 

89 CANCELLED = "cancelled" 

90 ACCEPTED = "accepted" 

91 

92 

93class SocketType(Enum): 

94 PUB = "PUB" 

95 SUB = "SUB" 

96 REQ = "REQ" 

97 REP = "REP" 

98 

99 @classmethod 

100 def from_zmq_constant(cls, zmq_const): 

101 import zmq 

102 return {zmq.PUB: cls.PUB, zmq.SUB: cls.SUB, zmq.REQ: cls.REQ, zmq.REP: cls.REP}.get(zmq_const, cls.PUB) 

103 

104 def get_display_name(self): 

105 return self.value 

106 

107 

108@dataclass(frozen=True) 

109class ExecuteRequest: 

110 plate_id: str 

111 pipeline_code: str 

112 config_params: dict = None 

113 config_code: str = None 

114 pipeline_config_code: str = None 

115 client_address: str = None 

116 

117 def validate(self): 

118 if not self.plate_id: 

119 return "Missing required field: plate_id" 

120 if not self.pipeline_code: 

121 return "Missing required field: pipeline_code" 

122 if self.config_params is None and self.config_code is None: 

123 return "Missing config: provide either config_params or config_code" 

124 return None 

125 

126 def to_dict(self): 

127 result = {MessageFields.TYPE: ControlMessageType.EXECUTE.value, MessageFields.PLATE_ID: self.plate_id, MessageFields.PIPELINE_CODE: self.pipeline_code} 

128 if self.config_params is not None: 

129 result[MessageFields.CONFIG_PARAMS] = self.config_params 

130 if self.config_code is not None: 

131 result[MessageFields.CONFIG_CODE] = self.config_code 

132 if self.pipeline_config_code is not None: 

133 result[MessageFields.PIPELINE_CONFIG_CODE] = self.pipeline_config_code 

134 if self.client_address is not None: 

135 result[MessageFields.CLIENT_ADDRESS] = self.client_address 

136 return result 

137 

138 @classmethod 

139 def from_dict(cls, data): 

140 return cls(plate_id=data[MessageFields.PLATE_ID], pipeline_code=data[MessageFields.PIPELINE_CODE], 

141 config_params=data.get(MessageFields.CONFIG_PARAMS), config_code=data.get(MessageFields.CONFIG_CODE), 

142 pipeline_config_code=data.get(MessageFields.PIPELINE_CONFIG_CODE), client_address=data.get(MessageFields.CLIENT_ADDRESS)) 

143 

144 

145@dataclass(frozen=True) 

146class ExecuteResponse: 

147 status: ResponseType 

148 execution_id: str = None 

149 message: str = None 

150 error: str = None 

151 

152 def to_dict(self): 

153 result = {MessageFields.STATUS: self.status.value} 

154 if self.execution_id is not None: 

155 result[MessageFields.EXECUTION_ID] = self.execution_id 

156 if self.message is not None: 

157 result[MessageFields.MESSAGE] = self.message 

158 if self.error is not None: 

159 result[MessageFields.ERROR] = self.error 

160 return result 

161 

162 

163@dataclass(frozen=True) 

164class StatusRequest: 

165 execution_id: str = None 

166 

167 def to_dict(self): 

168 result = {MessageFields.TYPE: ControlMessageType.STATUS.value} 

169 if self.execution_id is not None: 

170 result[MessageFields.EXECUTION_ID] = self.execution_id 

171 return result 

172 

173 @classmethod 

174 def from_dict(cls, data): 

175 return cls(execution_id=data.get(MessageFields.EXECUTION_ID)) 

176 

177 

178@dataclass(frozen=True) 

179class CancelRequest: 

180 execution_id: str 

181 

182 def validate(self): 

183 return "Missing execution_id" if not self.execution_id else None 

184 

185 def to_dict(self): 

186 return {MessageFields.TYPE: ControlMessageType.CANCEL.value, MessageFields.EXECUTION_ID: self.execution_id} 

187 

188 @classmethod 

189 def from_dict(cls, data): 

190 return cls(execution_id=data[MessageFields.EXECUTION_ID]) 

191 

192 

193@dataclass(frozen=True) 

194class PongResponse: 

195 port: int 

196 control_port: int 

197 ready: bool 

198 server: str 

199 log_file_path: str = None 

200 active_executions: int = None 

201 running_executions: list = None 

202 workers: list = None 

203 uptime: float = None 

204 

205 def to_dict(self): 

206 result = {MessageFields.TYPE: ResponseType.PONG.value, MessageFields.PORT: self.port, 

207 MessageFields.CONTROL_PORT: self.control_port, MessageFields.READY: self.ready, MessageFields.SERVER: self.server} 

208 if self.log_file_path is not None: 

209 result[MessageFields.LOG_FILE_PATH] = self.log_file_path 

210 if self.active_executions is not None: 

211 result[MessageFields.ACTIVE_EXECUTIONS] = self.active_executions 

212 if self.running_executions is not None: 

213 result[MessageFields.RUNNING_EXECUTIONS] = self.running_executions 

214 if self.workers is not None: 

215 result[MessageFields.WORKERS] = self.workers 

216 if self.uptime is not None: 

217 result[MessageFields.UPTIME] = self.uptime 

218 return result 

219 

220 

221@dataclass(frozen=True) 

222class ProgressUpdate: 

223 well_id: str 

224 step: str 

225 status: str 

226 timestamp: float 

227 

228 def to_dict(self): 

229 return {MessageFields.TYPE: "progress", MessageFields.WELL_ID: self.well_id, 

230 MessageFields.STEP: self.step, MessageFields.STATUS: self.status, MessageFields.TIMESTAMP: self.timestamp} 

231 

232 

233@dataclass(frozen=True) 

234class ImageAck: 

235 """Acknowledgment message sent by viewers after processing an image. 

236 

237 Sent via PUSH socket from viewer to shared ack port (7555). 

238 Used to track real-time queue depth and show progress like '3/10 images processed'. 

239 """ 

240 image_id: str # UUID of the processed image 

241 viewer_port: int # Port of the viewer that processed it (for routing) 

242 viewer_type: str # 'napari' or 'fiji' 

243 status: str = 'success' # 'success', 'error', etc. 

244 timestamp: float = None # When it was processed 

245 error: str = None # Error message if status='error' 

246 

247 def to_dict(self): 

248 result = { 

249 MessageFields.TYPE: "image_ack", 

250 MessageFields.IMAGE_ID: self.image_id, 

251 MessageFields.VIEWER_PORT: self.viewer_port, 

252 MessageFields.VIEWER_TYPE: self.viewer_type, 

253 MessageFields.STATUS: self.status 

254 } 

255 if self.timestamp is not None: 

256 result[MessageFields.TIMESTAMP] = self.timestamp 

257 if self.error is not None: 

258 result[MessageFields.ERROR] = self.error 

259 return result 

260 

261 @classmethod 

262 def from_dict(cls, data): 

263 return cls( 

264 image_id=data[MessageFields.IMAGE_ID], 

265 viewer_port=data[MessageFields.VIEWER_PORT], 

266 viewer_type=data[MessageFields.VIEWER_TYPE], 

267 status=data.get(MessageFields.STATUS, 'success'), 

268 timestamp=data.get(MessageFields.TIMESTAMP), 

269 error=data.get(MessageFields.ERROR) 

270 ) 

271 

272@dataclass(frozen=True) 

273class ROIMessage: 

274 """Message for streaming ROIs to viewers (Napari/Fiji). 

275 

276 Sent via ZMQ to viewer servers to display ROIs in real-time. 

277 """ 

278 rois: list # List of ROI dictionaries with shapes and metadata 

279 layer_name: str = "ROIs" # Name of the layer/overlay 

280 

281 def to_dict(self): 

282 return { 

283 MessageFields.TYPE: "rois", 

284 MessageFields.ROIS: self.rois, 

285 MessageFields.LAYER_NAME: self.layer_name 

286 } 

287 

288 @classmethod 

289 def from_dict(cls, data): 

290 return cls( 

291 rois=data[MessageFields.ROIS], 

292 layer_name=data.get(MessageFields.LAYER_NAME, "ROIs") 

293 ) 

294 

295 

296@dataclass(frozen=True) 

297class ShapesMessage: 

298 """Message for Napari shapes layer. 

299 

300 Napari-specific format for displaying polygon/ellipse shapes. 

301 """ 

302 shapes: list # List of shape dictionaries with type, coordinates, metadata 

303 layer_name: str = "ROIs" 

304 

305 def to_dict(self): 

306 return { 

307 MessageFields.TYPE: "shapes", 

308 MessageFields.SHAPES: self.shapes, 

309 MessageFields.LAYER_NAME: self.layer_name 

310 } 

311 

312 @classmethod 

313 def from_dict(cls, data): 

314 return cls( 

315 shapes=data[MessageFields.SHAPES], 

316 layer_name=data.get(MessageFields.LAYER_NAME, "ROIs") 

317 ) 

318 

319