Coverage for ezstitcher/core/steps.py: 78%

277 statements  

« prev     ^ index     » next       coverage.py v7.3.2, created at 2025-04-30 13:20 +0000

1""" 

2Step classes for the pipeline architecture. 

3 

4This module contains the Step class and all specialized step implementations for 

5different types of processing operations, including: 

6 

71. Base Step class for general-purpose processing 

82. Step factories (ZFlatStep, FocusStep, CompositeStep) for common operations 

93. Specialized steps (PositionGenerationStep, ImageStitchingStep) for specific tasks 

10 

11For conceptual explanation, see the documentation at: 

12https://ezstitcher.readthedocs.io/en/latest/concepts/step.html 

13""" 

14 

15from typing import Dict, List, Union, Callable, Any, TypeVar, Optional, Sequence, Tuple 

16import logging 

17from pathlib import Path 

18import numpy as np 

19 

20# Import core components 

21from ezstitcher.core.file_system_manager import FileSystemManager 

22from ezstitcher.core.utils import prepare_patterns_and_functions 

23from ezstitcher.core.abstract_step import AbstractStep 

24from ezstitcher.core.image_processor import ImageProcessor as IP 

25from ezstitcher.core.focus_analyzer import FocusAnalyzer 

26# Removed adapt_func_to_stack import 

27 

28 

29# Type definitions 

30# Note: All functions in ProcessingFunc are now expected to accept List[np.ndarray] 

31# and return List[np.ndarray]. Use utils.stack() to wrap single-image functions. 

32FunctionType = Callable[[List[np.ndarray], ...], List[np.ndarray]] 

33# A function can be a callable or a tuple of (callable, kwargs) 

34FunctionWithArgs = Union[FunctionType, Tuple[FunctionType, Dict[str, Any]]] 

35ProcessingFunc = Union[FunctionWithArgs, Dict[str, FunctionWithArgs], List[FunctionWithArgs]] 

36VariableComponents = List[str] 

37GroupBy = Optional[str] 

38WellFilter = Optional[List[str]] 

39T = TypeVar('T') # For generic return types 

40 

41# Configure logging 

42logger = logging.getLogger(__name__) 

43 

44 

45class Step(AbstractStep): 

46 """ 

47 A processing step in a pipeline. 

48 

49 A Step encapsulates a processing operation that can be applied to images. 

50 It mirrors the functionality of process_patterns_with_variable_components 

51 while providing a more object-oriented interface. 

52 

53 Attributes: 

54 func: The processing function(s) to apply 

55 variable_components: Components that vary across files (e.g., 'z_index', 'channel') 

56 group_by: How to group files for processing (e.g., 'channel', 'site') 

57 input_dir: The input directory 

58 output_dir: The output directory 

59 well_filter: Wells to process 

60 processing_args: Additional arguments to pass to the processing function 

61 name: Human-readable name for the step 

62 """ 

63 

64 def __init__( 

65 self, 

66 func: ProcessingFunc, 

67 variable_components: VariableComponents = ['site'], 

68 group_by: GroupBy = None, 

69 input_dir: str = None, 

70 output_dir: str = None, 

71 well_filter: WellFilter = None, 

72 name: str = None 

73 ): 

74 """ 

75 Initialize a processing step. 

76 

77 Args: 

78 func: The processing function(s) to apply. Can be: 

79 - A single callable function 

80 - A tuple of (function, kwargs) 

81 - A list of functions or (function, kwargs) tuples 

82 - A dictionary mapping component values to functions or tuples 

83 variable_components: Components that vary across files 

84 group_by: How to group files for processing 

85 input_dir: The input directory 

86 output_dir: The output directory 

87 well_filter: Wells to process 

88 name: Human-readable name for the step 

89 """ 

90 self.func = func 

91 self.variable_components = variable_components 

92 self.group_by = group_by 

93 self._input_dir = input_dir 

94 self._output_dir = output_dir 

95 self.well_filter = well_filter 

96 self._name = name or self._generate_name() 

97 

98 def _generate_name(self) -> str: 

99 """ 

100 Generate a descriptive name based on the function. 

101 

102 Returns: 

103 A human-readable name for the step 

104 """ 

105 # Helper function to get name from function or function tuple 

106 def get_func_name(f): 

107 if isinstance(f, tuple) and len(f) == 2 and callable(f[0]): 

108 return getattr(f[0], '__name__', str(f[0])) 

