Coverage for ezstitcher/core/opera_phenix_xml_parser.py: 58%

203 statements  

« prev     ^ index     » next       coverage.py v7.3.2, created at 2025-04-30 13:20 +0000

1""" 

2Opera Phenix XML parser for ezstitcher. 

3 

4This module provides a class for parsing Opera Phenix Index.xml files. 

5""" 

6 

7import logging 

8import xml.etree.ElementTree as ET 

9from pathlib import Path 

10from typing import Dict, Optional, Union, Any, Tuple 

11import re 

12import numpy as np 

13 

14logger = logging.getLogger(__name__) 

15 

16 

17class OperaPhenixXmlParser: 

18 """Parser for Opera Phenix Index.xml files.""" 

19 

20 def __init__(self, xml_path: Union[str, Path]): 

21 """ 

22 Initialize the parser with the path to the Index.xml file. 

23 

24 Args: 

25 xml_path: Path to the Index.xml file 

26 """ 

27 self.xml_path = Path(xml_path) 

28 self.tree = None 

29 self.root = None 

30 self.namespace = "" 

31 self._parse_xml() 

32 

33 def _parse_xml(self): 

34 """Parse the XML file and extract the namespace.""" 

35 try: 

36 self.tree = ET.parse(self.xml_path) 

37 self.root = self.tree.getroot() 

38 

39 # Extract namespace from the root tag 

40 match = re.match(r'{.*}', self.root.tag) 

41 self.namespace = match.group(0) if match else "" 

42 

43 logger.info("Parsed Opera Phenix XML file: %s", self.xml_path) 

44 logger.debug("XML namespace: %s", self.namespace) 

45 except Exception as e: 

46 logger.error("Error parsing Opera Phenix XML file %s: %s", self.xml_path, e) 

47 raise 

48 

49 def get_plate_info(self) -> Dict[str, Any]: 

50 """ 

51 Extract plate information from the XML. 

52 

53 Returns: 

54 Dict containing plate information 

55 """ 

56 if self.root is None: 

57 return {} 

58 

59 plate_elem = self.root.find(f".//{self.namespace}Plate") 

60 if plate_elem is None: 

61 logger.warning("No Plate element found in XML") 

62 return {} 

63 

64 plate_info = { 

65 'plate_id': self._get_element_text(plate_elem, 'PlateID'), 

66 'measurement_id': self._get_element_text(plate_elem, 'MeasurementID'), 

67 'plate_type': self._get_element_text(plate_elem, 'PlateTypeName'), 

68 'rows': int(self._get_element_text(plate_elem, 'PlateRows') or 0), 

69 'columns': int(self._get_element_text(plate_elem, 'PlateColumns') or 0), 

70 } 

71 

72 # Get well IDs 

73 well_elems = plate_elem.findall(f"{self.namespace}Well") 

74 plate_info['wells'] = [well.get('id') for well in well_elems if well.get('id')] 

75 

76 logger.debug("Plate info: %s", plate_info) 

77 return plate_info 

78 

79 def get_grid_size(self) -> Tuple[int, int]: 

80 """ 

81 Determine the grid size (number of fields per well) by analyzing image positions. 

82 

83 This method analyzes the positions of images for a single well, channel, and plane 

84 to determine the grid dimensions. 

85 

86 Returns: 

87 Tuple of (grid_size_x, grid_size_y) 

88 """ 

89 if self.root is None: 

90 logger.error("XML not parsed, cannot determine grid size") 

91 return (2, 2) # Default grid size 

92 

93 # Get all image elements 

94 image_elements = self.root.findall(f".//{self.namespace}Image") 

95 

96 if not image_elements: 

97 logger.warning("No Image elements found in XML") 

98 return (2, 2) # Default grid size 

99 

100 # Group images by well (Row+Col), channel, and plane 

101 # We'll use the first group with multiple fields to determine grid size 

102 image_groups = {} 

103 

104 for image in image_elements: 

105 # Extract well, channel, and plane information 

106 row_elem = image.find(f"{self.namespace}Row") 

107 col_elem = image.find(f"{self.namespace}Col") 

108 channel_elem = image.find(f"{self.namespace}ChannelID") 

109 plane_elem = image.find(f"{self.namespace}PlaneID") 

110 

