Coverage for openhcs/processing/backends/analysis/hmm_axon_torbi.py: 5.2%
255 statements
« prev ^ index » next coverage.py v7.10.3, created at 2025-08-14 05:57 +0000
« prev ^ index » next coverage.py v7.10.3, created at 2025-08-14 05:57 +0000
1# Torbi-accelerated version of hmm_axon.py
2"""
3OpenHCS-compatible neurite tracing using alvahmm RRS algorithm with torbi GPU acceleration.
5This version uses torbi for GPU-accelerated Viterbi decoding while maintaining
6the same API as the original CPU version.
7"""
9import numpy as np
10import networkx as nx
11import skimage
12import math
13from enum import Enum
14from typing import Tuple, Dict, List, Optional, Any
15from skimage.feature import canny, blob_dog as local_max
16from skimage.filters import median, threshold_li
17from skimage.morphology import remove_small_objects, skeletonize
18from openhcs.core.memory.decorators import torch as torch_func
19from openhcs.core.pipeline.function_contracts import special_outputs
21# Import torch using the established optional import pattern
22from openhcs.core.utils import optional_import
23torch = optional_import("torch")
25# Import torbi for GPU-accelerated Viterbi decoding
26torbi = optional_import("torbi")
29def materialize_hmm_analysis(
30 hmm_analysis_data: Dict[str, Any],
31 path: str,
32 filemanager,
33 **kwargs
34) -> str:
35 """
36 Materialize HMM neurite tracing analysis results to disk.
38 Creates multiple output files:
39 - JSON file with graph data and summary metrics
40 - GraphML file with the NetworkX graph
41 - CSV file with edge data
43 Args:
44 hmm_analysis_data: The HMM analysis results dictionary
45 path: Base path for output files (from special output path)
46 filemanager: FileManager instance for consistent I/O
47 **kwargs: Additional materialization options
49 Returns:
50 str: Path to the primary output file (JSON summary)
51 """
52 import json
53 import networkx as nx
54 from pathlib import Path
55 from openhcs.constants.constants import Backend
57 # Generate output file paths
58 base_path = path.replace('.pkl', '')
59 json_path = f"{base_path}.json"
60 graphml_path = f"{base_path}_graph.graphml"
61 csv_path = f"{base_path}_edges.csv"
63 # Ensure output directory exists
64 output_dir = Path(json_path).parent
65 filemanager.ensure_directory(str(output_dir), Backend.DISK.value)
67 # 1. Save summary and metadata as JSON (primary output)
68 summary_data = {
69 'analysis_type': 'hmm_neurite_tracing_torbi',
70 'summary': hmm_analysis_data['summary'],
71 'metadata': hmm_analysis_data['metadata']
72 }
73 json_content = json.dumps(summary_data, indent=2, default=str)
74 filemanager.save(json_content, json_path, Backend.DISK.value)
76 # 2. Save NetworkX graph as GraphML
77 graph = hmm_analysis_data['graph']
78 if graph and graph.number_of_nodes() > 0:
79 # Use direct file I/O for GraphML (NetworkX doesn't support string I/O)
80 nx.write_graphml(graph, graphml_path)
82 # 3. Save edge data as CSV
83 if graph.number_of_edges() > 0:
84 import pandas as pd
85 edge_data = []
86 for u, v, data in graph.edges(data=True):
87 edge_info = {
88 'source_x': u[0], 'source_y': u[1],
89 'target_x': v[0], 'target_y': v[1],
90 **data # Include any edge attributes
91 }
92 edge_data.append(edge_info)
94 edge_df = pd.DataFrame(edge_data)
95 csv_content = edge_df.to_csv(index=False)
96 filemanager.save(csv_content, csv_path, Backend.DISK.value)
98 return json_path
101def materialize_trace_visualizations(data: List[np.ndarray], path: str, filemanager) -> str:
102 """Materialize trace visualizations as individual TIFF files."""
104 if not data:
105 # Create empty summary file to indicate no visualizations were generated
106 summary_path = path.replace('.pkl', '_trace_summary.txt')
107 summary_content = "No trace visualizations generated (return_trace_visualizations=False)\n"
108 from openhcs.constants.constants import Backend
109 filemanager.save(summary_content, summary_path, Backend.DISK.value)
110 return summary_path
112 # Generate output file paths based on the input path
113 base_path = path.replace('.pkl', '')
115 # Save each visualization as a separate TIFF file
116 for i, visualization in enumerate(data):
117 viz_filename = f"{base_path}_slice_{i:03d}.tif"
119 # Convert visualization to appropriate dtype for saving (uint16 to match input images)
120 if visualization.dtype != np.uint16:
121 # Normalize to uint16 range if needed
122 if visualization.max() <= 1.0:
123 viz_uint16 = (visualization * 65535).astype(np.uint16)
124 else:
125 viz_uint16 = visualization.astype(np.uint16)
126 else:
127 viz_uint16 = visualization
129 # Save using filemanager
130 from openhcs.constants.constants import Backend
131 filemanager.save(viz_uint16, viz_filename, Backend.DISK.value)
133 # Return summary path
134 summary_path = f"{base_path}_trace_summary.txt"
135 summary_content = f"Trace visualizations saved: {len(data)} files\n"
136 summary_content += f"Base filename pattern: {base_path}_slice_XXX.tif\n"
137 summary_content += f"Visualization dtype: {data[0].dtype}\n"
138 summary_content += f"Visualization shape: {data[0].shape}\n"
140 filemanager.save(summary_content, summary_path, Backend.DISK.value)
142 return summary_path
144# Import alvahmm - use torbi version from GitHub dependency
145from alva_machinery.markov import aChain_torbi as alva_MCMC_torbi
146from alva_machinery.branching import aWay as alva_branch
149class SeedingMethod(Enum):
150 """Seeding methods for neurite tracing."""
151 RANDOM = "random" # Paper's original method - random seeds across entire image
152 BLOB_DETECTION = "blob" # Enhanced method - seeds on detected blob structures
153 CANNY_EDGES = "canny" # Alternative - seeds on Canny edge detection
154 GROWTH_CONES = "growth_cones" # Alternative - seeds on detected growth cones
157class VisualizationMode(Enum):
158 """Visualization modes for trace output."""
159 NONE = "none" # Return zeros array (no visualization)
160 TRACE_ONLY = "trace" # Show only traced neurites (binary mask)
161 OVERLAY = "overlay" # Show original image with traced neurites overlaid
164class OutputMode(Enum):
165 """Output visualization modes."""
166 TRACE_ONLY = "trace_only" # Binary mask of traced neurites only
167 OVERLAY = "overlay" # Original image with traces overlaid
168 NONE = "none" # Return original image unchanged
170def normalize(img,percentile=99.9):
171 percentile_value = np.percentile(img, percentile)
172 img = img / percentile_value # Scale the image to the nth percentile value
173 img = np.clip(img, 0, 100) # You can change 1 to 100 if you want percentages
174 #img = img - img.min()
175 #img = img / img.max()
176 return img
178def boundary_masking_canny(image):
179 bool_im_axon_edit = canny(image)
180 bool_im_axon_edit[:,:2] = False
181 bool_im_axon_edit[:,-2:] = False
182 bool_im_axon_edit[:2,:] = False
183 bool_im_axon_edit[-2:,:] = False
184 return np.array(bool_im_axon_edit,dtype=np.int64)
186def boundary_masking_threshold(image,threshold=threshold_li,min_size=2):
187 threshed=threshold(image)
188 bool_image = image > threshed
189 bool_image[:,:2] = False
190 bool_image[:,-2:] = False
191 bool_image[:2,:] = False
192 bool_image[-2:,:] = False
193 cleaned_bool_im_axon_edit = skeletonize(bool_image)
194 return np.array(bool_image,dtype=np.int64)
196def boundary_masking_blob(image,min_sigma = 1, max_sigma = 2, threshold = 0.02):
197 if min_sigma is None:
198 min_sigma = 1
199 if max_sigma is None:
200 max_sigma = 2
201 if threshold is None:
202 threshold = 0.02
204 image_median = median(image)
205 galaxy = local_max(image_median, min_sigma = min_sigma, max_sigma = max_sigma, threshold = threshold)
206 yy = np.int64(galaxy[:, 0])
207 xx = np.int64(galaxy[:, 1])
208 boundary_mask = np.copy(image) * 0
209 boundary_mask[yy, xx] = 1
210 return boundary_mask
212def random_seed_by_edge_map(edge_map):
213 """Generate random seeds from detected edge/blob locations."""
214 yy, xx = edge_map.nonzero()
215 if len(xx) == 0:
216 # No edges detected, fall back to random seeding
217 return generate_random_seeds(edge_map.shape, num_seeds=100)
218 seed_index = np.random.choice(len(xx), len(xx))
219 seed_xx = xx[seed_index]
220 seed_yy = yy[seed_index]
221 return seed_xx, seed_yy
224def generate_random_seeds(image_shape: Tuple[int, int], num_seeds: int = 100):
225 """Generate completely random seeds across the entire image (paper's original method)."""
226 height, width = image_shape
227 seed_xx = np.random.randint(0, width, num_seeds)
228 seed_yy = np.random.randint(0, height, num_seeds)
229 return seed_xx, seed_yy
232def generate_seeds_by_method(
233 image: np.ndarray,
234 method: SeedingMethod = SeedingMethod.BLOB_DETECTION,
235 num_seeds: int = 100,
236 min_sigma: float = 1.0,
237 max_sigma: float = 2.0,
238 threshold: float = 0.02
239) -> Tuple[np.ndarray, np.ndarray]:
240 """
241 Generate seeds using the specified method.
243 Args:
244 image: Input image for seed generation
245 method: Seeding method to use
246 num_seeds: Number of seeds for random method
247 min_sigma: Min sigma for blob detection
248 max_sigma: Max sigma for blob detection
249 threshold: Threshold for blob detection
251 Returns:
252 seed_xx, seed_yy: Arrays of seed coordinates
253 """
254 if method == SeedingMethod.RANDOM:
255 # Paper's original method - pure random seeding
256 return generate_random_seeds(image.shape, num_seeds)
258 elif method == SeedingMethod.BLOB_DETECTION:
259 # Enhanced method - seeds on detected blobs
260 edge_map = boundary_masking_blob(image, min_sigma, max_sigma, threshold)
261 return random_seed_by_edge_map(edge_map)
263 elif method == SeedingMethod.CANNY_EDGES:
264 # Alternative - seeds on Canny edges
265 edge_map = boundary_masking_canny(image)
266 return random_seed_by_edge_map(edge_map)
268 elif method == SeedingMethod.GROWTH_CONES:
269 # Alternative - seeds on growth cones
270 return get_growth_cone_positions(image)
272 else:
273 raise ValueError(f"Unknown seeding method: {method}")
275def get_growth_cone_positions(image):
276 """
277 Detect growth cone positions using morphological operations (OpenHCS-compatible).
279 Args:
280 image: Input image for growth cone detection
282 Returns:
283 seed_xx, seed_yy: Arrays of growth cone center coordinates
284 """
285 # Threshold the image to create a binary mask
286 mask = image > skimage.filters.threshold_otsu(image)
288 # Use morphological closing to fill in small gaps in the mask
289 mask = skimage.morphology.closing(mask, skimage.morphology.disk(3))
290 labeled = skimage.measure.label(mask)
291 props = skimage.measure.regionprops(labeled)
293 seed_xx = []
294 seed_yy = []
295 for prop in props:
296 seed_xx.append(prop.centroid[1]) # x coordinate
297 seed_yy.append(prop.centroid[0]) # y coordinate
299 return np.array(seed_xx), np.array(seed_yy)
301def selected_seeding(image,seed_xx,seed_yy,chain_level=1.05,total_node=8,node_r=None,line_length_min=32):
302 """Original CPU version for reference."""
303 im_copy=np.copy(image)
304 alva_HMM = alva_MCMC.AlvaHmm(im_copy,
305 total_node = total_node,
306 total_path = None,
307 node_r = node_r,
308 node_angle_max = None,)
309 chain_HMM_1st, pair_chain_HMM, pair_seed_xx, pair_seed_yy = alva_HMM.pair_HMM_chain(seed_xx = seed_xx,
310 seed_yy = seed_yy,
311 chain_level = chain_level,)
312 for chain_i in [0, 1]:
313 chain_HMM = [chain_HMM_1st, pair_chain_HMM][chain_i]
314 real_chain_ii, real_chain_aa, real_chain_xx, real_chain_yy = chain_HMM[0:4]
315 seed_node_xx, seed_node_yy = chain_HMM[4:6]
317 chain_im_fine = alva_HMM.chain_image(chain_HMM_1st, pair_chain_HMM,)
318 return alva_branch.connect_way(chain_im_fine,
319 line_length_min = line_length_min,
320 free_zone_from_y0 = None,)
323def selected_seeding_torbi(image, seed_xx, seed_yy, chain_level=1.05, total_node=8, node_r=None, line_length_min=32, device=None):
324 """
325 Torbi-accelerated version of selected_seeding with batched processing.
327 Processes all seeds in parallel using torbi's batch capabilities for maximum GPU utilization.
328 """
329 print(f"🚀 Processing {len(seed_xx)} seeds in parallel with torbi GPU acceleration")
331 im_copy = np.copy(image)
333 # Use torbi-accelerated HMM class with batched processing
334 alva_HMM = alva_MCMC_torbi.AlvaHmmTorbi(
335 im_copy,
336 total_node=total_node,
337 total_path=None,
338 node_r=node_r,
339 node_angle_max=None,
340 device=device
341 )
343 # Perform batched bidirectional HMM tracing with torbi acceleration
344 chain_HMM_1st, pair_chain_HMM, pair_seed_xx, pair_seed_yy = alva_HMM.pair_HMM_chain_batched(
345 seed_xx=seed_xx,
346 seed_yy=seed_yy,
347 chain_level=chain_level
348 )
350 for chain_i in [0, 1]:
351 chain_HMM = [chain_HMM_1st, pair_chain_HMM][chain_i]
352 real_chain_ii, real_chain_aa, real_chain_xx, real_chain_yy = chain_HMM[0:4]
353 seed_node_xx, seed_node_yy = chain_HMM[4:6]
355 chain_im_fine = alva_HMM.chain_image(chain_HMM_1st, pair_chain_HMM)
356 return alva_branch.connect_way(
357 chain_im_fine,
358 line_length_min=line_length_min,
359 free_zone_from_y0=None
360 )
362def euclidian_distance(x1, y1, x2, y2):
363 distance = math.sqrt((x1 - x2) ** 2 + (y1 - y2) ** 2)
364 return distance
366def extract_graph(root_tree_xx,root_tree_yy):
367 graph = nx.Graph()
368 for path_x,path_y in zip(root_tree_xx,root_tree_yy):
369 for x,y in zip(path_x,path_y):
370 graph.add_node((x,y))
371 for i in range(len(path_x)-1):
372 distance=euclidian_distance(path_x[i], path_y[i], path_x[i + 1], path_y[i + 1])
373 graph.add_edge((path_x[i], path_y[i]), (path_x[i + 1], path_y[i + 1]),weight=distance)
374 return graph
376def graph_to_length(graph):
377 total_distance = 0
378 for u, v, data in graph.edges(data=True):
379 total_distance += data['weight']
380 return total_distance
382def create_overlay_from_graph(original_image: np.ndarray, graph: nx.Graph) -> np.ndarray:
383 """
384 Create overlay visualization with traces on original image.
386 Args:
387 original_image: Original 2D input image (numpy array)
388 graph: NetworkX graph containing trace coordinates as (x, y) tuples
390 Returns:
391 2D overlay array with traces highlighted on original image
392 """
393 overlay = original_image.copy()
394 # Set trace pixels to maximum intensity for visibility
395 max_val = np.max(original_image) if original_image.size > 0 else 1
397 for u, v in graph.edges:
398 x1, y1 = u # Nodes are stored as (x, y) tuples
399 x2, y2 = v # Nodes are stored as (x, y) tuples
400 # Bounds checking
401 if (0 <= y1 < original_image.shape[0] and 0 <= x1 < original_image.shape[1]):
402 overlay[y1, x1] = max_val
403 if (0 <= y2 < original_image.shape[0] and 0 <= x2 < original_image.shape[1]):
404 overlay[y2, x2] = max_val
406 return overlay
408def create_visualization_array(
409 original_image: np.ndarray,
410 graph: nx.Graph,
411 mode: VisualizationMode
412) -> np.ndarray:
413 """
414 Create visualization array based on the specified mode.
416 Args:
417 original_image: Original 2D image
418 graph: NetworkX graph with traced neurites
419 mode: Visualization mode
421 Returns:
422 2D array for visualization
423 """
424 if mode == VisualizationMode.NONE:
425 # Return zeros array
426 return np.zeros_like(original_image, dtype=original_image.dtype)
428 elif mode == VisualizationMode.TRACE_ONLY:
429 # Create binary mask with traced neurites
430 trace_mask = np.zeros_like(original_image, dtype=np.uint8)
431 for u, v in graph.edges:
432 x1, y1 = u # Nodes are stored as (x, y) tuples
433 x2, y2 = v # Nodes are stored as (x, y) tuples
434 # Bounds checking
435 if (0 <= y1 < original_image.shape[0] and 0 <= x1 < original_image.shape[1]):
436 trace_mask[y1, x1] = 1
437 if (0 <= y2 < original_image.shape[0] and 0 <= x2 < original_image.shape[1]):
438 trace_mask[y2, x2] = 1
439 return trace_mask
441 elif mode == VisualizationMode.OVERLAY:
442 # Use shared overlay function
443 return create_overlay_from_graph(original_image, graph)
445 else:
446 raise ValueError(f"Unknown visualization mode: {mode}")
448@special_outputs(("hmm_analysis", materialize_hmm_analysis), ("trace_visualizations", materialize_trace_visualizations))
449@torch_func
450def trace_neurites_rrs_alva_torbi(
451 image_stack: torch.Tensor,
452 seeding_method: SeedingMethod = SeedingMethod.BLOB_DETECTION,
453 return_trace_visualizations: bool = False,
454 trace_visualization_mode: VisualizationMode = VisualizationMode.TRACE_ONLY,
455 chain_level: float = 1.05,
456 node_r: Optional[int] = None,
457 total_node: Optional[int] = None,
458 line_length_min: int = 32,
459 num_seeds: int = 100,
460 min_sigma: float = 1.0,
461 max_sigma: float = 2.0,
462 threshold: float = 0.02,
463 normalize_image: bool = False,
464 percentile: float = 99.9
465) -> Tuple[torch.Tensor, Dict[str, Any], List[np.ndarray]]:
466 """
467 Trace neurites using the alvahmm RRS algorithm with torbi GPU acceleration.
469 This is the GPU-accelerated version using torbi for Viterbi decoding while
470 maintaining the same API as the CPU version. Falls back to CPU if torbi
471 is not available.
473 Args:
474 image_stack: 3D torch tensor of shape (Z, Y, X) - input image stack
475 seeding_method: Method for seed generation (RANDOM=paper default, BLOB_DETECTION=enhanced)
476 return_trace_visualizations: Whether to generate trace visualizations as special output
477 trace_visualization_mode: How to visualize results (NONE, TRACE_ONLY, OVERLAY)
478 chain_level: Validation threshold for HMM chains (default: 1.05)
479 node_r: Path length between adjacent nodes (default: None, uses alvahmm default)
480 total_node: Number of HMM nodes in chain (default: None, uses alvahmm default)
481 line_length_min: Minimum line length for connection (default: 32)
482 num_seeds: Number of seeds for random seeding method (default: 100)
483 min_sigma: Minimum sigma for blob detection (default: 1.0)
484 max_sigma: Maximum sigma for blob detection (default: 2.0)
485 threshold: Threshold for blob detection (default: 0.02)
486 normalize_image: Whether to apply percentile normalization (default: False, paper doesn't use)
487 percentile: Percentile for normalization if enabled (default: 99.9)
489 Returns:
490 result_image: Original image stack unchanged (Z, Y, X)
491 analysis_results: HMM analysis data structure with graph and metrics
492 trace_visualizations: (Special output) List of visualization arrays if return_trace_visualizations=True
493 """
494 # Validate input is 3D
495 if image_stack.ndim != 3:
496 raise ValueError(f"Expected 3D tensor, got {image_stack.ndim}D")
498 # Get device from input tensor
499 device = image_stack.device
501 # Process each slice individually
502 Z, Y, X = image_stack.shape
503 all_graphs = []
504 trace_visualizations = []
506 for z in range(Z):
507 im_axon = image_stack[z].cpu().numpy().astype(np.float64)
509 # Optional normalization (removed from default, paper doesn't use)
510 if normalize_image:
511 im_axon = normalize(im_axon, percentile=percentile)
513 # Generate seeds using selected method
514 seed_xx, seed_yy = generate_seeds_by_method(
515 im_axon,
516 method=seeding_method,
517 num_seeds=num_seeds,
518 min_sigma=min_sigma,
519 max_sigma=max_sigma,
520 threshold=threshold
521 )
523 # Perform RRS tracing with bidirectional HMM chains (torbi-accelerated)
524 root_tree_yy, root_tree_xx, root_tip_yy, root_tip_xx = selected_seeding_torbi(
525 im_axon,
526 seed_xx,
527 seed_yy,
528 chain_level=chain_level,
529 node_r=node_r,
530 total_node=total_node,
531 line_length_min=line_length_min,
532 device=device
533 )
535 # Extract graph representation for this slice
536 graph = extract_graph(root_tree_xx, root_tree_yy)
537 all_graphs.append(graph)
539 # Create visualization for this slice if requested
540 if return_trace_visualizations:
541 visualization = create_visualization_array(im_axon, graph, trace_visualization_mode)
542 trace_visualizations.append(visualization)
544 # Combine all graphs (for compatibility, return the first one)
545 combined_graph = all_graphs[0] if all_graphs else nx.Graph()
547 # Compile analysis results
548 analysis_results = _compile_hmm_analysis_results(
549 combined_graph, all_graphs, image_stack.shape,
550 seeding_method, trace_visualization_mode, chain_level,
551 node_r, total_node, line_length_min
552 )
554 # Always return original image, analysis results, and trace visualizations
555 return image_stack, analysis_results, trace_visualizations
558def _compile_hmm_analysis_results(
559 combined_graph: nx.Graph,
560 all_graphs: List[nx.Graph],
561 image_shape: Tuple[int, int, int],
562 seeding_method, # SeedingMethod enum
563 visualization_mode, # VisualizationMode enum
564 chain_level: float,
565 node_r: Optional[int],
566 total_node: Optional[int],
567 line_length_min: int
568) -> Dict[str, Any]:
569 """Compile comprehensive HMM analysis results for torbi version."""
570 from datetime import datetime
572 # Compute summary metrics from the graph
573 num_nodes = combined_graph.number_of_nodes()
574 num_edges = combined_graph.number_of_edges()
576 # Calculate total trace length
577 total_length = 0.0
578 edge_lengths = []
579 for u, v, data in combined_graph.edges(data=True):
580 # Calculate Euclidean distance between nodes
581 x1, y1 = u
582 x2, y2 = v
583 length = ((x2 - x1)**2 + (y2 - y1)**2)**0.5
584 edge_lengths.append(length)
585 total_length += length
587 # Summary metrics
588 summary = {
589 'total_trace_length': float(total_length),
590 'num_nodes': int(num_nodes),
591 'num_edges': int(num_edges),
592 'num_slices_processed': len(all_graphs),
593 'mean_edge_length': float(sum(edge_lengths) / len(edge_lengths)) if edge_lengths else 0.0,
594 'max_edge_length': float(max(edge_lengths)) if edge_lengths else 0.0,
595 'graph_density': float(nx.density(combined_graph)) if num_nodes > 1 else 0.0,
596 'num_connected_components': int(nx.number_connected_components(combined_graph)),
597 }
599 # Metadata
600 metadata = {
601 'algorithm': 'alvahmm_rrs_torbi',
602 'seeding_method': seeding_method.value,
603 'visualization_mode': visualization_mode.value,
604 'chain_level': chain_level,
605 'node_r': node_r,
606 'total_node': total_node,
607 'line_length_min': line_length_min,
608 'image_shape': image_shape,
609 'processing_timestamp': datetime.now().isoformat(),
610 'gpu_accelerated': True,
611 }
613 return {
614 'summary': summary,
615 'graph': combined_graph,
616 'metadata': metadata
617 }
620# Legacy file-based processing function (kept for reference)
621def process_file_legacy(filename, input_folder, output_folder, **kwargs):
622 """
623 Legacy file-based processing function.
625 This is kept for reference but should not be used in OpenHCS.
626 Use trace_neurites_rrs_alva() instead for array-in/array-out processing.
627 """
628 raise NotImplementedError(
629 "Legacy file-based processing not supported in OpenHCS. "
630 "Use trace_neurites_rrs_alva() for array-in/array-out processing."
631 )