109 if callable(f): 

110 return getattr(f, '__name__', str(f)) 

111 return str(f) 

112 

113 # Dictionary of functions 

114 if isinstance(self.func, dict): 

115 funcs = ", ".join(f"{k}:{get_func_name(f)}" for k, f in self.func.items()) 

116 return f"ChannelMappedStep({funcs})" 

117 

118 # List of functions 

119 if isinstance(self.func, list): 

120 funcs = ", ".join(get_func_name(f) for f in self.func) 

121 return f"MultiStep({funcs})" 

122 

123 # Single function or function tuple 

124 return f"Step({get_func_name(self.func)})" 

125 

126 def process(self, group: List[Any], context: Optional[Dict[str, Any]] = None) -> Any: 

127 """ 

128 Process a group of images. 

129 

130 This implementation of the AbstractStep.process method adapts the Step class 

131 to work with the AbstractStep interface. It processes a group of images 

132 according to the step's configuration. 

133 

134 Args: 

135 group: Group of images to process 

136 context: Pipeline context for sharing data between steps 

137 

138 Returns: 

139 Processed result (typically an image or list of images) 

140 """ 

141 # For backward compatibility, if this method is called directly with a ProcessingContext 

142 # object as the first argument, treat it as the old-style process method 

143 if hasattr(group, 'orchestrator') and context is None: 

144 return self._process_with_context_object(group) 

145 

146 # Otherwise, process the group of images directly 

147 return self._apply_processing(group) 

148 

149 def _process_with_context_object(self, context: 'ProcessingContext') -> 'ProcessingContext': 

150 """ 

151 Process the step with the given context object (legacy method). 

152 

153 This method maintains backward compatibility with the old process method 

154 that takes a ProcessingContext object. 

155 

156 Args: 

157 context: The processing context 

158 

159 Returns: 

160 The updated processing context 

161 """ 

162 logger.info("Processing step: %s", self.name) 

163 

164 # Get directories and microscope handler 

165 input_dir = self.input_dir 

166 output_dir = self.output_dir 

167 well_filter = self.well_filter or context.well_filter 

168 orchestrator = context.orchestrator # Required, will raise AttributeError if missing 

169 microscope_handler = orchestrator.microscope_handler 

170 

171 if not input_dir: 

172 raise ValueError("Input directory must be specified") 

173 

174 # Find the actual directory containing images 

175 # This works whether input_dir is a plate folder or a subfolder 

176 actual_input_dir = FileSystemManager.find_image_directory(Path(input_dir)) 

177 logger.debug("Using actual image directory: %s", actual_input_dir) 

178 

179 # Get patterns with variable components 

180 patterns_by_well = microscope_handler.auto_detect_patterns( 

181 actual_input_dir, 

182 well_filter=well_filter, 

183 variable_components=self.variable_components 

184 ) 

185 

186 # Process each well 

187 results = {} 

188 for well, patterns in patterns_by_well.items(): 

189 if well_filter and well not in well_filter: 

190 continue 

191 

192 logger.info("Processing well: %s", well) 

193 well_results = {} 

194 

195 # Prepare patterns, functions, and args 

196 grouped_patterns, component_to_funcs, component_to_args = prepare_patterns_and_functions( 

197 patterns, self.func, component=self.group_by 

198 ) 

199 

200 # Process each component 

201 for component_value, component_patterns in grouped_patterns.items(): 

202 component_func = component_to_funcs[component_value] 

203 component_args = component_to_args[component_value] 

204 output_files = [] 

205 

206 # Process each pattern 

207 for pattern in component_patterns: 

208 # Find matching files 

209 matching_files = microscope_handler.parser.path_list_from_pattern(actual_input_dir, pattern) 

210 

211 # Load images 

212 try: 

213 images = [FileSystemManager.load_image(str(Path(actual_input_dir) / filename)) for filename in matching_files] 

214 images = [img for img in images if img is not None] 

215 except Exception as e: 

216 logger.error("Error loading images: %s", str(e)) 

217 images = [] 

218 

219 if not images: 

220 continue # Skip if no valid images found 

221 

222 # Process the images with component-specific args 

223 # Process the images 

224 try: 

225 images = self._apply_processing(images, func=component_func) 

226 except Exception as e: 

227 logger.error("Error applying processing function: %s", str(e)) 

228 continue 

229 

230 # Save images and get output files 

