Coverage for ezstitcher/core/stitcher.py: 85%
174 statements
« prev ^ index » next coverage.py v7.3.2, created at 2025-04-30 13:20 +0000
« prev ^ index » next coverage.py v7.3.2, created at 2025-04-30 13:20 +0000
1"""
2Stitcher module for ezstitcher.
4This module contains the Stitcher class for handling image stitching operations.
5"""
7import re
8import os
9import logging
10from pathlib import Path
11from typing import List, Optional, Union
12from scipy.ndimage import shift as subpixel_shift
14import numpy as np
15import pandas as pd
16from ashlar import fileseries, reg
18from ezstitcher.core.config import StitcherConfig
19from ezstitcher.core.file_system_manager import FileSystemManager
20from ezstitcher.core.image_processor import create_linear_weight_mask
21from ezstitcher.core.microscope_interfaces import FilenameParser
23logger = logging.getLogger(__name__)
29class Stitcher:
30 """
31 Class for handling image stitching operations.
32 """
34 def __init__(self, config: Optional[StitcherConfig] = None, filename_parser: Optional[FilenameParser] = None):
35 """
36 Initialize the Stitcher.
38 Args:
39 config (StitcherConfig): Configuration for stitching
40 filename_parser (FilenameParser): Parser for microscopy filenames
41 """
42 self.config = config or StitcherConfig()
43 self.fs_manager = FileSystemManager()
44 self.filename_parser = filename_parser
46 def generate_positions_df(self, image_dir, image_pattern, positions, grid_size_x, grid_size_y):
47 """
48 Given an image_dir, an image_pattern (with '{iii}' or similar placeholder)
49 and a list of (x, y) tuples 'positions', build a DataFrame with lines like:
51 file: <filename>; position: (x, y); grid: (col, row);
52 """
53 all_files = self.filename_parser.path_list_from_pattern(image_dir, image_pattern)
54 if len(all_files) != len(positions):
55 raise ValueError(f"File/position count mismatch: {len(all_files)}≠{len(positions)}")
57 # Check if grid size matches the number of files
58 total_grid_size = grid_size_x * grid_size_y
59 if total_grid_size != len(all_files):
60 # Raise an error if the grid size doesn't match the number of files
61 raise ValueError(f"Grid size mismatch: {grid_size_x}×{grid_size_y}≠{len(all_files)}")
63 # Generate a list of (x, y) grid positions following a raster pattern
64 positions_grid = [(x, y) for y in range(grid_size_y) for x in range(grid_size_x)]
66 # Ensure we don't try to access beyond the available positions
67 num_positions = min(len(all_files), len(positions), len(positions_grid))
68 data_rows = []
70 for i in range(num_positions):
71 fname = all_files[i]
72 x, y = positions[i]
73 row, col = positions_grid[i]
75 data_rows.append({
76 "file": "file: " + fname,
77 "grid": " grid: " + "("+str(row)+", "+str(col)+")",
78 "position": " position: " + "("+str(x)+", "+str(y)+")",
79 })
81 df = pd.DataFrame(data_rows)
82 return df
85 def generate_positions(self, image_dir: Union[str, Path],
86 image_pattern: str,
87 positions_path: Union[str, Path],
88 grid_size_x: int,
89 grid_size_y: int) -> bool:
90 """
91 Generate positions for stitching using Ashlar.
93 Args:
94 image_dir (str or Path): Directory containing images
95 image_pattern (str): Pattern with '{iii}' placeholder
96 positions_path (str or Path): Path to save positions CSV
97 grid_size_x (int): Number of tiles horizontally
98 grid_size_y (int): Number of tiles vertically
100 Returns:
101 bool: True if successful, False otherwise
102 """
103 return self._generate_positions_ashlar(image_dir, image_pattern, positions_path, grid_size_x, grid_size_y)
105 def _generate_positions_ashlar(self, image_dir: Union[str, Path],
106 image_pattern: str,
107 positions_path: Union[str, Path],
108 grid_size_x: int,
109 grid_size_y: int) -> bool:
110 """
111 Generate positions for stitching using Ashlar.
113 Args:
114 image_dir (str or Path): Directory containing images
115 image_pattern (str): Pattern with '{iii}' placeholder
116 positions_path (str or Path): Path to save positions CSV
117 grid_size_x (int): Number of tiles horizontally
118 grid_size_y (int): Number of tiles vertically
120 Returns:
121 bool: True if successful, False otherwise
122 """
123 try:
124 image_dir = Path(image_dir)
125 positions_path = Path(positions_path)
127 # Get tile overlap from config
128 tile_overlap = self.config.tile_overlap
129 max_shift = self.config.max_shift
130 pixel_size = self.config.pixel_size
132 # Deprecated code removed - we now only use tile_overlap
134 # Convert overlap from percentage to fraction
135 overlap = tile_overlap / 100.0
137 # Replace {iii} with {series} for Ashlar
138 ashlar_pattern = image_pattern.replace("{iii}", "{series}")
139 logger.info(f"Using pattern: {ashlar_pattern} for ashlar")
141 # Check if the pattern has .tif extension, but files have .tiff extension
142 if (image_pattern.endswith('.tif') and
143 not self.filename_parser.path_list_from_pattern(image_dir, image_pattern)):
144 # Try with .tiff extension
145 tiff_pattern = image_pattern[:-4] + '.tiff'
146 if self.filename_parser.path_list_from_pattern(image_dir, tiff_pattern):
147 image_pattern = tiff_pattern
148 ashlar_pattern = image_pattern.replace("{iii}", "{series}")
149 logger.info(f"Updated pattern to: {ashlar_pattern} for ashlar")
151 # Check if there are enough files for the grid size
152 files = self.filename_parser.path_list_from_pattern(image_dir, image_pattern)
155 if len(files) != grid_size_x * grid_size_y:
156 raise ValueError(f"Grid size mismatch: {grid_size_x}×{grid_size_y}≠{len(files)}")
158 # Create a FileSeriesReader for the images
159 fs_reader = fileseries.FileSeriesReader(
160 path=str(image_dir),
161 pattern=ashlar_pattern,
162 overlap=overlap, # Using single overlap value for now
163 width=grid_size_x,
164 height=grid_size_y,
165 layout="raster",
166 direction="horizontal",
167 pixel_size=pixel_size,
168 )
170 # Align the tiles using EdgeAligner
171 aligner = reg.EdgeAligner(
172 fs_reader,
173 channel=0, # If multi-channel, pick the channel to align on
174 filter_sigma=0, # adjust if needed
175 verbose=True,
176 max_shift=max_shift
177 )
178 aligner.run()
180 # Build a Mosaic from the alignment
181 mosaic_args = {
182 'verbose': True,
183 'flip_mosaic_y': False, # if your final mosaic needs flipping
184 # 'num_workers': 1 # This parameter is not supported by Ashlar's Mosaic class
185 }
186 mosaic = reg.Mosaic(
187 aligner,
188 aligner.mosaic_shape,
189 **mosaic_args
190 )
192 # Extract positions and generate CSV
193 positions = [(y, x) for x, y in mosaic.aligner.positions]
195 # Use the original pattern (with {iii} instead of {series})
196 original_pattern = image_pattern.replace("{series}", "{iii}")
198 # Generate positions DataFrame
199 positions_df = self.generate_positions_df(str(image_dir), original_pattern, positions, grid_size_x, grid_size_y)
201 # Save to CSV
202 self.save_positions_df(positions_df, positions_path)
204 logger.info("Saved positions to %s", positions_path)
205 return True
207 except Exception as e:
208 logger.error("Error in generate_positions_ashlar: %s", e)
209 return False
211 @staticmethod
212 def parse_positions_csv(csv_path):
213 """
214 Parse a CSV file with lines of the form:
215 file: <filename>; grid: (col, row); position: (x, y)
217 Args:
218 csv_path (str or Path): Path to the CSV file
220 Returns:
221 list: List of tuples (filename, x_float, y_float)
222 """
223 entries = []
224 with open(csv_path, 'r', encoding='utf-8') as fh:
225 for line in fh:
226 line = line.strip()
227 if not line:
228 continue
229 # Example line:
230 # file: some_image.tif; grid: (0, 0); position: (123.45, 67.89)
231 file_match = re.search(r'file:\s*([^;]+);', line)
232 pos_match = re.search(r'position:\s*\(([^,]+),\s*([^)]+)\)', line)
233 if file_match and pos_match:
234 fname = file_match.group(1).strip()
235 x_val = float(pos_match.group(1).strip())
236 y_val = float(pos_match.group(2).strip())
237 entries.append((fname, x_val, y_val))
238 return entries
240 @staticmethod
241 def save_positions_df(df, positions_path):
242 """
243 Save a positions DataFrame to CSV.
245 Args:
246 df (pandas.DataFrame): DataFrame to save
247 positions_path (str or Path): Path to save the CSV file
249 Returns:
250 bool: True if successful, False otherwise
251 """
252 try:
253 # Ensure directory exists
254 Path(positions_path).parent.mkdir(parents=True, exist_ok=True)
256 # Save to CSV
257 df.to_csv(positions_path, index=False, sep=";", header=False)
258 return True
259 except Exception as e:
260 logger.error("Error saving positions CSV: %s", e)
261 return False
263 def assemble_image(self, positions_path: Union[str, Path],
264 images_dir: Union[str, Path],
265 output_path: Union[str, Path],
266 override_names: Optional[List[str]] = None) -> bool:
267 """
268 Assemble a stitched image using subpixel positions from a CSV file.
270 Args:
271 positions_path (str or Path): Path to the CSV with subpixel positions
272 images_dir (str or Path): Directory containing image tiles
273 output_path (str or Path): Path to save final stitched image
274 override_names (list): Optional list of filenames to use instead of those in CSV
276 Returns:
277 bool: True if successful, False otherwise
278 """
279 try:
280 # Get margin ratio from config
281 margin_ratio = self.config.margin_ratio
283 # Ensure output directory exists
284 output_path = Path(output_path)
285 output_dir = output_path.parent
286 self.fs_manager.ensure_directory(output_dir)
287 logger.info("Ensured output directory exists: %s", output_dir)
289 # Parse CSV file
290 pos_entries = self.parse_positions_csv(positions_path)
291 if not pos_entries:
292 logger.error("No valid entries found in %s", positions_path)
293 return False
295 # Override filenames if provided
296 if override_names is not None:
297 if len(override_names) != len(pos_entries):
298 raise ValueError(f"Override names/positions mismatch: {len(override_names)}≠{len(pos_entries)}")
300 pos_entries = [(override_names[i], x, y) for i, (_, x, y) in enumerate(pos_entries)]
302 # Check tile existence
303 images_dir = Path(images_dir)
304 for (fname, _, _) in pos_entries:
305 if not (images_dir / fname).exists():
306 logger.error("Missing image: %s in %s", fname, images_dir)
307 return False
309 # Read the first tile to get shape, dtype
310 first_tile = self.fs_manager.load_image(images_dir / pos_entries[0][0])
311 if first_tile is None:
312 logger.error("Failed to load first tile: %s", pos_entries[0][0])
313 return False
315 tile_h, tile_w = first_tile.shape
316 dtype = first_tile.dtype
318 # Compute bounding box
319 x_vals = [x_f for _, x_f, _ in pos_entries]
320 y_vals = [y_f for _, _, y_f in pos_entries]
322 min_x = min(x_vals)
323 max_x = max(x_vals) + tile_w
324 min_y = min(y_vals)
325 max_y = max(y_vals) + tile_h
327 # Final canvas size
328 final_w = int(np.ceil(max_x - min_x))
329 final_h = int(np.ceil(max_y - min_y))
330 logger.info("Final canvas size: %d x %d", final_h, final_w)
332 # Prepare accumulators
333 acc = np.zeros((final_h, final_w), dtype=np.float32)
334 weight_acc = np.zeros((final_h, final_w), dtype=np.float32)
336 # Prepare the tile mask
337 base_mask = create_linear_weight_mask(tile_h, tile_w, margin_ratio=margin_ratio)
339 # Process each tile
340 for i, (fname, x_f, y_f) in enumerate(pos_entries):
341 logger.info("Placing tile %d/%d: %s at (%.2f, %.2f)", i+1, len(pos_entries), fname, x_f, y_f)
343 # Load tile
344 tile_img = self.fs_manager.load_image(images_dir / fname)
345 if tile_img is None:
346 logger.error("Failed to load tile: %s", fname)
347 continue
349 # Check shape and dtype
350 if tile_img.shape != (tile_h, tile_w):
351 logger.error("Tile shape mismatch: %s vs %dx%d", tile_img.shape, tile_h, tile_w)
352 continue
354 if tile_img.dtype != dtype:
355 logger.error("Tile dtype mismatch: %s vs %s", tile_img.dtype, dtype)
356 continue
358 # Apply weight mask
359 tile_float = tile_img.astype(np.float32)
360 weighted_tile = tile_float * base_mask
362 # Separate offset into integer + fractional
363 shift_x = x_f - min_x
364 shift_y = y_f - min_y
365 int_x = int(np.floor(shift_x))
366 int_y = int(np.floor(shift_y))
367 frac_x = shift_x - int_x
368 frac_y = shift_y - int_y
370 # Shift by fractional portion
371 shifted_tile = subpixel_shift(
372 weighted_tile,
373 shift=(frac_y, frac_x),
374 order=1,
375 mode='constant',
376 cval=0
377 )
379 shifted_mask = subpixel_shift(
380 base_mask,
381 shift=(frac_y, frac_x),
382 order=1,
383 mode='constant',
384 cval=0
385 )
387 # Place at integer offset
388 y_start = int_y
389 x_start = int_x
390 y_end = y_start + tile_h
391 x_end = x_start + tile_w
393 # Accumulate
394 acc[y_start:y_end, x_start:x_end] += shifted_tile
395 weight_acc[y_start:y_end, x_start:x_end] += shifted_mask
397 # Final blend
398 safe_weight = np.where(weight_acc == 0, 1, weight_acc)
399 blended = acc / safe_weight
401 # Clip to original dtype
402 if np.issubdtype(dtype, np.integer):
403 max_val = np.iinfo(dtype).max
404 else:
405 max_val = np.finfo(dtype).max
407 blended = np.clip(blended, 0, max_val).astype(dtype)
409 # Save stitched image
410 logger.info("Saving stitched image to %s", output_path)
411 self.fs_manager.save_image(output_path, blended)
413 return True
415 except Exception as e:
416 logger.error("Error in assemble_image: %s", e)
417 return False