Coverage for openhcs/processing/backends/analysis/hmm_axon.py: 3.9%
240 statements
« prev ^ index » next coverage.py v7.10.3, created at 2025-08-14 05:57 +0000
« prev ^ index » next coverage.py v7.10.3, created at 2025-08-14 05:57 +0000
1# BACKUP: hmm_axon_backup.py created before OpenHCS conversion
2"""
3OpenHCS-compatible neurite tracing using alvahmm RRS algorithm.
5Converted from file-based processing to pure array-in/array-out functions
6following OpenHCS patterns.
7"""
9import numpy as np
10import networkx as nx
11import skimage
12import math
13from enum import Enum
14from typing import Tuple, Dict, List, Optional, Any
15from skimage.feature import canny, blob_dog as local_max
16from skimage.filters import median, threshold_li
17from skimage.morphology import remove_small_objects, skeletonize
18from openhcs.core.memory.decorators import numpy
19from openhcs.core.pipeline.function_contracts import special_outputs
21# Import alvahmm from GitHub dependency
22from alva_machinery.markov import aChain as alva_MCMC
23from alva_machinery.branching import aWay as alva_branch
26def materialize_hmm_analysis(
27 hmm_analysis_data: Dict[str, Any],
28 path: str,
29 filemanager,
30 **kwargs
31) -> str:
32 """
33 Materialize HMM neurite tracing analysis results to disk.
35 Creates multiple output files:
36 - JSON file with graph data and summary metrics
37 - GraphML file with the NetworkX graph
38 - CSV file with edge data
40 Args:
41 hmm_analysis_data: The HMM analysis results dictionary
42 path: Base path for output files (from special output path)
43 filemanager: FileManager instance for consistent I/O
44 **kwargs: Additional materialization options
46 Returns:
47 str: Path to the primary output file (JSON summary)
48 """
49 import json
50 import networkx as nx
51 from pathlib import Path
52 from openhcs.constants.constants import Backend
54 # Generate output file paths
55 base_path = path.replace('.pkl', '')
56 json_path = f"{base_path}.json"
57 graphml_path = f"{base_path}_graph.graphml"
58 csv_path = f"{base_path}_edges.csv"
60 # Ensure output directory exists
61 output_dir = Path(json_path).parent
62 filemanager.ensure_directory(str(output_dir), Backend.DISK.value)
64 # 1. Save summary and metadata as JSON (primary output)
65 summary_data = {
66 'analysis_type': 'hmm_neurite_tracing',
67 'summary': hmm_analysis_data['summary'],
68 'metadata': hmm_analysis_data['metadata']
69 }
70 json_content = json.dumps(summary_data, indent=2, default=str)
71 filemanager.save(json_content, json_path, Backend.DISK.value)
73 # 2. Save NetworkX graph as GraphML
74 graph = hmm_analysis_data['graph']
75 if graph and graph.number_of_nodes() > 0:
76 # Use direct file I/O for GraphML (NetworkX doesn't support string I/O)
77 nx.write_graphml(graph, graphml_path)
79 # 3. Save edge data as CSV
80 if graph.number_of_edges() > 0:
81 import pandas as pd
82 edge_data = []
83 for u, v, data in graph.edges(data=True):
84 edge_info = {
85 'source_x': u[0], 'source_y': u[1],
86 'target_x': v[0], 'target_y': v[1],
87 **data # Include any edge attributes
88 }
89 edge_data.append(edge_info)
91 edge_df = pd.DataFrame(edge_data)
92 csv_content = edge_df.to_csv(index=False)
93 filemanager.save(csv_content, csv_path, Backend.DISK.value)
95 return json_path
98def materialize_trace_visualizations(data: List[np.ndarray], path: str, filemanager) -> str:
99 """Materialize trace visualizations as individual TIFF files."""
101 if not data:
102 # Create empty summary file to indicate no visualizations were generated
103 summary_path = path.replace('.pkl', '_trace_summary.txt')
104 summary_content = "No trace visualizations generated (return_trace_visualizations=False)\n"
105 from openhcs.constants.constants import Backend
106 filemanager.save(summary_content, summary_path, Backend.DISK.value)
107 return summary_path
109 # Generate output file paths based on the input path
110 base_path = path.replace('.pkl', '')
112 # Save each visualization as a separate TIFF file
113 for i, visualization in enumerate(data):
114 viz_filename = f"{base_path}_slice_{i:03d}.tif"
116 # Convert visualization to appropriate dtype for saving (uint16 to match input images)
117 if visualization.dtype != np.uint16:
118 # Normalize to uint16 range if needed
119 if visualization.max() <= 1.0:
120 viz_uint16 = (visualization * 65535).astype(np.uint16)
121 else:
122 viz_uint16 = visualization.astype(np.uint16)
123 else:
124 viz_uint16 = visualization
126 # Save using filemanager
127 from openhcs.constants.constants import Backend
128 filemanager.save(viz_uint16, viz_filename, Backend.DISK.value)
130 # Return summary path
131 summary_path = f"{base_path}_trace_summary.txt"
132 summary_content = f"Trace visualizations saved: {len(data)} files\n"
133 summary_content += f"Base filename pattern: {base_path}_slice_XXX.tif\n"
134 summary_content += f"Visualization dtype: {data[0].dtype}\n"
135 summary_content += f"Visualization shape: {data[0].shape}\n"
137 filemanager.save(summary_content, summary_path, Backend.DISK.value)
139 return summary_path
142class SeedingMethod(Enum):
143 """Seeding methods for neurite tracing."""
144 RANDOM = "random" # Paper's original method - random seeds across entire image
145 BLOB_DETECTION = "blob" # Enhanced method - seeds on detected blob structures
146 CANNY_EDGES = "canny" # Alternative - seeds on Canny edge detection
147 GROWTH_CONES = "growth_cones" # Alternative - seeds on detected growth cones
150class VisualizationMode(Enum):
151 """Visualization modes for trace output."""
152 NONE = "none" # Return zeros array (no visualization)
153 TRACE_ONLY = "trace" # Show only traced neurites (binary mask)
154 OVERLAY = "overlay" # Show original image with traced neurites overlaid
157class OutputMode(Enum):
158 """Output visualization modes."""
159 TRACE_ONLY = "trace_only" # Binary mask of traced neurites only
160 OVERLAY = "overlay" # Original image with traces overlaid
161 NONE = "none" # Return original image unchanged
163def normalize(img,percentile=99.9):
164 percentile_value = np.percentile(img, percentile)
165 img = img / percentile_value # Scale the image to the nth percentile value
166 img = np.clip(img, 0, 100) # You can change 1 to 100 if you want percentages
167 #img = img - img.min()
168 #img = img / img.max()
169 return img
171def boundary_masking_canny(image):
172 bool_im_axon_edit = canny(image)
173 bool_im_axon_edit[:,:2] = False
174 bool_im_axon_edit[:,-2:] = False
175 bool_im_axon_edit[:2,:] = False
176 bool_im_axon_edit[-2:,:] = False
177 return np.array(bool_im_axon_edit,dtype=np.int64)
179def boundary_masking_threshold(image,threshold=threshold_li,min_size=2):
180 threshed=threshold(image)
181 bool_image = image > threshed
182 bool_image[:,:2] = False
183 bool_image[:,-2:] = False
184 bool_image[:2,:] = False
185 bool_image[-2:,:] = False
186 cleaned_bool_im_axon_edit = skeletonize(bool_image)
187 return np.array(bool_image,dtype=np.int64)
189def boundary_masking_blob(image,min_sigma = 1, max_sigma = 2, threshold = 0.02):
190 if min_sigma is None:
191 min_sigma = 1
192 if max_sigma is None:
193 max_sigma = 2
194 if threshold is None:
195 threshold = 0.02
197 image_median = median(image)
198 galaxy = local_max(image_median, min_sigma = min_sigma, max_sigma = max_sigma, threshold = threshold)
199 yy = np.int64(galaxy[:, 0])
200 xx = np.int64(galaxy[:, 1])
201 boundary_mask = np.copy(image) * 0
202 boundary_mask[yy, xx] = 1
203 return boundary_mask
205def random_seed_by_edge_map(edge_map):
206 """Generate random seeds from detected edge/blob locations."""
207 yy, xx = edge_map.nonzero()
208 if len(xx) == 0:
209 # No edges detected, fall back to random seeding
210 return generate_random_seeds(edge_map.shape, num_seeds=100)
211 seed_index = np.random.choice(len(xx), len(xx))
212 seed_xx = xx[seed_index]
213 seed_yy = yy[seed_index]
214 return seed_xx, seed_yy
217def generate_random_seeds(image_shape: Tuple[int, int], num_seeds: int = 100):
218 """Generate completely random seeds across the entire image (paper's original method)."""
219 height, width = image_shape
220 seed_xx = np.random.randint(0, width, num_seeds)
221 seed_yy = np.random.randint(0, height, num_seeds)
222 return seed_xx, seed_yy
225def generate_seeds_by_method(
226 image: np.ndarray,
227 method: SeedingMethod = SeedingMethod.BLOB_DETECTION,
228 num_seeds: int = 100,
229 min_sigma: float = 1.0,
230 max_sigma: float = 2.0,
231 threshold: float = 0.02
232) -> Tuple[np.ndarray, np.ndarray]:
233 """
234 Generate seeds using the specified method.
236 Args:
237 image: Input image for seed generation
238 method: Seeding method to use
239 num_seeds: Number of seeds for random method
240 min_sigma: Min sigma for blob detection
241 max_sigma: Max sigma for blob detection
242 threshold: Threshold for blob detection
244 Returns:
245 seed_xx, seed_yy: Arrays of seed coordinates
246 """
247 if method == SeedingMethod.RANDOM:
248 # Paper's original method - pure random seeding
249 return generate_random_seeds(image.shape, num_seeds)
251 elif method == SeedingMethod.BLOB_DETECTION:
252 # Enhanced method - seeds on detected blobs
253 edge_map = boundary_masking_blob(image, min_sigma, max_sigma, threshold)
254 return random_seed_by_edge_map(edge_map)
256 elif method == SeedingMethod.CANNY_EDGES:
257 # Alternative - seeds on Canny edges
258 edge_map = boundary_masking_canny(image)
259 return random_seed_by_edge_map(edge_map)
261 elif method == SeedingMethod.GROWTH_CONES:
262 # Alternative - seeds on growth cones
263 return get_growth_cone_positions(image)
265 else:
266 raise ValueError(f"Unknown seeding method: {method}")
268def get_growth_cone_positions(image):
269 """
270 Detect growth cone positions using morphological operations (OpenHCS-compatible).
272 Args:
273 image: Input image for growth cone detection
275 Returns:
276 seed_xx, seed_yy: Arrays of growth cone center coordinates
277 """
278 # Threshold the image to create a binary mask
279 mask = image > skimage.filters.threshold_otsu(image)
281 # Use morphological closing to fill in small gaps in the mask
282 mask = skimage.morphology.closing(mask, skimage.morphology.disk(3))
283 labeled = skimage.measure.label(mask)
284 props = skimage.measure.regionprops(labeled)
286 seed_xx = []
287 seed_yy = []
288 for prop in props:
289 seed_xx.append(prop.centroid[1]) # x coordinate
290 seed_yy.append(prop.centroid[0]) # y coordinate
292 return np.array(seed_xx), np.array(seed_yy)
294def selected_seeding(image,seed_xx,seed_yy,chain_level=1.05,total_node=8,node_r=None,line_length_min=32):
295 im_copy=np.copy(image)
296 alva_HMM = alva_MCMC.AlvaHmm(im_copy,
297 total_node = total_node,
298 total_path = None,
299 node_r = node_r,
300 node_angle_max = None,)
301 chain_HMM_1st, pair_chain_HMM, pair_seed_xx, pair_seed_yy = alva_HMM.pair_HMM_chain(seed_xx = seed_xx,
302 seed_yy = seed_yy,
303 chain_level = chain_level,)
304 for chain_i in [0, 1]:
305 chain_HMM = [chain_HMM_1st, pair_chain_HMM][chain_i]
306 real_chain_ii, real_chain_aa, real_chain_xx, real_chain_yy = chain_HMM[0:4]
307 seed_node_xx, seed_node_yy = chain_HMM[4:6]
309 chain_im_fine = alva_HMM.chain_image(chain_HMM_1st, pair_chain_HMM,)
310 return alva_branch.connect_way(chain_im_fine,
311 line_length_min = line_length_min,
312 free_zone_from_y0 = None,)
314def euclidian_distance(x1, y1, x2, y2):
315 distance = math.sqrt((x1 - x2) ** 2 + (y1 - y2) ** 2)
316 return distance
318def extract_graph(root_tree_xx,root_tree_yy):
319 graph = nx.Graph()
320 for path_x,path_y in zip(root_tree_xx,root_tree_yy):
321 for x,y in zip(path_x,path_y):
322 graph.add_node((x,y))
323 for i in range(len(path_x)-1):
324 distance=euclidian_distance(path_x[i], path_y[i], path_x[i + 1], path_y[i + 1])
325 graph.add_edge((path_x[i], path_y[i]), (path_x[i + 1], path_y[i + 1]),weight=distance)
326 return graph
328def graph_to_length(graph):
329 total_distance = 0
330 for u, v, data in graph.edges(data=True):
331 total_distance += data['weight']
332 return total_distance
334def create_overlay_from_graph(original_image: np.ndarray, graph: nx.Graph) -> np.ndarray:
335 """
336 Create overlay visualization with traces on original image.
338 Args:
339 original_image: Original input image (numpy array)
340 graph: NetworkX graph containing trace coordinates as (x, y) tuples
342 Returns:
343 Overlay array with traces highlighted on original image
344 """
345 overlay = original_image.copy()
346 # Set trace pixels to maximum intensity for visibility
347 max_val = np.max(original_image) if original_image.size > 0 else 1
349 for u, v in graph.edges:
350 x1, y1 = u # Nodes are stored as (x, y) tuples
351 x2, y2 = v # Nodes are stored as (x, y) tuples
352 # Bounds checking
353 if (0 <= y1 < original_image.shape[0] and 0 <= x1 < original_image.shape[1]):
354 overlay[y1, x1] = max_val
355 if (0 <= y2 < original_image.shape[0] and 0 <= x2 < original_image.shape[1]):
356 overlay[y2, x2] = max_val
358 return overlay
360def create_visualization_array(
361 original_image: np.ndarray,
362 graph: nx.Graph,
363 mode: VisualizationMode
364) -> np.ndarray:
365 """
366 Create visualization array based on the specified mode.
368 Args:
369 original_image: Original 2D image
370 graph: NetworkX graph with traced neurites
371 mode: Visualization mode
373 Returns:
374 2D array for visualization
375 """
376 if mode == VisualizationMode.NONE:
377 # Return zeros array
378 return np.zeros_like(original_image, dtype=original_image.dtype)
380 elif mode == VisualizationMode.TRACE_ONLY:
381 # Create binary mask with traced neurites
382 trace_mask = np.zeros_like(original_image, dtype=np.uint8)
383 for u, v in graph.edges:
384 x1, y1 = u # Fix: nodes are stored as (x, y) not (y, x)
385 x2, y2 = v # Fix: nodes are stored as (x, y) not (y, x)
386 # Bounds checking
387 if (0 <= y1 < original_image.shape[0] and 0 <= x1 < original_image.shape[1]):
388 trace_mask[y1, x1] = 1
389 if (0 <= y2 < original_image.shape[0] and 0 <= x2 < original_image.shape[1]):
390 trace_mask[y2, x2] = 1
391 return trace_mask
393 elif mode == VisualizationMode.OVERLAY:
394 # Use shared overlay function
395 return create_overlay_from_graph(original_image, graph)
397 else:
398 raise ValueError(f"Unknown visualization mode: {mode}")
400@special_outputs(("hmm_analysis", materialize_hmm_analysis), ("trace_visualizations", materialize_trace_visualizations))
401@numpy
402def trace_neurites_rrs_alva(
403 image_stack: np.ndarray,
404 seeding_method: SeedingMethod = SeedingMethod.BLOB_DETECTION,
405 return_trace_visualizations: bool = False,
406 trace_visualization_mode: VisualizationMode = VisualizationMode.TRACE_ONLY,
407 chain_level: float = 1.05,
408 node_r: Optional[int] = None,
409 total_node: Optional[int] = None,
410 line_length_min: int = 32,
411 num_seeds: int = 100,
412 min_sigma: float = 1.0,
413 max_sigma: float = 2.0,
414 threshold: float = 0.02,
415 normalize_image: bool = False,
416 percentile: float = 99.9
417) -> Tuple[np.ndarray, Dict[str, Any], List[np.ndarray]]:
418 """
419 Trace neurites using the alvahmm RRS (Random-Reaction-Seed) algorithm.
421 This is the OpenHCS-compatible version of the original alvahmm implementation.
422 Performs bidirectional HMM tracing with branching analysis to reconstruct
423 complete neurite morphology.
425 Args:
426 image_stack: 3D array of shape (Z, Y, X) - input image stack
427 seeding_method: Method for seed generation (RANDOM=paper default, BLOB_DETECTION=enhanced)
428 return_trace_visualizations: Whether to generate trace visualizations as special output
429 trace_visualization_mode: How to visualize results (NONE, TRACE_ONLY, OVERLAY)
430 chain_level: Validation threshold for HMM chains (default: 1.05)
431 node_r: Path length between adjacent nodes (default: None, uses alvahmm default)
432 total_node: Number of HMM nodes in chain (default: None, uses alvahmm default)
433 line_length_min: Minimum line length for connection (default: 32)
434 num_seeds: Number of seeds for random seeding method (default: 100)
435 min_sigma: Minimum sigma for blob detection (default: 1.0)
436 max_sigma: Maximum sigma for blob detection (default: 2.0)
437 threshold: Threshold for blob detection (default: 0.02)
438 normalize_image: Whether to apply percentile normalization (default: False, paper doesn't use)
439 percentile: Percentile for normalization if enabled (default: 99.9)
441 Returns:
442 result_image: Original image stack unchanged (Z, Y, X)
443 analysis_results: HMM analysis data structure with graph and metrics
444 trace_visualizations: (Special output) List of visualization arrays if return_trace_visualizations=True
445 """
446 # Validate input is 3D
447 if image_stack.ndim != 3:
448 raise ValueError(f"Expected 3D array, got {image_stack.ndim}D")
450 # Process each slice individually
451 Z, Y, X = image_stack.shape
452 all_graphs = []
453 trace_visualizations = []
455 for z in range(Z):
456 im_axon = image_stack[z].astype(np.float64)
458 # Optional normalization (removed from default, paper doesn't use)
459 if normalize_image:
460 im_axon = normalize(im_axon, percentile=percentile)
462 # Generate seeds using selected method
463 seed_xx, seed_yy = generate_seeds_by_method(
464 im_axon,
465 method=seeding_method,
466 num_seeds=num_seeds,
467 min_sigma=min_sigma,
468 max_sigma=max_sigma,
469 threshold=threshold
470 )
472 # Perform RRS tracing with bidirectional HMM chains
473 root_tree_yy, root_tree_xx, root_tip_yy, root_tip_xx = selected_seeding(
474 im_axon,
475 seed_xx,
476 seed_yy,
477 chain_level=chain_level,
478 node_r=node_r,
479 total_node=total_node,
480 line_length_min=line_length_min
481 )
483 # Extract graph representation for this slice
484 graph = extract_graph(root_tree_xx, root_tree_yy)
485 all_graphs.append(graph)
487 # Create visualization for this slice if requested
488 if return_trace_visualizations:
489 visualization = create_visualization_array(im_axon, graph, trace_visualization_mode)
490 trace_visualizations.append(visualization)
492 # Combine all graphs (for compatibility, return the first one)
493 combined_graph = all_graphs[0] if all_graphs else nx.Graph()
495 # Compile analysis results
496 analysis_results = _compile_hmm_analysis_results(
497 combined_graph, all_graphs, image_stack.shape,
498 seeding_method, trace_visualization_mode, chain_level,
499 node_r, total_node, line_length_min
500 )
502 # Always return original image, analysis results, and trace visualizations
503 return image_stack, analysis_results, trace_visualizations
506def _compile_hmm_analysis_results(
507 combined_graph: nx.Graph,
508 all_graphs: List[nx.Graph],
509 image_shape: Tuple[int, int, int],
510 seeding_method: SeedingMethod,
511 visualization_mode: VisualizationMode,
512 chain_level: float,
513 node_r: Optional[int],
514 total_node: Optional[int],
515 line_length_min: int
516) -> Dict[str, Any]:
517 """Compile comprehensive HMM analysis results."""
518 from datetime import datetime
520 # Compute summary metrics from the graph
521 num_nodes = combined_graph.number_of_nodes()
522 num_edges = combined_graph.number_of_edges()
524 # Calculate total trace length
525 total_length = 0.0
526 edge_lengths = []
527 for u, v, data in combined_graph.edges(data=True):
528 # Calculate Euclidean distance between nodes
529 x1, y1 = u
530 x2, y2 = v
531 length = ((x2 - x1)**2 + (y2 - y1)**2)**0.5
532 edge_lengths.append(length)
533 total_length += length
535 # Summary metrics
536 summary = {
537 'total_trace_length': float(total_length),
538 'num_nodes': int(num_nodes),
539 'num_edges': int(num_edges),
540 'num_slices_processed': len(all_graphs),
541 'mean_edge_length': float(sum(edge_lengths) / len(edge_lengths)) if edge_lengths else 0.0,
542 'max_edge_length': float(max(edge_lengths)) if edge_lengths else 0.0,
543 'graph_density': float(nx.density(combined_graph)) if num_nodes > 1 else 0.0,
544 'num_connected_components': int(nx.number_connected_components(combined_graph)),
545 }
547 # Metadata
548 metadata = {
549 'algorithm': 'alvahmm_rrs',
550 'seeding_method': seeding_method.value,
551 'visualization_mode': visualization_mode.value,
552 'chain_level': chain_level,
553 'node_r': node_r,
554 'total_node': total_node,
555 'line_length_min': line_length_min,
556 'image_shape': image_shape,
557 'processing_timestamp': datetime.now().isoformat(),
558 }
560 return {
561 'summary': summary,
562 'graph': combined_graph,
563 'metadata': metadata
564 }
567# Legacy file-based processing function (kept for reference)
568def process_file_legacy(filename, input_folder, output_folder, **kwargs):
569 """
570 Legacy file-based processing function.
572 This is kept for reference but should not be used in OpenHCS.
573 Use trace_neurites_rrs_alva() instead for array-in/array-out processing.
574 """
575 raise NotImplementedError(
576 "Legacy file-based processing not supported in OpenHCS. "
577 "Use trace_neurites_rrs_alva() for array-in/array-out processing."
578 )