231 pattern_files = self._save_images(actual_input_dir, output_dir, images, matching_files) 

232 if pattern_files: 

233 output_files.extend(pattern_files) 

234 

235 # Store results for this component 

236 if output_files: 

237 well_results[component_value] = output_files 

238 

239 # Store results for this well 

240 results[well] = well_results 

241 

242 # Store results in context 

243 context.results = results 

244 return context 

245 

246 

247 

248 def _ensure_2d(self, img): 

249 """Ensure an image is 2D by reducing dimensions if needed.""" 

250 if not isinstance(img, np.ndarray) or img.ndim <= 2: 

251 return img 

252 

253 # Try to squeeze out singleton dimensions first 

254 squeezed = np.squeeze(img) 

255 if squeezed.ndim <= 2: 

256 return squeezed 

257 

258 # If still not 2D, take first slice until it is 

259 result = img 

260 while result.ndim > 2: 

261 result = result[0] 

262 

263 logger.warning("Reduced image dimensions from %dD to 2D", img.ndim) 

264 return result 

265 

266 def _extract_function_and_args( 

267 self, 

268 func_item: FunctionWithArgs 

269 ) -> Tuple[Callable, Dict[str, Any]]: 

270 """Extract function and arguments from a function item. 

271 

272 A function item can be either a callable or a tuple of (callable, kwargs). 

273 

274 Args: 

275 func_item: Function item to extract from 

276 

277 Returns: 

278 Tuple of (function, kwargs) 

279 """ 

280 if isinstance(func_item, tuple) and len(func_item) == 2 and callable(func_item[0]): 

281 # It's a (function, kwargs) tuple 

282 return func_item[0], func_item[1] 

283 if callable(func_item): 

284 # It's just a function, use default args 

285 return func_item, {} 

286 

287 # Invalid function item 

288 logger.warning( 

289 "Invalid function item: %s. Expected callable or (callable, kwargs) tuple.", 

290 str(func_item) 

291 ) 

292 # Return a dummy function that returns the input unchanged 

293 return lambda x, **kwargs: x, {} 

294 

295 def _apply_function_list( 

296 self, 

297 images: List[np.ndarray], 

298 function_list: List[FunctionWithArgs] 

299 ) -> List[np.ndarray]: 

300 """Apply a list of functions sequentially to images. 

301 

302 Args: 

303 images: List of images to process 

304 function_list: List of functions to apply (can include tuples of (function, kwargs)) 

305 

306 Returns: 

307 List of processed images 

308 """ 

309 processed_images = images 

310 

311 for func_item in function_list: 

312 # Extract function and args 

313 func, func_args = self._extract_function_and_args(func_item) 

314 

315 # Apply the function 

316 result = self._apply_single_function(processed_images, func, func_args) 

317 processed_images = [self._ensure_2d(img) for img in result] 

318 

319 return processed_images 

320 

321 

322 

323 def _apply_single_function( 

324 self, 

325 images: List[np.ndarray], 

326 func: Callable, 

327 args: Dict[str, Any] 

328 ) -> List[np.ndarray]: 

329 """Apply a single processing function with specific args. 

330 

331 Args: 

332 images: List of images to process 

333 func: Processing function to apply 

334 args: Arguments to pass to the function 

335 

336 Returns: 

337 List of processed images 

338 """ 

339 try: 

340 result = func(images, **args) 

341 

342 # Handle different return types 

343 if isinstance(result, list): 

344 return result 

345 if isinstance(result, np.ndarray): 

346 func_name = getattr(func, '__name__', 'unknown') 

347 

348 # Check if this is a 3D array (stack of images) 

349 if result.ndim >= 3: 

350 # Convert 3D+ array to list of 2D arrays 

351 logger.debug( 

352 "Function %s returned a 3D array. Converting to list of 2D arrays.", 

353 func_name 

354 ) 

355 return [result[i] for i in range(result.shape[0])] 

356 

357 # It's a single 2D image 

358 logger.warning( 

359 "Function %s returned a single image instead of a list. Wrapping it.", 

360 func_name 

361 ) 

362 return [result] 

363 

364 # Unexpected return type 

365 func_name = getattr(func, '__name__', 'unknown') 

366 result_type = type(result).__name__ 

367 logger.error( 

368 "Function %s returned an unexpected type (%s). Returning original images.", 

369 func_name, 

370 result_type 

371 ) 

