Coverage for openhcs/processing/backends/analysis/skan_axon_analysis.py: 11.1%
261 statements
« prev ^ index » next coverage.py v7.11.0, created at 2025-11-04 02:09 +0000
« prev ^ index » next coverage.py v7.11.0, created at 2025-11-04 02:09 +0000
1"""
2Skan-based axon skeletonization and analysis for OpenHCS.
4This module provides comprehensive axon analysis using the skan library,
5including segmentation, skeletonization, and quantitative skeleton analysis.
6Supports both 2D and 3D analysis modes with multiple output formats.
7"""
9import numpy as np
10import pandas as pd
11from datetime import datetime
12from enum import Enum
13from pathlib import Path
14from typing import Dict, Any, Tuple, Optional, List
15import logging
17# OpenHCS imports
18from openhcs.core.memory.decorators import numpy as numpy_func
19from openhcs.core.pipeline.function_contracts import special_outputs
21logger = logging.getLogger(__name__)
24class ThresholdMethod(Enum):
25 """Segmentation methods for axon detection."""
26 OTSU = "otsu"
27 MANUAL = "manual"
28 ADAPTIVE = "adaptive"
31class OutputMode(Enum):
32 """Output array format options."""
33 SKELETON = "skeleton"
34 SKELETON_OVERLAY = "skeleton_overlay"
35 ORIGINAL = "original"
36 COMPOSITE = "composite"
39class AnalysisDimension(Enum):
40 """Analysis dimension modes."""
41 TWO_D = "2d"
42 THREE_D = "3d"
45def materialize_axon_analysis(
46 axon_analysis_data: Dict[str, Any],
47 path: str,
48 filemanager,
49 backend: str,
50 **kwargs
51) -> str:
52 """
53 Materialize axon analysis results to disk using filemanager.
55 Creates multiple output files:
56 - CSV file with detailed branch data
57 - JSON file with summary metrics and metadata
58 - Optional: Excel file with multiple sheets
60 Args:
61 axon_analysis_data: The axon analysis results dictionary
62 path: Base path for output files (from special output path)
63 filemanager: FileManager instance for consistent I/O
64 backend: Backend to use for materialization
65 **kwargs: Additional materialization options
67 Returns:
68 str: Path to the primary output file (JSON summary)
69 """
70 logger.info(f"🔬 SKAN_MATERIALIZE: Called with path={path}, backend={backend}, data_keys={list(axon_analysis_data.keys()) if axon_analysis_data else 'None'}")
71 import json
72 from openhcs.constants.constants import Backend
74 # Generate output file paths based on the input path
75 # Use clean naming: preserve namespaced path structure, don't duplicate special output key
76 base_path = path.replace('.pkl', '')
77 json_path = f"{base_path}.json"
78 csv_path = f"{base_path}_branches.csv"
80 # Ensure output directory exists for disk backend
81 output_dir = Path(json_path).parent
82 if backend == Backend.DISK.value:
83 filemanager.ensure_directory(str(output_dir), backend)
85 # 1. Save summary and metadata as JSON (primary output)
86 summary_data = {
87 'analysis_type': 'axon_skeleton_analysis',
88 'summary': axon_analysis_data['summary'],
89 'metadata': axon_analysis_data['metadata']
90 }
91 json_content = json.dumps(summary_data, indent=2, default=str)
92 # Remove existing file if it exists using filemanager
93 if filemanager.exists(json_path, backend):
94 filemanager.delete(json_path, backend)
95 filemanager.save(json_content, json_path, backend)
97 # 2. Save detailed branch data as CSV
98 branch_df = pd.DataFrame(axon_analysis_data['branch_data'])
99 csv_content = branch_df.to_csv(index=False)
100 # Remove existing file if it exists using filemanager
101 if filemanager.exists(csv_path, backend):
102 filemanager.delete(csv_path, backend)
103 filemanager.save(csv_content, csv_path, backend)
105 # 3. Optional: Create Excel file with multiple sheets (using direct file I/O for Excel)
106 # Note: Excel files require actual file paths, not compatible with OMERO backend
107 if kwargs.get('create_excel', True) and backend == Backend.DISK.value:
108 excel_path = f"{base_path}_complete.xlsx"
109 # Remove existing file if it exists
110 if Path(excel_path).exists():
111 Path(excel_path).unlink()
112 with pd.ExcelWriter(excel_path, engine='openpyxl') as writer:
113 # Branch data sheet
114 branch_df.to_excel(writer, sheet_name='Branch_Data', index=False)
116 # Summary sheet
117 summary_df = pd.DataFrame([axon_analysis_data['summary']])
118 summary_df.to_excel(writer, sheet_name='Summary', index=False)
120 # Metadata sheet
121 metadata_df = pd.DataFrame([axon_analysis_data['metadata']])
122 metadata_df.to_excel(writer, sheet_name='Metadata', index=False)
124 logger.info(f"Created Excel file: {excel_path}")
126 # 4. Log materialization
127 logger.info("Materialized axon analysis:")
128 logger.info(f" - Summary: {json_path}")
129 logger.info(f" - Branch data: {csv_path}")
131 return json_path
134def materialize_skeleton_visualizations(data: List[np.ndarray], path: str, filemanager, backend: str) -> str:
135 """Materialize skeleton visualizations as individual TIFF files."""
137 if not data:
138 # Create empty summary file to indicate no visualizations were generated
139 summary_path = path.replace('.pkl', '_skeleton_summary.txt')
140 summary_content = "No skeleton visualizations generated (return_skeleton_visualizations=False)\n"
141 filemanager.save(summary_content, summary_path, backend)
142 return summary_path
144 # Generate output file paths based on the input path
145 base_path = path.replace('.pkl', '')
147 # Save each visualization as a separate TIFF file
148 for i, visualization in enumerate(data):
149 viz_filename = f"{base_path}_slice_{i:03d}.tif"
151 # Convert visualization to appropriate dtype for saving (uint16 to match input images)
152 if visualization.dtype != np.uint16:
153 # Normalize to uint16 range if needed
154 if visualization.max() <= 1.0:
155 viz_uint16 = (visualization * 65535).astype(np.uint16)
156 else:
157 viz_uint16 = visualization.astype(np.uint16)
158 else:
159 viz_uint16 = visualization
161 # Save using filemanager
162 from openhcs.constants.constants import Backend
163 filemanager.save(viz_uint16, viz_filename, Backend.DISK.value)
165 # Return summary path
166 summary_path = f"{base_path}_skeleton_summary.txt"
167 summary_content = f"Skeleton visualizations saved: {len(data)} files\n"
168 summary_content += f"Base filename pattern: {base_path}_slice_XXX.tif\n"
169 summary_content += f"Visualization dtype: {data[0].dtype}\n"
170 summary_content += f"Visualization shape: {data[0].shape}\n"
172 filemanager.save(summary_content, summary_path, Backend.DISK.value)
174 return summary_path
177@special_outputs(("axon_analysis", materialize_axon_analysis), ("skeleton_visualizations", materialize_skeleton_visualizations))
178@numpy_func
179def skan_axon_skeletonize_and_analyze(
180 image_stack: np.ndarray,
181 voxel_spacing: Tuple[float, float, float] = (1.0, 1.0, 1.0),
182 threshold_method: ThresholdMethod = ThresholdMethod.OTSU,
183 threshold_value: Optional[float] = None,
184 min_object_size: int = 100,
185 min_branch_length: float = 0.0,
186 return_skeleton_visualizations: bool = False,
187 skeleton_visualization_mode: OutputMode = OutputMode.SKELETON_OVERLAY,
188 analysis_dimension: AnalysisDimension = AnalysisDimension.THREE_D
189) -> Tuple[np.ndarray, Dict[str, Any], List[np.ndarray]]:
190 """
191 Skeletonize axon images and perform comprehensive skeleton analysis.
193 Complete workflow: segmentation → skeletonization → analysis
195 Args:
196 image_stack: 3D grayscale image to skeletonize (Z, Y, X format)
197 voxel_spacing: Physical voxel spacing (z, y, x) in micrometers
198 threshold_method: Segmentation method (OTSU, MANUAL, ADAPTIVE)
199 threshold_value: Manual threshold value (if threshold_method=MANUAL)
200 min_object_size: Minimum object size for noise removal (voxels)
201 min_branch_length: Minimum branch length threshold (micrometers)
202 return_skeleton_visualizations: Whether to generate skeleton visualizations as special output
203 skeleton_visualization_mode: Type of visualization (SKELETON, SKELETON_OVERLAY, ORIGINAL, COMPOSITE)
204 analysis_dimension: Analysis mode (TWO_D or THREE_D)
206 Returns:
207 Tuple containing:
208 - Original image stack: Input image unchanged (Z, Y, X)
209 - Axon analysis results: Complete analysis data structure
210 - Skeleton visualizations: (Special output) List of visualization arrays if return_skeleton_visualizations=True
211 """
212 # Validate input
213 if len(image_stack.shape) != 3:
214 raise ValueError(f"Expected 3D image, got {len(image_stack.shape)}D")
216 if threshold_method == ThresholdMethod.MANUAL and threshold_value is None:
217 raise ValueError("threshold_value required when threshold_method=MANUAL")
219 logger.info(f"Starting skan axon analysis: {image_stack.shape} image")
220 logger.info(f"Parameters: threshold={threshold_method.value}, "
221 f"analysis={analysis_dimension.value}, visualizations={return_skeleton_visualizations}")
223 # Step 1: Segmentation/Thresholding
224 binary_stack = _segment_axons(image_stack, threshold_method, threshold_value)
226 # Step 2: Noise removal
227 if min_object_size > 0:
228 binary_stack = _remove_small_objects(binary_stack, min_object_size)
230 # Step 3: Skeletonization
231 skeleton_stack = _skeletonize_3d(binary_stack)
233 # Step 4: Skeleton analysis
234 if analysis_dimension == AnalysisDimension.THREE_D:
235 branch_data = _analyze_3d_skeleton(skeleton_stack, voxel_spacing)
236 analysis_type = "3D volumetric"
237 elif analysis_dimension == AnalysisDimension.TWO_D:
238 branch_data = _analyze_2d_slices(skeleton_stack, voxel_spacing)
239 analysis_type = "2D slice-by-slice"
240 else:
241 raise ValueError(f"Invalid analysis_dimension: {analysis_dimension}")
243 # Step 5: Filter results
244 # DataFrame always has proper schema (even when empty), so we can filter directly
245 if min_branch_length > 0 and len(branch_data) > 0:
246 branch_data = branch_data[branch_data['branch_distance'] >= min_branch_length]
248 # Step 6: Generate skeleton visualizations if requested
249 skeleton_visualizations = []
250 if return_skeleton_visualizations:
251 # Generate visualization for each slice
252 for z in range(image_stack.shape[0]):
253 slice_image = image_stack[z]
254 slice_binary = binary_stack[z]
255 slice_skeleton = skeleton_stack[z]
257 # Create visualization for this slice
258 visualization = _create_output_array_2d(
259 slice_image, slice_binary, slice_skeleton, skeleton_visualization_mode
260 )
261 skeleton_visualizations.append(visualization)
263 # Step 7: Compile comprehensive results
264 results = _compile_analysis_results(
265 branch_data, skeleton_stack, binary_stack, image_stack,
266 voxel_spacing, analysis_type, threshold_method, min_object_size, min_branch_length
267 )
269 logger.info(f"Analysis complete: {len(branch_data)} branches found")
271 # Always return original image, analysis results, and skeleton visualizations
272 return image_stack, results, skeleton_visualizations
275# Helper functions for segmentation and preprocessing
276def _segment_axons(image_stack, threshold_method, threshold_value):
277 """Segment axons from grayscale image."""
278 from skimage import filters
280 if threshold_method == ThresholdMethod.OTSU:
281 # Global Otsu thresholding
282 threshold = filters.threshold_otsu(image_stack)
283 binary_stack = image_stack > threshold
284 logger.debug(f"Otsu threshold: {threshold}")
286 elif threshold_method == ThresholdMethod.MANUAL:
287 # Manual threshold (threshold_value already validated)
288 binary_stack = image_stack > threshold_value
289 logger.debug(f"Manual threshold: {threshold_value}")
291 elif threshold_method == ThresholdMethod.ADAPTIVE:
292 # Slice-by-slice adaptive thresholding
293 binary_stack = np.zeros_like(image_stack, dtype=bool)
294 for z in range(image_stack.shape[0]):
295 if image_stack[z].max() > 0: # Skip empty slices
296 threshold = filters.threshold_local(image_stack[z], block_size=51)
297 binary_stack[z] = image_stack[z] > threshold
298 logger.debug("Applied adaptive thresholding slice-by-slice")
300 else:
301 raise ValueError(f"Unknown threshold_method: {threshold_method}")
303 return binary_stack
306def _remove_small_objects(binary_stack, min_size):
307 """Remove small objects from binary image."""
308 from skimage import morphology
310 # Apply to each slice to preserve 3D connectivity
311 cleaned_stack = np.zeros_like(binary_stack)
312 removed_count = 0
314 for z in range(binary_stack.shape[0]):
315 if binary_stack[z].any():
316 original_objects = np.sum(binary_stack[z])
317 cleaned_stack[z] = morphology.remove_small_objects(
318 binary_stack[z], min_size=min_size
319 )
320 removed_objects = original_objects - np.sum(cleaned_stack[z])
321 removed_count += removed_objects
323 logger.debug(f"Removed {removed_count} small object pixels (min_size={min_size})")
324 return cleaned_stack
327def _skeletonize_3d(binary_stack):
328 """Create 3D skeleton from binary image."""
329 from skimage import morphology
331 # Use 3D skeletonization to preserve connectivity
332 skeleton_stack = morphology.skeletonize(binary_stack)
334 # Count skeleton pixels for logging
335 skeleton_pixels = np.sum(skeleton_stack)
336 binary_pixels = np.sum(binary_stack)
337 reduction_ratio = skeleton_pixels / binary_pixels if binary_pixels > 0 else 0
339 logger.debug(f"Skeletonization: {binary_pixels} → {skeleton_pixels} pixels "
340 f"(reduction: {reduction_ratio:.3f})")
342 return skeleton_stack
345def _create_empty_branch_dataframe(include_2d_columns: bool = False):
346 """
347 Create an empty DataFrame with the expected skan branch data schema.
349 This ensures consistent DataFrame structure even when no branches are found,
350 preventing KeyError when filtering or processing results.
352 Args:
353 include_2d_columns: If True, include additional columns for 2D slice analysis
355 Returns:
356 Empty DataFrame with proper column schema
357 """
358 # Core columns from skan.summarize()
359 columns = [
360 'skeleton_id',
361 'node_id_src',
362 'node_id_dst',
363 'branch_distance',
364 'branch_type',
365 'mean_pixel_value',
366 'stdev_pixel_value',
367 'image_coord_src_0',
368 'image_coord_src_1',
369 'image_coord_src_2',
370 'image_coord_dst_0',
371 'image_coord_dst_1',
372 'image_coord_dst_2',
373 'coord_src_0',
374 'coord_src_1',
375 'coord_src_2',
376 'coord_dst_0',
377 'coord_dst_1',
378 'coord_dst_2',
379 'euclidean_distance',
380 ]
382 # Add 2D-specific columns if requested
383 if include_2d_columns:
384 columns.extend(['z_slice', 'z_coord', 'skeleton_id'])
386 return pd.DataFrame(columns=columns)
389def _analyze_3d_skeleton(skeleton_stack, voxel_spacing):
390 """Analyze skeleton as single 3D network."""
391 try:
392 from skan import Skeleton, summarize
393 except ImportError:
394 raise ImportError("skan library is required for skeleton analysis. "
395 "Install with: pip install skan")
397 if not skeleton_stack.any():
398 logger.warning("Empty skeleton - returning empty analysis")
399 return _create_empty_branch_dataframe()
401 # Single 3D analysis - preserves Z-connections
402 skeleton_obj = Skeleton(skeleton_stack, spacing=voxel_spacing)
403 branch_data = summarize(skeleton_obj, separator='_')
405 logger.debug(f"3D analysis: {len(branch_data)} branches found")
406 return branch_data
409def _analyze_2d_slices(skeleton_stack, voxel_spacing):
410 """Analyze each Z-slice as separate 2D skeleton."""
411 try:
412 from skan import Skeleton, summarize
413 except ImportError:
414 raise ImportError("skan library is required for skeleton analysis. "
415 "Install with: pip install skan")
417 all_branch_data = []
418 z_spacing, y_spacing, x_spacing = voxel_spacing
420 for z_idx, slice_skeleton in enumerate(skeleton_stack):
421 if slice_skeleton.any(): # Skip empty slices
422 # 2D analysis with XY spacing only
423 skeleton_obj = Skeleton(slice_skeleton, spacing=(y_spacing, x_spacing))
424 slice_data = summarize(skeleton_obj, separator='_')
426 if len(slice_data) > 0:
427 # Add Z-coordinate information
428 slice_data['z_slice'] = z_idx
429 slice_data['z_coord'] = z_idx * z_spacing
430 slice_data['skeleton_id'] = f"slice_{z_idx:03d}"
432 all_branch_data.append(slice_data)
434 # Combine all slices
435 if all_branch_data:
436 combined_data = pd.concat(all_branch_data, ignore_index=True)
437 logger.debug(f"2D analysis: {len(combined_data)} branches across "
438 f"{len(all_branch_data)} slices")
439 return combined_data
440 else:
441 logger.warning("No skeleton data found in any slice")
442 return _create_empty_branch_dataframe(include_2d_columns=True)
445def _create_output_array_2d(slice_image, slice_binary, slice_skeleton, output_mode):
446 """Generate 2D output array based on specified mode."""
448 if output_mode == OutputMode.SKELETON:
449 # Return binary skeleton
450 return slice_skeleton.astype(np.uint8) * 255
452 elif output_mode == OutputMode.SKELETON_OVERLAY:
453 # Overlay skeleton on original image
454 output = slice_image.copy()
455 # Highlight skeleton pixels with maximum intensity
456 if slice_skeleton.any():
457 output[slice_skeleton] = slice_image.max()
458 return output
460 elif output_mode == OutputMode.ORIGINAL:
461 # Return original unchanged
462 return slice_image.copy()
464 elif output_mode == OutputMode.COMPOSITE:
465 # Side-by-side: original | binary | skeleton
466 y, x = slice_image.shape
467 composite = np.zeros((y, x * 3), dtype=slice_image.dtype)
469 # Original image
470 composite[:, :x] = slice_image
472 # Binary segmentation (scaled to match original intensity range)
473 binary_scaled = (slice_binary.astype(np.float32) * slice_image.max()).astype(slice_image.dtype)
474 composite[:, x:2*x] = binary_scaled
476 # Skeleton (scaled to match original intensity range)
477 skeleton_scaled = (slice_skeleton.astype(np.float32) * slice_image.max()).astype(slice_image.dtype)
478 composite[:, 2*x:3*x] = skeleton_scaled
480 return composite
482 else:
483 raise ValueError(f"Unknown output_mode: {output_mode}")
486def _create_output_array(image_stack, binary_stack, skeleton_stack, branch_data, output_mode):
487 """Generate output array based on specified mode (legacy function, kept for compatibility)."""
489 if output_mode == OutputMode.SKELETON:
490 # Return binary skeleton
491 return skeleton_stack.astype(np.uint8) * 255
493 elif output_mode == OutputMode.SKELETON_OVERLAY:
494 # Overlay skeleton on original image
495 output = image_stack.copy()
496 # Highlight skeleton pixels with maximum intensity
497 if skeleton_stack.any():
498 output[skeleton_stack] = image_stack.max()
499 return output
501 elif output_mode == OutputMode.ORIGINAL:
502 # Return original unchanged
503 return image_stack.copy()
505 elif output_mode == OutputMode.COMPOSITE:
506 # Side-by-side: original | binary | skeleton
507 z, y, x = image_stack.shape
508 composite = np.zeros((z, y, x * 3), dtype=image_stack.dtype)
510 # Original image
511 composite[:, :, :x] = image_stack
513 # Binary segmentation (scaled to match original intensity range)
514 binary_scaled = (binary_stack.astype(np.float32) * image_stack.max()).astype(image_stack.dtype)
515 composite[:, :, x:2*x] = binary_scaled
517 # Skeleton (scaled to match original intensity range)
518 skeleton_scaled = (skeleton_stack.astype(np.float32) * image_stack.max()).astype(image_stack.dtype)
519 composite[:, :, 2*x:3*x] = skeleton_scaled
521 return composite
523 else:
524 raise ValueError(f"Unknown output_mode: {output_mode}")
527def _compile_analysis_results(branch_data, skeleton_stack, binary_stack, image_stack,
528 voxel_spacing, analysis_type, threshold_method,
529 min_object_size, min_branch_length):
530 """Compile complete analysis results."""
532 # Compute summary metrics
533 summary = _compute_summary_metrics(branch_data, skeleton_stack.shape, voxel_spacing)
535 # Add segmentation metrics
536 total_voxels = np.prod(image_stack.shape)
537 binary_voxels = np.sum(binary_stack)
538 skeleton_voxels = np.sum(skeleton_stack)
540 segmentation_metrics = {
541 'total_voxels': int(total_voxels),
542 'segmented_voxels': int(binary_voxels),
543 'skeleton_voxels': int(skeleton_voxels),
544 'segmentation_fraction': float(binary_voxels / total_voxels),
545 'skeleton_fraction': float(skeleton_voxels / binary_voxels) if binary_voxels > 0 else 0.0,
546 }
548 # Combine all results
549 results = {
550 'summary': {**summary, **segmentation_metrics},
551 'branch_data': branch_data.to_dict('list') if len(branch_data) > 0 else {},
552 'metadata': {
553 'analysis_type': analysis_type,
554 'voxel_spacing': voxel_spacing,
555 'threshold_method': threshold_method.value,
556 'min_object_size': min_object_size,
557 'min_branch_length': min_branch_length,
558 'image_shape': image_stack.shape,
559 'image_dtype': str(image_stack.dtype),
560 'intensity_range': (float(image_stack.min()), float(image_stack.max())),
561 'processing_timestamp': datetime.now().isoformat(),
562 'skan_version': _get_skan_version(),
563 }
564 }
566 return results
569def _compute_summary_metrics(branch_data, skeleton_shape, voxel_spacing):
570 """Compute summary statistics from branch data."""
571 if len(branch_data) == 0:
572 return {
573 'total_axon_length': 0.0,
574 'num_branches': 0,
575 'num_junction_points': 0,
576 'num_endpoints': 0,
577 'mean_branch_length': 0.0,
578 'max_branch_length': 0.0,
579 'mean_tortuosity': 0.0,
580 'network_density': 0.0,
581 'branching_ratio': 0.0,
582 'total_volume': float(np.prod(skeleton_shape) * np.prod(voxel_spacing)),
583 }
585 # Basic metrics
586 total_length = branch_data['branch_distance'].sum()
587 num_branches = len(branch_data)
588 mean_length = branch_data['branch_distance'].mean()
589 max_length = branch_data['branch_distance'].max()
591 # Tortuosity (branch_distance / euclidean_distance)
592 tortuosity = branch_data['branch_distance'] / (branch_data['euclidean_distance'] + 1e-8)
593 mean_tortuosity = tortuosity.mean()
595 # Count junction points and endpoints based on branch types
596 # Branch types: 0=endpoint-endpoint, 1=junction-endpoint, 2=junction-junction, 3=cycle
597 junction_branches = branch_data[branch_data['branch_type'].isin([1, 2])]
598 num_junction_points = len(junction_branches['node_id_src'].unique()) if len(junction_branches) > 0 else 0
600 endpoint_branches = branch_data[branch_data['branch_type'].isin([0, 1])]
601 num_endpoints = len(endpoint_branches) * 2 if len(endpoint_branches) > 0 else 0 # Each branch has 2 endpoints
603 # Volume and density
604 total_volume = float(np.prod(skeleton_shape) * np.prod(voxel_spacing))
605 network_density = num_branches / total_volume if total_volume > 0 else 0.0
607 # Branching ratio
608 branching_ratio = num_junction_points / num_endpoints if num_endpoints > 0 else 0.0
610 return {
611 'total_axon_length': float(total_length),
612 'num_branches': int(num_branches),
613 'num_junction_points': int(num_junction_points),
614 'num_endpoints': int(num_endpoints),
615 'mean_branch_length': float(mean_length),
616 'max_branch_length': float(max_length),
617 'mean_tortuosity': float(mean_tortuosity),
618 'network_density': float(network_density),
619 'branching_ratio': float(branching_ratio),
620 'total_volume': total_volume,
621 }
624def _get_skan_version():
625 """Get skan library version."""
626 try:
627 import skan
628 return skan.__version__
629 except (ImportError, AttributeError):
630 return "unknown"