Coverage for openhcs/processing/backends/analysis/skan_axon_analysis.py: 11.0%
257 statements
« prev ^ index » next coverage.py v7.10.3, created at 2025-08-14 05:57 +0000
« prev ^ index » next coverage.py v7.10.3, created at 2025-08-14 05:57 +0000
1"""
2Skan-based axon skeletonization and analysis for OpenHCS.
4This module provides comprehensive axon analysis using the skan library,
5including segmentation, skeletonization, and quantitative skeleton analysis.
6Supports both 2D and 3D analysis modes with multiple output formats.
7"""
9import numpy as np
10import pandas as pd
11from datetime import datetime
12from enum import Enum
13from pathlib import Path
14from typing import Dict, Any, Tuple, Optional, Union, List
15import logging
17# OpenHCS imports
18from openhcs.core.memory.decorators import numpy as numpy_func
19from openhcs.core.pipeline.function_contracts import special_outputs
21logger = logging.getLogger(__name__)
24class ThresholdMethod(Enum):
25 """Segmentation methods for axon detection."""
26 OTSU = "otsu"
27 MANUAL = "manual"
28 ADAPTIVE = "adaptive"
31class OutputMode(Enum):
32 """Output array format options."""
33 SKELETON = "skeleton"
34 SKELETON_OVERLAY = "skeleton_overlay"
35 ORIGINAL = "original"
36 COMPOSITE = "composite"
39class AnalysisDimension(Enum):
40 """Analysis dimension modes."""
41 TWO_D = "2d"
42 THREE_D = "3d"
45def materialize_axon_analysis(
46 axon_analysis_data: Dict[str, Any],
47 path: str,
48 filemanager,
49 **kwargs
50) -> str:
51 """
52 Materialize axon analysis results to disk using filemanager.
54 Creates multiple output files:
55 - CSV file with detailed branch data
56 - JSON file with summary metrics and metadata
57 - Optional: Excel file with multiple sheets
59 Args:
60 axon_analysis_data: The axon analysis results dictionary
61 path: Base path for output files (from special output path)
62 filemanager: FileManager instance for consistent I/O
63 **kwargs: Additional materialization options
65 Returns:
66 str: Path to the primary output file (JSON summary)
67 """
68 logger.info(f"🔬 SKAN_MATERIALIZE: Called with path={path}, data_keys={list(axon_analysis_data.keys()) if axon_analysis_data else 'None'}")
69 import json
70 from pathlib import Path
71 from openhcs.constants.constants import Backend
73 # Generate output file paths based on the input path
74 # Use clean naming: preserve namespaced path structure, don't duplicate special output key
75 base_path = path.replace('.pkl', '')
76 json_path = f"{base_path}.json"
77 csv_path = f"{base_path}_branches.csv"
79 # Ensure output directory exists for disk backend
80 output_dir = Path(json_path).parent
81 filemanager.ensure_directory(str(output_dir), Backend.DISK.value)
83 # 1. Save summary and metadata as JSON (primary output)
84 summary_data = {
85 'analysis_type': 'axon_skeleton_analysis',
86 'summary': axon_analysis_data['summary'],
87 'metadata': axon_analysis_data['metadata']
88 }
89 json_content = json.dumps(summary_data, indent=2, default=str)
90 # Remove existing file if it exists using filemanager
91 if filemanager.exists(json_path, Backend.DISK.value):
92 filemanager.delete(json_path, Backend.DISK.value)
93 filemanager.save(json_content, json_path, Backend.DISK.value)
95 # 2. Save detailed branch data as CSV
96 branch_df = pd.DataFrame(axon_analysis_data['branch_data'])
97 csv_content = branch_df.to_csv(index=False)
98 # Remove existing file if it exists using filemanager
99 if filemanager.exists(csv_path, Backend.DISK.value):
100 filemanager.delete(csv_path, Backend.DISK.value)
101 filemanager.save(csv_content, csv_path, Backend.DISK.value)
103 # 3. Optional: Create Excel file with multiple sheets (using direct file I/O for Excel)
104 if kwargs.get('create_excel', True):
105 excel_path = f"{base_path}_complete.xlsx"
106 # Remove existing file if it exists
107 if Path(excel_path).exists():
108 Path(excel_path).unlink()
109 with pd.ExcelWriter(excel_path, engine='openpyxl') as writer:
110 # Branch data sheet
111 branch_df.to_excel(writer, sheet_name='Branch_Data', index=False)
113 # Summary sheet
114 summary_df = pd.DataFrame([axon_analysis_data['summary']])
115 summary_df.to_excel(writer, sheet_name='Summary', index=False)
117 # Metadata sheet
118 metadata_df = pd.DataFrame([axon_analysis_data['metadata']])
119 metadata_df.to_excel(writer, sheet_name='Metadata', index=False)
121 logger.info(f"Created Excel file: {excel_path}")
123 # 4. Log materialization
124 logger.info(f"Materialized axon analysis:")
125 logger.info(f" - Summary: {json_path}")
126 logger.info(f" - Branch data: {csv_path}")
128 return json_path
131def materialize_skeleton_visualizations(data: List[np.ndarray], path: str, filemanager) -> str:
132 """Materialize skeleton visualizations as individual TIFF files."""
134 if not data:
135 # Create empty summary file to indicate no visualizations were generated
136 summary_path = path.replace('.pkl', '_skeleton_summary.txt')
137 summary_content = "No skeleton visualizations generated (return_skeleton_visualizations=False)\n"
138 from openhcs.constants.constants import Backend
139 filemanager.save(summary_content, summary_path, Backend.DISK.value)
140 return summary_path
142 # Generate output file paths based on the input path
143 base_path = path.replace('.pkl', '')
145 # Save each visualization as a separate TIFF file
146 for i, visualization in enumerate(data):
147 viz_filename = f"{base_path}_slice_{i:03d}.tif"
149 # Convert visualization to appropriate dtype for saving (uint16 to match input images)
150 if visualization.dtype != np.uint16:
151 # Normalize to uint16 range if needed
152 if visualization.max() <= 1.0:
153 viz_uint16 = (visualization * 65535).astype(np.uint16)
154 else:
155 viz_uint16 = visualization.astype(np.uint16)
156 else:
157 viz_uint16 = visualization
159 # Save using filemanager
160 from openhcs.constants.constants import Backend
161 filemanager.save(viz_uint16, viz_filename, Backend.DISK.value)
163 # Return summary path
164 summary_path = f"{base_path}_skeleton_summary.txt"
165 summary_content = f"Skeleton visualizations saved: {len(data)} files\n"
166 summary_content += f"Base filename pattern: {base_path}_slice_XXX.tif\n"
167 summary_content += f"Visualization dtype: {data[0].dtype}\n"
168 summary_content += f"Visualization shape: {data[0].shape}\n"
170 filemanager.save(summary_content, summary_path, Backend.DISK.value)
172 return summary_path
175@special_outputs(("axon_analysis", materialize_axon_analysis), ("skeleton_visualizations", materialize_skeleton_visualizations))
176@numpy_func
177def skan_axon_skeletonize_and_analyze(
178 image_stack: np.ndarray,
179 voxel_spacing: Tuple[float, float, float] = (1.0, 1.0, 1.0),
180 threshold_method: ThresholdMethod = ThresholdMethod.OTSU,
181 threshold_value: Optional[float] = None,
182 min_object_size: int = 100,
183 min_branch_length: float = 0.0,
184 return_skeleton_visualizations: bool = False,
185 skeleton_visualization_mode: OutputMode = OutputMode.SKELETON_OVERLAY,
186 analysis_dimension: AnalysisDimension = AnalysisDimension.THREE_D
187) -> Tuple[np.ndarray, Dict[str, Any], List[np.ndarray]]:
188 """
189 Skeletonize axon images and perform comprehensive skeleton analysis.
191 Complete workflow: segmentation → skeletonization → analysis
193 Args:
194 image_stack: 3D grayscale image to skeletonize (Z, Y, X format)
195 voxel_spacing: Physical voxel spacing (z, y, x) in micrometers
196 threshold_method: Segmentation method (OTSU, MANUAL, ADAPTIVE)
197 threshold_value: Manual threshold value (if threshold_method=MANUAL)
198 min_object_size: Minimum object size for noise removal (voxels)
199 min_branch_length: Minimum branch length threshold (micrometers)
200 return_skeleton_visualizations: Whether to generate skeleton visualizations as special output
201 skeleton_visualization_mode: Type of visualization (SKELETON, SKELETON_OVERLAY, ORIGINAL, COMPOSITE)
202 analysis_dimension: Analysis mode (TWO_D or THREE_D)
204 Returns:
205 Tuple containing:
206 - Original image stack: Input image unchanged (Z, Y, X)
207 - Axon analysis results: Complete analysis data structure
208 - Skeleton visualizations: (Special output) List of visualization arrays if return_skeleton_visualizations=True
209 """
210 # Validate input
211 if len(image_stack.shape) != 3:
212 raise ValueError(f"Expected 3D image, got {len(image_stack.shape)}D")
214 if threshold_method == ThresholdMethod.MANUAL and threshold_value is None:
215 raise ValueError("threshold_value required when threshold_method=MANUAL")
217 logger.info(f"Starting skan axon analysis: {image_stack.shape} image")
218 logger.info(f"Parameters: threshold={threshold_method.value}, "
219 f"analysis={analysis_dimension.value}, visualizations={return_skeleton_visualizations}")
221 # Step 1: Segmentation/Thresholding
222 binary_stack = _segment_axons(image_stack, threshold_method, threshold_value)
224 # Step 2: Noise removal
225 if min_object_size > 0:
226 binary_stack = _remove_small_objects(binary_stack, min_object_size)
228 # Step 3: Skeletonization
229 skeleton_stack = _skeletonize_3d(binary_stack)
231 # Step 4: Skeleton analysis
232 if analysis_dimension == AnalysisDimension.THREE_D:
233 branch_data = _analyze_3d_skeleton(skeleton_stack, voxel_spacing)
234 analysis_type = "3D volumetric"
235 elif analysis_dimension == AnalysisDimension.TWO_D:
236 branch_data = _analyze_2d_slices(skeleton_stack, voxel_spacing)
237 analysis_type = "2D slice-by-slice"
238 else:
239 raise ValueError(f"Invalid analysis_dimension: {analysis_dimension}")
241 # Step 5: Filter results
242 if min_branch_length > 0:
243 branch_data = branch_data[branch_data['branch_distance'] >= min_branch_length]
245 # Step 6: Generate skeleton visualizations if requested
246 skeleton_visualizations = []
247 if return_skeleton_visualizations:
248 # Generate visualization for each slice
249 for z in range(image_stack.shape[0]):
250 slice_image = image_stack[z]
251 slice_binary = binary_stack[z]
252 slice_skeleton = skeleton_stack[z]
254 # Create visualization for this slice
255 visualization = _create_output_array_2d(
256 slice_image, slice_binary, slice_skeleton, skeleton_visualization_mode
257 )
258 skeleton_visualizations.append(visualization)
260 # Step 7: Compile comprehensive results
261 results = _compile_analysis_results(
262 branch_data, skeleton_stack, binary_stack, image_stack,
263 voxel_spacing, analysis_type, threshold_method, min_object_size, min_branch_length
264 )
266 logger.info(f"Analysis complete: {len(branch_data)} branches found")
268 # Always return original image, analysis results, and skeleton visualizations
269 return image_stack, results, skeleton_visualizations
272# Helper functions for segmentation and preprocessing
273def _segment_axons(image_stack, threshold_method, threshold_value):
274 """Segment axons from grayscale image."""
275 from skimage import filters
277 if threshold_method == ThresholdMethod.OTSU:
278 # Global Otsu thresholding
279 threshold = filters.threshold_otsu(image_stack)
280 binary_stack = image_stack > threshold
281 logger.debug(f"Otsu threshold: {threshold}")
283 elif threshold_method == ThresholdMethod.MANUAL:
284 # Manual threshold (threshold_value already validated)
285 binary_stack = image_stack > threshold_value
286 logger.debug(f"Manual threshold: {threshold_value}")
288 elif threshold_method == ThresholdMethod.ADAPTIVE:
289 # Slice-by-slice adaptive thresholding
290 binary_stack = np.zeros_like(image_stack, dtype=bool)
291 for z in range(image_stack.shape[0]):
292 if image_stack[z].max() > 0: # Skip empty slices
293 threshold = filters.threshold_local(image_stack[z], block_size=51)
294 binary_stack[z] = image_stack[z] > threshold
295 logger.debug("Applied adaptive thresholding slice-by-slice")
297 else:
298 raise ValueError(f"Unknown threshold_method: {threshold_method}")
300 return binary_stack
303def _remove_small_objects(binary_stack, min_size):
304 """Remove small objects from binary image."""
305 from skimage import morphology
307 # Apply to each slice to preserve 3D connectivity
308 cleaned_stack = np.zeros_like(binary_stack)
309 removed_count = 0
311 for z in range(binary_stack.shape[0]):
312 if binary_stack[z].any():
313 original_objects = np.sum(binary_stack[z])
314 cleaned_stack[z] = morphology.remove_small_objects(
315 binary_stack[z], min_size=min_size
316 )
317 removed_objects = original_objects - np.sum(cleaned_stack[z])
318 removed_count += removed_objects
320 logger.debug(f"Removed {removed_count} small object pixels (min_size={min_size})")
321 return cleaned_stack
324def _skeletonize_3d(binary_stack):
325 """Create 3D skeleton from binary image."""
326 from skimage import morphology
328 # Use 3D skeletonization to preserve connectivity
329 skeleton_stack = morphology.skeletonize(binary_stack)
331 # Count skeleton pixels for logging
332 skeleton_pixels = np.sum(skeleton_stack)
333 binary_pixels = np.sum(binary_stack)
334 reduction_ratio = skeleton_pixels / binary_pixels if binary_pixels > 0 else 0
336 logger.debug(f"Skeletonization: {binary_pixels} → {skeleton_pixels} pixels "
337 f"(reduction: {reduction_ratio:.3f})")
339 return skeleton_stack
342def _analyze_3d_skeleton(skeleton_stack, voxel_spacing):
343 """Analyze skeleton as single 3D network."""
344 try:
345 from skan import Skeleton, summarize
346 except ImportError:
347 raise ImportError("skan library is required for skeleton analysis. "
348 "Install with: pip install skan")
350 if not skeleton_stack.any():
351 logger.warning("Empty skeleton - returning empty analysis")
352 return pd.DataFrame()
354 # Single 3D analysis - preserves Z-connections
355 skeleton_obj = Skeleton(skeleton_stack, spacing=voxel_spacing)
356 branch_data = summarize(skeleton_obj, separator='_')
358 logger.debug(f"3D analysis: {len(branch_data)} branches found")
359 return branch_data
362def _analyze_2d_slices(skeleton_stack, voxel_spacing):
363 """Analyze each Z-slice as separate 2D skeleton."""
364 try:
365 from skan import Skeleton, summarize
366 except ImportError:
367 raise ImportError("skan library is required for skeleton analysis. "
368 "Install with: pip install skan")
370 all_branch_data = []
371 z_spacing, y_spacing, x_spacing = voxel_spacing
373 for z_idx, slice_skeleton in enumerate(skeleton_stack):
374 if slice_skeleton.any(): # Skip empty slices
375 # 2D analysis with XY spacing only
376 skeleton_obj = Skeleton(slice_skeleton, spacing=(y_spacing, x_spacing))
377 slice_data = summarize(skeleton_obj, separator='_')
379 if len(slice_data) > 0:
380 # Add Z-coordinate information
381 slice_data['z_slice'] = z_idx
382 slice_data['z_coord'] = z_idx * z_spacing
383 slice_data['skeleton_id'] = f"slice_{z_idx:03d}"
385 all_branch_data.append(slice_data)
387 # Combine all slices
388 if all_branch_data:
389 combined_data = pd.concat(all_branch_data, ignore_index=True)
390 logger.debug(f"2D analysis: {len(combined_data)} branches across "
391 f"{len(all_branch_data)} slices")
392 return combined_data
393 else:
394 logger.warning("No skeleton data found in any slice")
395 return pd.DataFrame()
398def _create_output_array_2d(slice_image, slice_binary, slice_skeleton, output_mode):
399 """Generate 2D output array based on specified mode."""
401 if output_mode == OutputMode.SKELETON:
402 # Return binary skeleton
403 return slice_skeleton.astype(np.uint8) * 255
405 elif output_mode == OutputMode.SKELETON_OVERLAY:
406 # Overlay skeleton on original image
407 output = slice_image.copy()
408 # Highlight skeleton pixels with maximum intensity
409 if slice_skeleton.any():
410 output[slice_skeleton] = slice_image.max()
411 return output
413 elif output_mode == OutputMode.ORIGINAL:
414 # Return original unchanged
415 return slice_image.copy()
417 elif output_mode == OutputMode.COMPOSITE:
418 # Side-by-side: original | binary | skeleton
419 y, x = slice_image.shape
420 composite = np.zeros((y, x * 3), dtype=slice_image.dtype)
422 # Original image
423 composite[:, :x] = slice_image
425 # Binary segmentation (scaled to match original intensity range)
426 binary_scaled = (slice_binary.astype(np.float32) * slice_image.max()).astype(slice_image.dtype)
427 composite[:, x:2*x] = binary_scaled
429 # Skeleton (scaled to match original intensity range)
430 skeleton_scaled = (slice_skeleton.astype(np.float32) * slice_image.max()).astype(slice_image.dtype)
431 composite[:, 2*x:3*x] = skeleton_scaled
433 return composite
435 else:
436 raise ValueError(f"Unknown output_mode: {output_mode}")
439def _create_output_array(image_stack, binary_stack, skeleton_stack, branch_data, output_mode):
440 """Generate output array based on specified mode (legacy function, kept for compatibility)."""
442 if output_mode == OutputMode.SKELETON:
443 # Return binary skeleton
444 return skeleton_stack.astype(np.uint8) * 255
446 elif output_mode == OutputMode.SKELETON_OVERLAY:
447 # Overlay skeleton on original image
448 output = image_stack.copy()
449 # Highlight skeleton pixels with maximum intensity
450 if skeleton_stack.any():
451 output[skeleton_stack] = image_stack.max()
452 return output
454 elif output_mode == OutputMode.ORIGINAL:
455 # Return original unchanged
456 return image_stack.copy()
458 elif output_mode == OutputMode.COMPOSITE:
459 # Side-by-side: original | binary | skeleton
460 z, y, x = image_stack.shape
461 composite = np.zeros((z, y, x * 3), dtype=image_stack.dtype)
463 # Original image
464 composite[:, :, :x] = image_stack
466 # Binary segmentation (scaled to match original intensity range)
467 binary_scaled = (binary_stack.astype(np.float32) * image_stack.max()).astype(image_stack.dtype)
468 composite[:, :, x:2*x] = binary_scaled
470 # Skeleton (scaled to match original intensity range)
471 skeleton_scaled = (skeleton_stack.astype(np.float32) * image_stack.max()).astype(image_stack.dtype)
472 composite[:, :, 2*x:3*x] = skeleton_scaled
474 return composite
476 else:
477 raise ValueError(f"Unknown output_mode: {output_mode}")
480def _compile_analysis_results(branch_data, skeleton_stack, binary_stack, image_stack,
481 voxel_spacing, analysis_type, threshold_method,
482 min_object_size, min_branch_length):
483 """Compile complete analysis results."""
485 # Compute summary metrics
486 summary = _compute_summary_metrics(branch_data, skeleton_stack.shape, voxel_spacing)
488 # Add segmentation metrics
489 total_voxels = np.prod(image_stack.shape)
490 binary_voxels = np.sum(binary_stack)
491 skeleton_voxels = np.sum(skeleton_stack)
493 segmentation_metrics = {
494 'total_voxels': int(total_voxels),
495 'segmented_voxels': int(binary_voxels),
496 'skeleton_voxels': int(skeleton_voxels),
497 'segmentation_fraction': float(binary_voxels / total_voxels),
498 'skeleton_fraction': float(skeleton_voxels / binary_voxels) if binary_voxels > 0 else 0.0,
499 }
501 # Combine all results
502 results = {
503 'summary': {**summary, **segmentation_metrics},
504 'branch_data': branch_data.to_dict('list') if len(branch_data) > 0 else {},
505 'metadata': {
506 'analysis_type': analysis_type,
507 'voxel_spacing': voxel_spacing,
508 'threshold_method': threshold_method.value,
509 'min_object_size': min_object_size,
510 'min_branch_length': min_branch_length,
511 'image_shape': image_stack.shape,
512 'image_dtype': str(image_stack.dtype),
513 'intensity_range': (float(image_stack.min()), float(image_stack.max())),
514 'processing_timestamp': datetime.now().isoformat(),
515 'skan_version': _get_skan_version(),
516 }
517 }
519 return results
522def _compute_summary_metrics(branch_data, skeleton_shape, voxel_spacing):
523 """Compute summary statistics from branch data."""
524 if len(branch_data) == 0:
525 return {
526 'total_axon_length': 0.0,
527 'num_branches': 0,
528 'num_junction_points': 0,
529 'num_endpoints': 0,
530 'mean_branch_length': 0.0,
531 'max_branch_length': 0.0,
532 'mean_tortuosity': 0.0,
533 'network_density': 0.0,
534 'branching_ratio': 0.0,
535 'total_volume': float(np.prod(skeleton_shape) * np.prod(voxel_spacing)),
536 }
538 # Basic metrics
539 total_length = branch_data['branch_distance'].sum()
540 num_branches = len(branch_data)
541 mean_length = branch_data['branch_distance'].mean()
542 max_length = branch_data['branch_distance'].max()
544 # Tortuosity (branch_distance / euclidean_distance)
545 tortuosity = branch_data['branch_distance'] / (branch_data['euclidean_distance'] + 1e-8)
546 mean_tortuosity = tortuosity.mean()
548 # Count junction points and endpoints based on branch types
549 # Branch types: 0=endpoint-endpoint, 1=junction-endpoint, 2=junction-junction, 3=cycle
550 junction_branches = branch_data[branch_data['branch_type'].isin([1, 2])]
551 num_junction_points = len(junction_branches['node_id_src'].unique()) if len(junction_branches) > 0 else 0
553 endpoint_branches = branch_data[branch_data['branch_type'].isin([0, 1])]
554 num_endpoints = len(endpoint_branches) * 2 if len(endpoint_branches) > 0 else 0 # Each branch has 2 endpoints
556 # Volume and density
557 total_volume = float(np.prod(skeleton_shape) * np.prod(voxel_spacing))
558 network_density = num_branches / total_volume if total_volume > 0 else 0.0
560 # Branching ratio
561 branching_ratio = num_junction_points / num_endpoints if num_endpoints > 0 else 0.0
563 return {
564 'total_axon_length': float(total_length),
565 'num_branches': int(num_branches),
566 'num_junction_points': int(num_junction_points),
567 'num_endpoints': int(num_endpoints),
568 'mean_branch_length': float(mean_length),
569 'max_branch_length': float(max_length),
570 'mean_tortuosity': float(mean_tortuosity),
571 'network_density': float(network_density),
572 'branching_ratio': float(branching_ratio),
573 'total_volume': total_volume,
574 }
577def _get_skan_version():
578 """Get skan library version."""
579 try:
580 import skan
581 return skan.__version__
582 except (ImportError, AttributeError):
583 return "unknown"