372 return images 

373 

374 except Exception as e: 

375 func_name = getattr(func, '__name__', str(func)) 

376 logger.exception( 

377 "Error applying processing function %s: %s", 

378 func_name, 

379 e 

380 ) 

381 return images 

382 

383 def _apply_processing( 

384 self, 

385 images: List[np.ndarray], 

386 func: Optional[ProcessingFunc] = None 

387 ) -> List[np.ndarray]: 

388 """Apply processing function(s) to a stack (list) of images. 

389 

390 Note: This method only handles single functions or lists of functions. 

391 Dictionary mapping of functions to component values is handled by 

392 prepare_patterns_and_functions before this method is called. 

393 

394 Functions can be specified in several ways: 

395 - A single callable function 

396 - A tuple of (function, kwargs) 

397 - A list of functions or (function, kwargs) tuples 

398 

399 Args: 

400 images: List of images (numpy arrays) to process. 

401 func: Processing function(s) to apply. Defaults to self.func. 

402 

403 Returns: 

404 List of processed images, or the original list if an error occurs. 

405 """ 

406 # Handle empty input 

407 if not images: 

408 return [] 

409 

410 # Get processing function 

411 processing_func = func if func is not None else self.func 

412 

413 try: 

414 # Case 1: List of functions or function tuples 

415 if isinstance(processing_func, list): 

416 return self._apply_function_list(images, processing_func) 

417 

418 # Case 2: Single function or function tuple 

419 is_callable = callable(processing_func) 

420 is_func_tuple = isinstance(processing_func, tuple) and len(processing_func) == 2 

421 

422 if is_callable or is_func_tuple: 

423 func, args = self._extract_function_and_args(processing_func) 

424 return self._apply_single_function(images, func, args) 

425 

426 # Case 3: Invalid function 

427 logger.warning("No valid processing function provided. Returning original images.") 

428 return images 

429 

430 except Exception as e: 

431 # Try to get function name, but handle the case where processing_func might be a tuple 

432 if isinstance(processing_func, tuple) and callable(processing_func[0]): 

433 func_name = getattr(processing_func[0], '__name__', str(processing_func[0])) 

434 else: 

435 func_name = getattr(processing_func, '__name__', str(processing_func)) 

436 

437 logger.exception("Error applying processing function %s: %s", func_name, e) 

438 return images 

439 

440 def _save_images(self, input_dir, output_dir, images, filenames): 

441 """Save processed images. 

442 

443 Args: 

444 input_dir: Input directory 

445 output_dir: Output directory 

446 images: Images to save 

447 filenames: Filenames to use 

448 

449 Returns: 

450 list: Paths to saved images 

451 """ 

452 if not output_dir or not images or not filenames: 

453 return [] 

454 

455 try: 

456 # Ensure output directory exists 

457 FileSystemManager.ensure_directory(output_dir) 

458 

459 # Clean up old files if working in place 

460 if input_dir == output_dir: 

461 for filename in filenames: 

462 FileSystemManager.delete_file(Path(output_dir) / filename) 

463 

464 # Initialize output files list 

465 output_files = [] 

466 

467 # Convert to list if it's a single image 

468 if isinstance(images, np.ndarray): 

469 images = [images] 

470 filenames = [filenames[0]] 

471 

472 # Save each image 

473 for i, img in enumerate(images): 

474 if i < len(filenames): 

475 output_path = Path(output_dir) / filenames[i] 

476 FileSystemManager.save_image(str(output_path), img) 

477 output_files.append(str(output_path)) 

478 

479 return output_files 

480 

481 except Exception as e: 

482 logger.error("Error saving images: %s", str(e)) 

483 return [] 

484 

485 @property 

486 def name(self) -> str: 

487 """The name of this step.""" 

488 return self._name 

489 

490 @name.setter 

491 def name(self, value: str): 

492 """Set the name of this step.""" 

493 self._name = value 

494 

495 @property 

496 def input_dir(self) -> str: 

497 """The input directory for this step.""" 

498 return self._input_dir 

499 

500 @input_dir.setter 

501 def input_dir(self, value: str): 

502 """Set the input directory for this step.""" 

503 self._input_dir = value 

504 

505 @property 

506 def output_dir(self) -> str: 

507 """The output directory for this step.""" 

508 return self._output_dir 

509 

510 @output_dir.setter 

511 def output_dir(self, value: str): 

