Coverage for openhcs/processing/backends/analysis/hmm_axon.py: 3.8%
241 statements
« prev ^ index » next coverage.py v7.11.0, created at 2025-11-04 02:09 +0000
« prev ^ index » next coverage.py v7.11.0, created at 2025-11-04 02:09 +0000
1# BACKUP: hmm_axon_backup.py created before OpenHCS conversion
2"""
3OpenHCS-compatible neurite tracing using alvahmm RRS algorithm.
5Converted from file-based processing to pure array-in/array-out functions
6following OpenHCS patterns.
7"""
9import numpy as np
10import networkx as nx
11import skimage
12import math
13from enum import Enum
14from typing import Tuple, Dict, List, Optional, Any
15from skimage.feature import canny, blob_dog as local_max
16from skimage.filters import median, threshold_li
17from skimage.morphology import skeletonize
18from openhcs.core.memory.decorators import numpy
19from openhcs.core.pipeline.function_contracts import special_outputs
21# Import alvahmm from GitHub dependency
22from alva_machinery.markov import aChain as alva_MCMC
23from alva_machinery.branching import aWay as alva_branch
26def materialize_hmm_analysis(
27 hmm_analysis_data: Dict[str, Any],
28 path: str,
29 filemanager,
30 backend: str,
31 **kwargs
32) -> str:
33 """
34 Materialize HMM neurite tracing analysis results to disk.
36 Creates multiple output files:
37 - JSON file with graph data and summary metrics
38 - GraphML file with the NetworkX graph
39 - CSV file with edge data
41 Args:
42 hmm_analysis_data: The HMM analysis results dictionary
43 path: Base path for output files (from special output path)
44 filemanager: FileManager instance for consistent I/O
45 backend: Backend to use for materialization
46 **kwargs: Additional materialization options
48 Returns:
49 str: Path to the primary output file (JSON summary)
50 """
51 import json
52 import networkx as nx
53 from pathlib import Path
54 from openhcs.constants.constants import Backend
56 # Generate output file paths
57 base_path = path.replace('.pkl', '')
58 json_path = f"{base_path}.json"
59 graphml_path = f"{base_path}_graph.graphml"
60 csv_path = f"{base_path}_edges.csv"
62 # Ensure output directory exists for disk backend
63 output_dir = Path(json_path).parent
64 if backend == Backend.DISK.value:
65 filemanager.ensure_directory(str(output_dir), backend)
67 # 1. Save summary and metadata as JSON (primary output)
68 summary_data = {
69 'analysis_type': 'hmm_neurite_tracing',
70 'summary': hmm_analysis_data['summary'],
71 'metadata': hmm_analysis_data['metadata']
72 }
73 json_content = json.dumps(summary_data, indent=2, default=str)
74 filemanager.save(json_content, json_path, backend)
76 # 2. Save NetworkX graph as GraphML
77 graph = hmm_analysis_data['graph']
78 if graph and graph.number_of_nodes() > 0:
79 # Use direct file I/O for GraphML (NetworkX doesn't support string I/O)
80 # Note: NetworkX requires actual file path, not compatible with OMERO backend
81 if backend == Backend.DISK.value:
82 nx.write_graphml(graph, graphml_path)
84 # 3. Save edge data as CSV
85 if graph.number_of_edges() > 0:
86 import pandas as pd
87 edge_data = []
88 for u, v, data in graph.edges(data=True):
89 edge_info = {
90 'source_x': u[0], 'source_y': u[1],
91 'target_x': v[0], 'target_y': v[1],
92 **data # Include any edge attributes
93 }
94 edge_data.append(edge_info)
96 edge_df = pd.DataFrame(edge_data)
97 csv_content = edge_df.to_csv(index=False)
98 filemanager.save(csv_content, csv_path, backend)
100 return json_path
103def materialize_trace_visualizations(data: List[np.ndarray], path: str, filemanager, backend: str) -> str:
104 """Materialize trace visualizations as individual TIFF files."""
106 if not data:
107 # Create empty summary file to indicate no visualizations were generated
108 summary_path = path.replace('.pkl', '_trace_summary.txt')
109 summary_content = "No trace visualizations generated (return_trace_visualizations=False)\n"
110 filemanager.save(summary_content, summary_path, backend)
111 return summary_path
113 # Generate output file paths based on the input path
114 base_path = path.replace('.pkl', '')
116 # Save each visualization as a separate TIFF file
117 for i, visualization in enumerate(data):
118 viz_filename = f"{base_path}_slice_{i:03d}.tif"
120 # Convert visualization to appropriate dtype for saving (uint16 to match input images)
121 if visualization.dtype != np.uint16:
122 # Normalize to uint16 range if needed
123 if visualization.max() <= 1.0:
124 viz_uint16 = (visualization * 65535).astype(np.uint16)
125 else:
126 viz_uint16 = visualization.astype(np.uint16)
127 else:
128 viz_uint16 = visualization
130 # Save using filemanager
131 from openhcs.constants.constants import Backend
132 filemanager.save(viz_uint16, viz_filename, Backend.DISK.value)
134 # Return summary path
135 summary_path = f"{base_path}_trace_summary.txt"
136 summary_content = f"Trace visualizations saved: {len(data)} files\n"
137 summary_content += f"Base filename pattern: {base_path}_slice_XXX.tif\n"
138 summary_content += f"Visualization dtype: {data[0].dtype}\n"
139 summary_content += f"Visualization shape: {data[0].shape}\n"
141 filemanager.save(summary_content, summary_path, Backend.DISK.value)
143 return summary_path
146class SeedingMethod(Enum):
147 """Seeding methods for neurite tracing."""
148 RANDOM = "random" # Paper's original method - random seeds across entire image
149 BLOB_DETECTION = "blob" # Enhanced method - seeds on detected blob structures
150 CANNY_EDGES = "canny" # Alternative - seeds on Canny edge detection
151 GROWTH_CONES = "growth_cones" # Alternative - seeds on detected growth cones
154class VisualizationMode(Enum):
155 """Visualization modes for trace output."""
156 NONE = "none" # Return zeros array (no visualization)
157 TRACE_ONLY = "trace" # Show only traced neurites (binary mask)
158 OVERLAY = "overlay" # Show original image with traced neurites overlaid
161class OutputMode(Enum):
162 """Output visualization modes."""
163 TRACE_ONLY = "trace_only" # Binary mask of traced neurites only
164 OVERLAY = "overlay" # Original image with traces overlaid
165 NONE = "none" # Return original image unchanged
167def normalize(img,percentile=99.9):
168 percentile_value = np.percentile(img, percentile)
169 img = img / percentile_value # Scale the image to the nth percentile value
170 img = np.clip(img, 0, 100) # You can change 1 to 100 if you want percentages
171 #img = img - img.min()
172 #img = img / img.max()
173 return img
175def boundary_masking_canny(image):
176 bool_im_axon_edit = canny(image)
177 bool_im_axon_edit[:,:2] = False
178 bool_im_axon_edit[:,-2:] = False
179 bool_im_axon_edit[:2,:] = False
180 bool_im_axon_edit[-2:,:] = False
181 return np.array(bool_im_axon_edit,dtype=np.int64)
183def boundary_masking_threshold(image,threshold=threshold_li,min_size=2):
184 threshed=threshold(image)
185 bool_image = image > threshed
186 bool_image[:,:2] = False
187 bool_image[:,-2:] = False
188 bool_image[:2,:] = False
189 bool_image[-2:,:] = False
190 cleaned_bool_im_axon_edit = skeletonize(bool_image)
191 return np.array(bool_image,dtype=np.int64)
193def boundary_masking_blob(image,min_sigma = 1, max_sigma = 2, threshold = 0.02):
194 if min_sigma is None:
195 min_sigma = 1
196 if max_sigma is None:
197 max_sigma = 2
198 if threshold is None:
199 threshold = 0.02
201 image_median = median(image)
202 galaxy = local_max(image_median, min_sigma = min_sigma, max_sigma = max_sigma, threshold = threshold)
203 yy = np.int64(galaxy[:, 0])
204 xx = np.int64(galaxy[:, 1])
205 boundary_mask = np.copy(image) * 0
206 boundary_mask[yy, xx] = 1
207 return boundary_mask
209def random_seed_by_edge_map(edge_map):
210 """Generate random seeds from detected edge/blob locations."""
211 yy, xx = edge_map.nonzero()
212 if len(xx) == 0:
213 # No edges detected, fall back to random seeding
214 return generate_random_seeds(edge_map.shape, num_seeds=100)
215 seed_index = np.random.choice(len(xx), len(xx))
216 seed_xx = xx[seed_index]
217 seed_yy = yy[seed_index]
218 return seed_xx, seed_yy
221def generate_random_seeds(image_shape: Tuple[int, int], num_seeds: int = 100):
222 """Generate completely random seeds across the entire image (paper's original method)."""
223 height, width = image_shape
224 seed_xx = np.random.randint(0, width, num_seeds)
225 seed_yy = np.random.randint(0, height, num_seeds)
226 return seed_xx, seed_yy
229def generate_seeds_by_method(
230 image: np.ndarray,
231 method: SeedingMethod = SeedingMethod.BLOB_DETECTION,
232 num_seeds: int = 100,
233 min_sigma: float = 1.0,
234 max_sigma: float = 2.0,
235 threshold: float = 0.02
236) -> Tuple[np.ndarray, np.ndarray]:
237 """
238 Generate seeds using the specified method.
240 Args:
241 image: Input image for seed generation
242 method: Seeding method to use
243 num_seeds: Number of seeds for random method
244 min_sigma: Min sigma for blob detection
245 max_sigma: Max sigma for blob detection
246 threshold: Threshold for blob detection
248 Returns:
249 seed_xx, seed_yy: Arrays of seed coordinates
250 """
251 if method == SeedingMethod.RANDOM:
252 # Paper's original method - pure random seeding
253 return generate_random_seeds(image.shape, num_seeds)
255 elif method == SeedingMethod.BLOB_DETECTION:
256 # Enhanced method - seeds on detected blobs
257 edge_map = boundary_masking_blob(image, min_sigma, max_sigma, threshold)
258 return random_seed_by_edge_map(edge_map)
260 elif method == SeedingMethod.CANNY_EDGES:
261 # Alternative - seeds on Canny edges
262 edge_map = boundary_masking_canny(image)
263 return random_seed_by_edge_map(edge_map)
265 elif method == SeedingMethod.GROWTH_CONES:
266 # Alternative - seeds on growth cones
267 return get_growth_cone_positions(image)
269 else:
270 raise ValueError(f"Unknown seeding method: {method}")
272def get_growth_cone_positions(image):
273 """
274 Detect growth cone positions using morphological operations (OpenHCS-compatible).
276 Args:
277 image: Input image for growth cone detection
279 Returns:
280 seed_xx, seed_yy: Arrays of growth cone center coordinates
281 """
282 # Threshold the image to create a binary mask
283 mask = image > skimage.filters.threshold_otsu(image)
285 # Use morphological closing to fill in small gaps in the mask
286 mask = skimage.morphology.closing(mask, skimage.morphology.disk(3))
287 labeled = skimage.measure.label(mask)
288 props = skimage.measure.regionprops(labeled)
290 seed_xx = []
291 seed_yy = []
292 for prop in props:
293 seed_xx.append(prop.centroid[1]) # x coordinate
294 seed_yy.append(prop.centroid[0]) # y coordinate
296 return np.array(seed_xx), np.array(seed_yy)
298def selected_seeding(image,seed_xx,seed_yy,chain_level=1.05,total_node=8,node_r=None,line_length_min=32):
299 im_copy=np.copy(image)
300 alva_HMM = alva_MCMC.AlvaHmm(im_copy,
301 total_node = total_node,
302 total_path = None,
303 node_r = node_r,
304 node_angle_max = None,)
305 chain_HMM_1st, pair_chain_HMM, pair_seed_xx, pair_seed_yy = alva_HMM.pair_HMM_chain(seed_xx = seed_xx,
306 seed_yy = seed_yy,
307 chain_level = chain_level,)
308 for chain_i in [0, 1]:
309 chain_HMM = [chain_HMM_1st, pair_chain_HMM][chain_i]
310 real_chain_ii, real_chain_aa, real_chain_xx, real_chain_yy = chain_HMM[0:4]
311 seed_node_xx, seed_node_yy = chain_HMM[4:6]
313 chain_im_fine = alva_HMM.chain_image(chain_HMM_1st, pair_chain_HMM,)
314 return alva_branch.connect_way(chain_im_fine,
315 line_length_min = line_length_min,
316 free_zone_from_y0 = None,)
318def euclidian_distance(x1, y1, x2, y2):
319 distance = math.sqrt((x1 - x2) ** 2 + (y1 - y2) ** 2)
320 return distance
322def extract_graph(root_tree_xx,root_tree_yy):
323 graph = nx.Graph()
324 for path_x,path_y in zip(root_tree_xx,root_tree_yy):
325 for x,y in zip(path_x,path_y):
326 graph.add_node((x,y))
327 for i in range(len(path_x)-1):
328 distance=euclidian_distance(path_x[i], path_y[i], path_x[i + 1], path_y[i + 1])
329 graph.add_edge((path_x[i], path_y[i]), (path_x[i + 1], path_y[i + 1]),weight=distance)
330 return graph
332def graph_to_length(graph):
333 total_distance = 0
334 for u, v, data in graph.edges(data=True):
335 total_distance += data['weight']
336 return total_distance
338def create_overlay_from_graph(original_image: np.ndarray, graph: nx.Graph) -> np.ndarray:
339 """
340 Create overlay visualization with traces on original image.
342 Args:
343 original_image: Original input image (numpy array)
344 graph: NetworkX graph containing trace coordinates as (x, y) tuples
346 Returns:
347 Overlay array with traces highlighted on original image
348 """
349 overlay = original_image.copy()
350 # Set trace pixels to maximum intensity for visibility
351 max_val = np.max(original_image) if original_image.size > 0 else 1
353 for u, v in graph.edges:
354 x1, y1 = u # Nodes are stored as (x, y) tuples
355 x2, y2 = v # Nodes are stored as (x, y) tuples
356 # Bounds checking
357 if (0 <= y1 < original_image.shape[0] and 0 <= x1 < original_image.shape[1]):
358 overlay[y1, x1] = max_val
359 if (0 <= y2 < original_image.shape[0] and 0 <= x2 < original_image.shape[1]):
360 overlay[y2, x2] = max_val
362 return overlay
364def create_visualization_array(
365 original_image: np.ndarray,
366 graph: nx.Graph,
367 mode: VisualizationMode
368) -> np.ndarray:
369 """
370 Create visualization array based on the specified mode.
372 Args:
373 original_image: Original 2D image
374 graph: NetworkX graph with traced neurites
375 mode: Visualization mode
377 Returns:
378 2D array for visualization
379 """
380 if mode == VisualizationMode.NONE:
381 # Return zeros array
382 return np.zeros_like(original_image, dtype=original_image.dtype)
384 elif mode == VisualizationMode.TRACE_ONLY:
385 # Create binary mask with traced neurites
386 trace_mask = np.zeros_like(original_image, dtype=np.uint8)
387 for u, v in graph.edges:
388 x1, y1 = u # Fix: nodes are stored as (x, y) not (y, x)
389 x2, y2 = v # Fix: nodes are stored as (x, y) not (y, x)
390 # Bounds checking
391 if (0 <= y1 < original_image.shape[0] and 0 <= x1 < original_image.shape[1]):
392 trace_mask[y1, x1] = 1
393 if (0 <= y2 < original_image.shape[0] and 0 <= x2 < original_image.shape[1]):
394 trace_mask[y2, x2] = 1
395 return trace_mask
397 elif mode == VisualizationMode.OVERLAY:
398 # Use shared overlay function
399 return create_overlay_from_graph(original_image, graph)
401 else:
402 raise ValueError(f"Unknown visualization mode: {mode}")
404@special_outputs(("hmm_analysis", materialize_hmm_analysis), ("trace_visualizations", materialize_trace_visualizations))
405@numpy
406def trace_neurites_rrs_alva(
407 image_stack: np.ndarray,
408 seeding_method: SeedingMethod = SeedingMethod.BLOB_DETECTION,
409 return_trace_visualizations: bool = False,
410 trace_visualization_mode: VisualizationMode = VisualizationMode.TRACE_ONLY,
411 chain_level: float = 1.05,
412 node_r: Optional[int] = None,
413 total_node: Optional[int] = None,
414 line_length_min: int = 32,
415 num_seeds: int = 100,
416 min_sigma: float = 1.0,
417 max_sigma: float = 2.0,
418 threshold: float = 0.02,
419 normalize_image: bool = False,
420 percentile: float = 99.9
421) -> Tuple[np.ndarray, Dict[str, Any], List[np.ndarray]]:
422 """
423 Trace neurites using the alvahmm RRS (Random-Reaction-Seed) algorithm.
425 This is the OpenHCS-compatible version of the original alvahmm implementation.
426 Performs bidirectional HMM tracing with branching analysis to reconstruct
427 complete neurite morphology.
429 Args:
430 image_stack: 3D array of shape (Z, Y, X) - input image stack
431 seeding_method: Method for seed generation (RANDOM=paper default, BLOB_DETECTION=enhanced)
432 return_trace_visualizations: Whether to generate trace visualizations as special output
433 trace_visualization_mode: How to visualize results (NONE, TRACE_ONLY, OVERLAY)
434 chain_level: Validation threshold for HMM chains (default: 1.05)
435 node_r: Path length between adjacent nodes (default: None, uses alvahmm default)
436 total_node: Number of HMM nodes in chain (default: None, uses alvahmm default)
437 line_length_min: Minimum line length for connection (default: 32)
438 num_seeds: Number of seeds for random seeding method (default: 100)
439 min_sigma: Minimum sigma for blob detection (default: 1.0)
440 max_sigma: Maximum sigma for blob detection (default: 2.0)
441 threshold: Threshold for blob detection (default: 0.02)
442 normalize_image: Whether to apply percentile normalization (default: False, paper doesn't use)
443 percentile: Percentile for normalization if enabled (default: 99.9)
445 Returns:
446 result_image: Original image stack unchanged (Z, Y, X)
447 analysis_results: HMM analysis data structure with graph and metrics
448 trace_visualizations: (Special output) List of visualization arrays if return_trace_visualizations=True
449 """
450 # Validate input is 3D
451 if image_stack.ndim != 3:
452 raise ValueError(f"Expected 3D array, got {image_stack.ndim}D")
454 # Process each slice individually
455 Z, Y, X = image_stack.shape
456 all_graphs = []
457 trace_visualizations = []
459 for z in range(Z):
460 im_axon = image_stack[z].astype(np.float64)
462 # Optional normalization (removed from default, paper doesn't use)
463 if normalize_image:
464 im_axon = normalize(im_axon, percentile=percentile)
466 # Generate seeds using selected method
467 seed_xx, seed_yy = generate_seeds_by_method(
468 im_axon,
469 method=seeding_method,
470 num_seeds=num_seeds,
471 min_sigma=min_sigma,
472 max_sigma=max_sigma,
473 threshold=threshold
474 )
476 # Perform RRS tracing with bidirectional HMM chains
477 root_tree_yy, root_tree_xx, root_tip_yy, root_tip_xx = selected_seeding(
478 im_axon,
479 seed_xx,
480 seed_yy,
481 chain_level=chain_level,
482 node_r=node_r,
483 total_node=total_node,
484 line_length_min=line_length_min
485 )
487 # Extract graph representation for this slice
488 graph = extract_graph(root_tree_xx, root_tree_yy)
489 all_graphs.append(graph)
491 # Create visualization for this slice if requested
492 if return_trace_visualizations:
493 visualization = create_visualization_array(im_axon, graph, trace_visualization_mode)
494 trace_visualizations.append(visualization)
496 # Combine all graphs (for compatibility, return the first one)
497 combined_graph = all_graphs[0] if all_graphs else nx.Graph()
499 # Compile analysis results
500 analysis_results = _compile_hmm_analysis_results(
501 combined_graph, all_graphs, image_stack.shape,
502 seeding_method, trace_visualization_mode, chain_level,
503 node_r, total_node, line_length_min
504 )
506 # Always return original image, analysis results, and trace visualizations
507 return image_stack, analysis_results, trace_visualizations
510def _compile_hmm_analysis_results(
511 combined_graph: nx.Graph,
512 all_graphs: List[nx.Graph],
513 image_shape: Tuple[int, int, int],
514 seeding_method: SeedingMethod,
515 visualization_mode: VisualizationMode,
516 chain_level: float,
517 node_r: Optional[int],
518 total_node: Optional[int],
519 line_length_min: int
520) -> Dict[str, Any]:
521 """Compile comprehensive HMM analysis results."""
522 from datetime import datetime
524 # Compute summary metrics from the graph
525 num_nodes = combined_graph.number_of_nodes()
526 num_edges = combined_graph.number_of_edges()
528 # Calculate total trace length
529 total_length = 0.0
530 edge_lengths = []
531 for u, v, data in combined_graph.edges(data=True):
532 # Calculate Euclidean distance between nodes
533 x1, y1 = u
534 x2, y2 = v
535 length = ((x2 - x1)**2 + (y2 - y1)**2)**0.5
536 edge_lengths.append(length)
537 total_length += length
539 # Summary metrics
540 summary = {
541 'total_trace_length': float(total_length),
542 'num_nodes': int(num_nodes),
543 'num_edges': int(num_edges),
544 'num_slices_processed': len(all_graphs),
545 'mean_edge_length': float(sum(edge_lengths) / len(edge_lengths)) if edge_lengths else 0.0,
546 'max_edge_length': float(max(edge_lengths)) if edge_lengths else 0.0,
547 'graph_density': float(nx.density(combined_graph)) if num_nodes > 1 else 0.0,
548 'num_connected_components': int(nx.number_connected_components(combined_graph)),
549 }
551 # Metadata
552 metadata = {
553 'algorithm': 'alvahmm_rrs',
554 'seeding_method': seeding_method.value,
555 'visualization_mode': visualization_mode.value,
556 'chain_level': chain_level,
557 'node_r': node_r,
558 'total_node': total_node,
559 'line_length_min': line_length_min,
560 'image_shape': image_shape,
561 'processing_timestamp': datetime.now().isoformat(),
562 }
564 return {
565 'summary': summary,
566 'graph': combined_graph,
567 'metadata': metadata
568 }
571# Legacy file-based processing function (kept for reference)
572def process_file_legacy(filename, input_folder, output_folder, **kwargs):
573 """
574 Legacy file-based processing function.
576 This is kept for reference but should not be used in OpenHCS.
577 Use trace_neurites_rrs_alva() instead for array-in/array-out processing.
578 """
579 raise NotImplementedError(
580 "Legacy file-based processing not supported in OpenHCS. "
581 "Use trace_neurites_rrs_alva() for array-in/array-out processing."
582 )