111 if (row_elem is not None and row_elem.text and 

112 col_elem is not None and col_elem.text and 

113 channel_elem is not None and channel_elem.text and 

114 plane_elem is not None and plane_elem.text): 

115 

116 # Create a key for grouping 

117 group_key = f"R{row_elem.text}C{col_elem.text}_CH{channel_elem.text}_P{plane_elem.text}" 

118 

119 # Extract position information 

120 pos_x_elem = image.find(f"{self.namespace}PositionX") 

121 pos_y_elem = image.find(f"{self.namespace}PositionY") 

122 field_elem = image.find(f"{self.namespace}FieldID") 

123 

124 if (pos_x_elem is not None and pos_x_elem.text and 

125 pos_y_elem is not None and pos_y_elem.text and 

126 field_elem is not None and field_elem.text): 

127 

128 try: 

129 # Parse position values 

130 x_value = float(pos_x_elem.text) 

131 y_value = float(pos_y_elem.text) 

132 field_id = int(field_elem.text) 

133 

134 # Add to group 

135 if group_key not in image_groups: 

136 image_groups[group_key] = [] 

137 

138 image_groups[group_key].append({ 

139 'field_id': field_id, 

140 'pos_x': x_value, 

141 'pos_y': y_value, 

142 'pos_x_unit': pos_x_elem.get('Unit', ''), 

143 'pos_y_unit': pos_y_elem.get('Unit', '') 

144 }) 

145 except (ValueError, TypeError): 

146 logger.warning("Could not parse position values for image in group %s", group_key) 

147 

148 # Find the first group with multiple fields 

149 for group_key, images in image_groups.items(): 

150 if len(images) > 1: 

151 logger.debug("Using image group %s with %d fields to determine grid size", group_key, len(images)) 

152 

153 # Extract unique X and Y positions 

154 # Use a small epsilon for floating point comparison 

155 epsilon = 1e-10 

156 x_positions = [img['pos_x'] for img in images] 

157 y_positions = [img['pos_y'] for img in images] 

158 

159 # Use numpy to find unique positions 

160 unique_x = np.unique(np.round(np.array(x_positions) / epsilon) * epsilon) 

161 unique_y = np.unique(np.round(np.array(y_positions) / epsilon) * epsilon) 

162 

163 # Count unique positions 

164 num_x_positions = len(unique_x) 

165 num_y_positions = len(unique_y) 

166 

167 # If we have a reasonable number of positions, use them as grid dimensions 

168 if num_x_positions > 0 and num_y_positions > 0: 

169 logger.info("Determined grid size from positions: %dx%d", num_x_positions, num_y_positions) 

170 return (num_x_positions, num_y_positions) 

171 

172 # Alternative approach: try to infer grid size from field IDs 

173 if len(images) > 1: 

174 # Sort images by field ID 

175 sorted_images = sorted(images, key=lambda x: x['field_id']) 

176 max_field_id = sorted_images[-1]['field_id'] 

177 

178 # Try to determine if it's a square grid 

179 grid_size = int(np.sqrt(max_field_id) + 0.5) # Round to nearest integer 

180 

181 if grid_size ** 2 == max_field_id: 

182 logger.info("Determined square grid size from field IDs: %dx%d", grid_size, grid_size) 

183 return (grid_size, grid_size) 

184 

185 # If not a perfect square, try to find factors 

186 for i in range(1, int(np.sqrt(max_field_id)) + 1): 

187 if max_field_id % i == 0: 

188 j = max_field_id // i 

189 logger.info("Determined grid size from field IDs: %dx%d", i, j) 

190 return (i, j) 

191 

192 # If we couldn't determine grid size, use a default 

193 logger.warning("Could not determine grid size from XML, using default 2x2") 

194 return (2, 2) # Default grid size 

195 

196 def get_pixel_size(self) -> float: 

197 """ 

198 Extract pixel size from the XML. 

199 

200 The pixel size is stored in ImageResolutionX/Y elements with Unit="m". 

201 

202 Returns: 

203 Pixel size in micrometers (μm) 

204 """ 

205 if self.root is None: 

206 logger.warning("XML not parsed, using default pixel size") 

207 return 0.65 # Default value in micrometers 

208 

209 # Try to find ImageResolutionX element 

210 resolution_x = self.root.find(f".//{self.namespace}ImageResolutionX") 

211 if resolution_x is not None and resolution_x.text: 