512 """Set the output directory for this step.""" 

513 self._output_dir = value 

514 

515 def __repr__(self) -> str: 

516 """ 

517 String representation of the step. 

518 

519 Returns: 

520 A human-readable representation of the step 

521 """ 

522 components = ", ".join(self.variable_components) 

523 output_dir_str = f"→ {str(self.output_dir)}" if self.output_dir else "" 

524 return f"{self.name} [components={components}, group_by={self.group_by}] {output_dir_str}" 

525 

526 

527 

528 

529class PositionGenerationStep(Step): 

530 """ 

531 A specialized Step for generating positions. 

532 

533 This step takes processed reference images and generates position files 

534 for stitching. It stores the positions file in the context for later use. 

535 """ 

536 

537 def __init__( 

538 self, 

539 name: str = "Position Generation", 

540 input_dir: Optional[Path] = None, 

541 output_dir: Optional[Path] = None # Output directory for positions files 

542 ): 

543 """ 

544 Initialize a position generation step. 

545 

546 Args: 

547 name: Name of the step 

548 input_dir: Input directory 

549 output_dir: Output directory (for positions files) 

550 """ 

551 super().__init__( 

552 func=None, # No processing function needed 

553 name=name, 

554 input_dir=input_dir, 

555 output_dir=output_dir 

556 ) 

557 

558 def process(self, group: List[Any], context: Optional[Dict[str, Any]] = None) -> Any: 

559 """ 

560 Generate positions for stitching and store them in the context. 

561 

562 This implementation adapts the specialized step to the AbstractStep interface. 

563 

564 Args: 

565 group: Group of images to process (not used in this step) 

566 context: Pipeline context for sharing data between steps 

567 

568 Returns: 

569 The processed result 

570 """ 

571 # For backward compatibility, if this method is called with a ProcessingContext 

572 # object as the first argument, treat it as the old-style process method 

573 if hasattr(group, 'orchestrator') and context is None: 

574 return self._process_with_context_object(group) 

575 

576 # This step doesn't process images directly, so return the group unchanged 

577 return group 

578 

579 def _process_with_context_object(self, context): 

580 """ 

581 Legacy method for backward compatibility. 

582 """ 

583 logger.info("Processing step: %s", self.name) 

584 

585 if self.output_dir is self.input_dir: 

586 self.output_dir = self.input_dir.parent / f"{self.input_dir.name}_positions" 

587 logger.info(f"Input and output directories are the same, using default positions directory: {self.output_dir}") 

588 

589 # Get required objects from context 

590 well = context.well_filter[0] if context.well_filter else None 

591 orchestrator = context.orchestrator # Required, will raise AttributeError if missing 

592 input_dir = self.input_dir or context.input_dir 

593 positions_dir = self.output_dir or context.output_dir 

594 

595 # Call the generate_positions method 

596 positions_file, reference_pattern = orchestrator.generate_positions(well, input_dir, positions_dir) 

597 

598 # Store in context 

599 context.positions_dir = positions_dir 

600 context.reference_pattern = reference_pattern 

601 return context 

602 

603 

604class ImageStitchingStep(Step): 

605 """ 

606 A step that stitches images using position files. 

607 

608 If input_dir is not specified, it will use the pipeline's input directory by default. 

609 If positions_dir is not specified, it will try to find a directory with "positions" in its name. 

610 """ 

611 

612 def __init__(self, name=None, input_dir=None, positions_dir=None, output_dir=None, **kwargs): 

613 """ 

614 Initialize an ImageStitchingStep. 

615 

616 Args: 

617 name (str, optional): Name of the step 

618 input_dir (str, optional): Directory containing images to stitch. 

619 If not specified, uses the pipeline's input directory. 

620 positions_dir (str, optional): Directory containing position files. 

621 If not specified, tries to find a directory with "positions" in its name. 

622 output_dir (str, optional): Directory to save stitched images 

623 **kwargs: Additional arguments for the step 

624 """ 

625 super().__init__( 

626 func=None, # ImageStitchingStep doesn't use the standard func mechanism 

627 name=name or "Image Stitching", 

628 input_dir=input_dir, 

629 output_dir=output_dir, 

630 variable_components=[], # Empty list for variable_components 

631 **kwargs 

632 ) 

633 self.positions_dir = positions_dir 

634 

635 def process(self, group: List[Any], context: Optional[Dict[str, Any]] = None) -> Any: 

636 """ 

637 Stitch images using the positions file. 

638 

639 This implementation adapts the specialized step to the AbstractStep interface. 

640 

641 Args: 

642 group: Group of images to process (not used in this step) 

643 context: Pipeline context for sharing data between steps 

644 

645 Returns: 

646 The processed result 

647 """ 

648 # For backward compatibility, if this method is called with a ProcessingContext 

649 # object as the first argument, treat it as the old-style process method 

650 if hasattr(group, 'orchestrator') and context is None: 

651 return self._process_with_context_object(group) 

652 

653 # This step doesn't process images directly, so return the group unchanged 

654 return group 

655 

656 def _process_with_context_object(self, context): 

657 """ 

658 Legacy method for backward compatibility. 

659 

660 Args: 

661 context: Processing context containing orchestrator and other metadata 

662 

663 Returns: 

664 Updated context 

665 """ 

666 logger.info("Processing step: %s", self.name) 

667 

668 # Get orchestrator from context 

669 orchestrator = getattr(context, 'orchestrator', None) 

670 if not orchestrator: 

671 raise ValueError("ImageStitchingStep requires an orchestrator in the context") 

672 

673 # Get well from context 

674 well = context.well_filter[0] if context.well_filter else None 

675 if not well: 

676 raise ValueError("ImageStitchingStep requires a well filter in the context") 

677 

678 # If positions_dir is not specified, try to get it from context or find it 

679 if not self.positions_dir: 

680 # First try to get from context (set by PositionGenerationStep) 

681 self.positions_dir = getattr(context, 'positions_dir', None) 

682 

683 # If still not found, try to find at parent level of plate 

684 if not self.positions_dir and orchestrator: 

685 plate_name = orchestrator.plate_path.name 

686 parent_positions_dir = orchestrator.plate_path.parent / f"{plate_name}_positions" 

687 if parent_positions_dir.exists(): 

688 self.positions_dir = parent_positions_dir 

689 logger.info(f"Using positions directory at parent level: {self.positions_dir}") 

690 else: 

691 # Fallback to existing logic if no positions directory is found 

692 self.positions_dir = FileSystemManager.find_directory_substring_recursive( 

693 Path(self.input_dir).parent, "positions") 

694 

695 # If still not found, raise an error 

696 if not self.positions_dir: 

697 raise ValueError(f"No positions directory found for well {well}") 

698 

699 # Call the stitch_images method 

700 orchestrator.stitch_images( 

701 well=well, 

702 input_dir=self.input_dir, 

703 output_dir=self.output_dir or context.output_dir, 

704 positions_file=Path(self.positions_dir) / f"{well}.csv" 

705 ) 

706 

707 return context 

708 

709 

710 

711 

712def group_patterns_by(patterns, component, microscope_handler=None): 

713 """ 

714 Group patterns by the specified component. 

715 

716 Args: 

717 patterns (list): Patterns to group 

718 Returns: 

719 dict: Dictionary mapping component values to lists of patterns 

720 """ 

721 grouped_patterns = {} 

722 for pattern in patterns: 

723 # Extract the component value from the pattern 

724 component_value = microscope_handler.parser.parse_filename(pattern)[component] 

725 if component_value not in grouped_patterns: 

726 grouped_patterns[component_value] = [] 

727 grouped_patterns[component_value].append(pattern) 

728 return grouped_patterns 

729 

730 

731class ZFlatStep(Step): 

732 """ 

733 Specialized step for Z-stack flattening. 

734 

735 This step performs Z-stack flattening using the specified method. 

736 It pre-configures variable_components=['z_index'] and group_by=None. 

737 """ 

738 

739 PROJECTION_METHODS = { 

740 "max": "max_projection", 

741 "mean": "mean_projection", 

742 "median": "median_projection", 

743 "min": "min_projection", 

744 "std": "std_projection", 

745 "sum": "sum_projection" 

746 } 

747 

748 def __init__( 

749 self, 

750 method: str = "max", 

751 **kwargs 

752 ): 

753 """ 

754 Initialize a Z-stack flattening step. 

755 

756 Args: 

757 method: Projection method. Options: "max", "mean", "median", "min", "std", "sum" 

758 **kwargs: Additional arguments passed to the parent Step class: 

759 input_dir: Input directory 

760 output_dir: Output directory 

761 well_filter: Wells to process 

762 """ 

763 # Validate method 