212 try: 

213 # Convert from meters to micrometers 

214 pixel_size = float(resolution_x.text) * 1e6 

215 logger.info("Found pixel size from ImageResolutionX: %.4f μm", pixel_size) 

216 return pixel_size 

217 except (ValueError, TypeError): 

218 logger.warning("Could not parse pixel size from ImageResolutionX") 

219 

220 # If not found in ImageResolutionX, try ImageResolutionY 

221 resolution_y = self.root.find(f".//{self.namespace}ImageResolutionY") 

222 if resolution_y is not None and resolution_y.text: 

223 try: 

224 # Convert from meters to micrometers 

225 pixel_size = float(resolution_y.text) * 1e6 

226 logger.info("Found pixel size from ImageResolutionY: %.4f μm", pixel_size) 

227 return pixel_size 

228 except (ValueError, TypeError): 

229 logger.warning("Could not parse pixel size from ImageResolutionY") 

230 

231 # If not found, use default value 

232 logger.warning("Pixel size not found in XML, using default value of 0.65 μm") 

233 return 0.65 # Default value in micrometers 

234 

235 

236 

237 def get_image_info(self) -> Dict[str, Dict[str, Any]]: 

238 """ 

239 Extract image information from the XML. 

240 

241 Returns: 

242 Dictionary mapping image IDs to dictionaries containing image information 

243 """ 

244 if self.root is None: 

245 return {} 

246 

247 # Look for Image elements 

248 image_elems = self.root.findall(f".//{self.namespace}Image[@Version]") 

249 if not image_elems: 

250 logger.warning("No Image elements with Version attribute found in XML") 

251 return {} 

252 

253 image_info = {} 

254 for image in image_elems: 

255 image_id = self._get_element_text(image, 'id') 

256 if image_id: 

257 image_data = { 

258 'url': self._get_element_text(image, 'URL'), 

259 'row': int(self._get_element_text(image, 'Row') or 0), 

260 'col': int(self._get_element_text(image, 'Col') or 0), 

261 'field_id': int(self._get_element_text(image, 'FieldID') or 0), 

262 'plane_id': int(self._get_element_text(image, 'PlaneID') or 0), 

263 'channel_id': int(self._get_element_text(image, 'ChannelID') or 0), 

264 'position_x': self._get_element_text(image, 'PositionX'), 

265 'position_y': self._get_element_text(image, 'PositionY'), 

266 'position_z': self._get_element_text(image, 'PositionZ'), 

267 } 

268 image_info[image_id] = image_data 

269 

270 logger.debug("Found %d images in XML", len(image_info)) 

271 return image_info 

272 

273 

274 

275 def get_well_positions(self) -> Dict[str, Tuple[int, int]]: 

276 """ 

277 Extract well positions from the XML. 

278 

279 Returns: 

280 Dictionary mapping well IDs to (row, column) tuples 

281 """ 

282 if self.root is None: 

283 return {} 

284 

285 # Look for Well elements 

286 well_elems = self.root.findall(f".//{self.namespace}Wells/{self.namespace}Well") 

287 if not well_elems: 

288 logger.warning("No Well elements found in XML") 

289 return {} 

290 

291 well_positions = {} 

292 for well in well_elems: 

293 well_id = self._get_element_text(well, 'id') 

294 row = self._get_element_text(well, 'Row') 

295 col = self._get_element_text(well, 'Col') 

296 

297 if well_id and row and col: 

298 well_positions[well_id] = (int(row), int(col)) 

299 

300 logger.debug("Well positions: %s", well_positions) 

301 return well_positions 

302 

303 def _get_element_text(self, parent_elem, tag_name: str) -> Optional[str]: 

304 """Helper method to get element text with namespace.""" 

305 elem = parent_elem.find(f"{self.namespace}{tag_name}") 

306 return elem.text if elem is not None else None 

307 

308 def _get_element_attribute(self, parent_elem, tag_name: str, attr_name: str) -> Optional[str]: 

309 """Helper method to get element attribute with namespace.""" 

310 elem = parent_elem.find(f"{self.namespace}{tag_name}") 

311 return elem.get(attr_name) if elem is not None else None 

312 

313 def get_field_positions(self) -> Dict[int, Tuple[float, float]]: 