764 if method not in self.PROJECTION_METHODS and method not in self.PROJECTION_METHODS.values(): 

765 raise ValueError(f"Unknown projection method: {method}. " 

766 f"Options are: {', '.join(self.PROJECTION_METHODS.keys())}") 

767 

768 # Get the full method name if a shorthand was provided 

769 self.method = method 

770 full_method = self.PROJECTION_METHODS.get(method, method) 

771 

772 # Initialize the Step with pre-configured parameters 

773 super().__init__( 

774 func=(IP.create_projection, {'method': full_method}), 

775 variable_components=['z_index'], 

776 group_by=None, 

777 name=f"{method.capitalize()} Projection", 

778 **kwargs 

779 ) 

780 

781 

782class FocusStep(Step): 

783 """ 

784 Specialized step for focus-based Z-stack processing. 

785 

786 This step finds the best focus plane in a Z-stack using FocusAnalyzer. 

787 It pre-configures variable_components=['z_index'] and group_by=None. 

788 """ 

789 

790 def __init__( 

791 self, 

792 focus_options: Optional[Dict[str, Any]] = None, 

793 **kwargs 

794 ): 

795 """ 

796 Initialize a focus step. 

797 

798 Args: 

799 focus_options: Dictionary of focus analyzer options: 

800 - metric: Focus metric. Options: "combined", "normalized_variance", 

801 "laplacian", "tenengrad", "fft" (default: "combined") 

802 **kwargs: Additional arguments passed to the parent Step class: 

803 input_dir: Input directory 

804 output_dir: Output directory 

805 well_filter: Wells to process 

806 """ 

807 # Initialize focus options 

808 focus_options = focus_options or {'metric': 'combined'} 

809 metric = focus_options.get('metric', 'combined') 

810 

811 def process_func(images): 

812 best_image, _, _ = FocusAnalyzer.select_best_focus(images, metric=metric) 

813 return best_image 

814 

815 # Initialize the Step with pre-configured parameters 

816 super().__init__( 

817 func=(process_func, {}), 

818 variable_components=['z_index'], 

819 group_by=None, 

820 name=f"Best Focus ({metric})", 

821 **kwargs 

822 ) 

823 

824 

825class CompositeStep(Step): 

826 """ 

827 Specialized step for creating composite images from multiple channels. 

828 

829 This step creates composite images from multiple channels with specified weights. 

830 It pre-configures variable_components=['channel'] and group_by=None. 

831 """ 

832 

833 def __init__( 

834 self, 

835 weights: Optional[List[float]] = None, 

836 **kwargs 

837 ): 

838 """ 

839 Initialize a channel compositing step. 

840 

841 Args: 

842 weights: List of weights for each channel. If None, equal weights are used. 

843 **kwargs: Additional arguments passed to the parent Step class: 

844 input_dir: Input directory 

845 output_dir: Output directory 

846 well_filter: Wells to process 

847 """ 

848 # Initialize the Step with pre-configured parameters 

849 super().__init__( 

850 func=(IP.create_composite, {'weights': weights}), 

851 variable_components=['channel'], 

852 group_by=None, 

853 name="Channel Composite", 

854 **kwargs 

855 ) 

856 

857 

858class NormStep(Step): 

859 """ 

860 Specialized step for image normalization. 

861 

862 This step performs percentile-based normalization on images. 

863 It pre-configures func=IP.stack_percentile_normalize with customizable percentile parameters. 

864 """ 

865 

866 def __init__( 

867 self, 

868 low_percentile: float = 0.1, 

869 high_percentile: float = 99.9, 

870 **kwargs 

871 ): 

872 """ 

873 Initialize a normalization step. 

874 

875 Args: 

876 low_percentile: Low percentile for normalization (0-100) 

877 high_percentile: High percentile for normalization (0-100) 

878 **kwargs: Additional arguments passed to the parent Step class: 

879 input_dir: Input directory 

880 output_dir: Output directory 

881 well_filter: Wells to process 

882 variable_components: Components that vary across files (default: ['site']) 

883 group_by: How to group files for processing (default: None) 

884 """ 

885 # Initialize the Step with pre-configured parameters 

886 super().__init__( 

887 func=(IP.stack_percentile_normalize, { 

888 'low_percentile': low_percentile, 

889 'high_percentile': high_percentile 

890 }), 

891 name="Percentile Normalization", 

892 **kwargs 

893 )