314 """ 

315 Extract field IDs and their X,Y positions from the Index.xml file. 

316 

317 Returns: 

318 dict: Mapping of field IDs to (x, y) position tuples 

319 """ 

320 field_positions = {} 

321 

322 # Find all Image elements 

323 image_elems = self.root.findall(f".//{self.namespace}Image") 

324 

325 for image in image_elems: 

326 # Check if this element has FieldID, PositionX, and PositionY children 

327 field_id_elem = image.find(f"{self.namespace}FieldID") 

328 pos_x_elem = image.find(f"{self.namespace}PositionX") 

329 pos_y_elem = image.find(f"{self.namespace}PositionY") 

330 

331 if field_id_elem is not None and pos_x_elem is not None and pos_y_elem is not None: 

332 try: 

333 field_id = int(field_id_elem.text) 

334 pos_x = float(pos_x_elem.text) 

335 pos_y = float(pos_y_elem.text) 

336 

337 # Only add if we don't already have this field ID 

338 if field_id not in field_positions: 

339 field_positions[field_id] = (pos_x, pos_y) 

340 except (ValueError, TypeError): 

341 # Skip entries with invalid data 

342 continue 

343 

344 return field_positions 

345 

346 def sort_fields_by_position(self, positions: Dict[int, Tuple[float, float]]) -> list: 

347 """ 

348 Sort fields based on their positions in a raster pattern starting from the top. 

349 All rows go left-to-right in a consistent raster scan pattern. 

350 

351 Args: 

352 positions: Dictionary mapping field IDs to (x, y) position tuples 

353 

354 Returns: 

355 list: Field IDs sorted in raster pattern order starting from the top 

356 """ 

357 if not positions: 

358 return [] 

359 

360 # Get all unique x and y coordinates 

361 x_coords = sorted(set(pos[0] for pos in positions.values())) 

362 y_coords = sorted(set(pos[1] for pos in positions.values()), reverse=True) # Reverse to get top row first 

363 

364 # Create a grid of field IDs 

365 grid = {} 

366 for field_id, (x, y) in positions.items(): 

367 # Find the closest x and y coordinates in our sorted lists 

368 x_idx = x_coords.index(x) 

369 y_idx = y_coords.index(y) # This will now map top row to index 0 

370 grid[(x_idx, y_idx)] = field_id 

371 

372 # Debug output to help diagnose field mapping issues 

373 logger.info("Field position grid:") 

374 for y_idx in range(len(y_coords)): 

375 row_str = "" 

376 for x_idx in range(len(x_coords)): 

377 field_id = grid.get((x_idx, y_idx), 0) 

378 row_str += f"{field_id:3d} " 

379 logger.info(row_str) 

380 

381 # Sort field IDs by row (y) then column (x) 

382 # Use raster pattern: all rows go left-to-right in a consistent pattern 

383 sorted_field_ids = [] 

384 for y_idx in range(len(y_coords)): 

385 row_fields = [] 

386 # All rows go left to right in a raster pattern 

387 x_range = range(len(x_coords)) 

388 

389 for x_idx in x_range: 

390 if (x_idx, y_idx) in grid: 

391 row_fields.append(grid[(x_idx, y_idx)]) 

392 sorted_field_ids.extend(row_fields) 

393 

394 return sorted_field_ids 

395 

396 def get_field_id_mapping(self) -> Dict[int, int]: 

397 """ 

398 Generate a mapping from original field IDs to new field IDs based on position data. 

399 

400 Returns: 

401 dict: Mapping of original field IDs to new field IDs 

402 """ 

403 # Get field positions 

404 field_positions = self.get_field_positions() 

405 

406 # Sort fields by position 

407 sorted_field_ids = self.sort_fields_by_position(field_positions) 

408 

409 # Create mapping from original to new field IDs 

410 return {field_id: i + 1 for i, field_id in enumerate(sorted_field_ids)} 

411 

412 def remap_field_id(self, field_id: int, mapping: Optional[Dict[int, int]] = None) -> int: 

413 """ 

414 Remap a field ID using the position-based mapping. 

415 

416 Args: 

417 field_id: Original field ID 

418 mapping: Mapping to use. If None, generates a new mapping. 

419 

420 Returns: 

421 int: New field ID, or original if not in mapping 

422 """ 

423 if mapping is None: 

424 mapping = self.get_field_id_mapping() 

425 

426 return mapping.get(field_id, field_id)