Coverage for openhcs/processing/backends/analysis/hmm_axon.py: 3.9%

240 statements  

« prev     ^ index     » next       coverage.py v7.10.3, created at 2025-08-14 05:57 +0000

1# BACKUP: hmm_axon_backup.py created before OpenHCS conversion 

2""" 

3OpenHCS-compatible neurite tracing using alvahmm RRS algorithm. 

4 

5Converted from file-based processing to pure array-in/array-out functions 

6following OpenHCS patterns. 

7""" 

8 

9import numpy as np 

10import networkx as nx 

11import skimage 

12import math 

13from enum import Enum 

14from typing import Tuple, Dict, List, Optional, Any 

15from skimage.feature import canny, blob_dog as local_max 

16from skimage.filters import median, threshold_li 

17from skimage.morphology import remove_small_objects, skeletonize 

18from openhcs.core.memory.decorators import numpy 

19from openhcs.core.pipeline.function_contracts import special_outputs 

20 

21# Import alvahmm from GitHub dependency 

22from alva_machinery.markov import aChain as alva_MCMC 

23from alva_machinery.branching import aWay as alva_branch 

24 

25 

26def materialize_hmm_analysis( 

27 hmm_analysis_data: Dict[str, Any], 

28 path: str, 

29 filemanager, 

30 **kwargs 

31) -> str: 

32 """ 

33 Materialize HMM neurite tracing analysis results to disk. 

34 

35 Creates multiple output files: 

36 - JSON file with graph data and summary metrics 

37 - GraphML file with the NetworkX graph 

38 - CSV file with edge data 

39 

40 Args: 

41 hmm_analysis_data: The HMM analysis results dictionary 

42 path: Base path for output files (from special output path) 

43 filemanager: FileManager instance for consistent I/O 

44 **kwargs: Additional materialization options 

45 

46 Returns: 

47 str: Path to the primary output file (JSON summary) 

48 """ 

49 import json 

50 import networkx as nx 

51 from pathlib import Path 

52 from openhcs.constants.constants import Backend 

53 

54 # Generate output file paths 

55 base_path = path.replace('.pkl', '') 

56 json_path = f"{base_path}.json" 

57 graphml_path = f"{base_path}_graph.graphml" 

58 csv_path = f"{base_path}_edges.csv" 

59 

60 # Ensure output directory exists 

61 output_dir = Path(json_path).parent 

62 filemanager.ensure_directory(str(output_dir), Backend.DISK.value) 

63 

64 # 1. Save summary and metadata as JSON (primary output) 

65 summary_data = { 

66 'analysis_type': 'hmm_neurite_tracing', 

67 'summary': hmm_analysis_data['summary'], 

68 'metadata': hmm_analysis_data['metadata'] 

69 } 

70 json_content = json.dumps(summary_data, indent=2, default=str) 

71 filemanager.save(json_content, json_path, Backend.DISK.value) 

72 

73 # 2. Save NetworkX graph as GraphML 

74 graph = hmm_analysis_data['graph'] 

75 if graph and graph.number_of_nodes() > 0: 

76 # Use direct file I/O for GraphML (NetworkX doesn't support string I/O) 

77 nx.write_graphml(graph, graphml_path) 

78 

79 # 3. Save edge data as CSV 

80 if graph.number_of_edges() > 0: 

81 import pandas as pd 

82 edge_data = [] 

83 for u, v, data in graph.edges(data=True): 

84 edge_info = { 

85 'source_x': u[0], 'source_y': u[1], 

86 'target_x': v[0], 'target_y': v[1], 

87 **data # Include any edge attributes 

88 } 

89 edge_data.append(edge_info) 

90 

91 edge_df = pd.DataFrame(edge_data) 

92 csv_content = edge_df.to_csv(index=False) 

93 filemanager.save(csv_content, csv_path, Backend.DISK.value) 

94 

95 return json_path 

96 

97 

98def materialize_trace_visualizations(data: List[np.ndarray], path: str, filemanager) -> str: 

99 """Materialize trace visualizations as individual TIFF files.""" 

100 

101 if not data: 

102 # Create empty summary file to indicate no visualizations were generated 

103 summary_path = path.replace('.pkl', '_trace_summary.txt') 

104 summary_content = "No trace visualizations generated (return_trace_visualizations=False)\n" 

105 from openhcs.constants.constants import Backend 

106 filemanager.save(summary_content, summary_path, Backend.DISK.value) 

107 return summary_path 

108 

109 # Generate output file paths based on the input path 

110 base_path = path.replace('.pkl', '') 

111 

112 # Save each visualization as a separate TIFF file 

113 for i, visualization in enumerate(data): 

114 viz_filename = f"{base_path}_slice_{i:03d}.tif" 

115 

116 # Convert visualization to appropriate dtype for saving (uint16 to match input images) 

117 if visualization.dtype != np.uint16: 

118 # Normalize to uint16 range if needed 

119 if visualization.max() <= 1.0: 

120 viz_uint16 = (visualization * 65535).astype(np.uint16) 

121 else: 

122 viz_uint16 = visualization.astype(np.uint16) 

123 else: 

124 viz_uint16 = visualization 

125 

126 # Save using filemanager 

127 from openhcs.constants.constants import Backend 

128 filemanager.save(viz_uint16, viz_filename, Backend.DISK.value) 

129 

130 # Return summary path 

131 summary_path = f"{base_path}_trace_summary.txt" 

132 summary_content = f"Trace visualizations saved: {len(data)} files\n" 

133 summary_content += f"Base filename pattern: {base_path}_slice_XXX.tif\n" 

134 summary_content += f"Visualization dtype: {data[0].dtype}\n" 

135 summary_content += f"Visualization shape: {data[0].shape}\n" 

136 

137 filemanager.save(summary_content, summary_path, Backend.DISK.value) 

138 

139 return summary_path 

140 

141 

142class SeedingMethod(Enum): 

143 """Seeding methods for neurite tracing.""" 

144 RANDOM = "random" # Paper's original method - random seeds across entire image 

145 BLOB_DETECTION = "blob" # Enhanced method - seeds on detected blob structures 

146 CANNY_EDGES = "canny" # Alternative - seeds on Canny edge detection 

147 GROWTH_CONES = "growth_cones" # Alternative - seeds on detected growth cones 

148 

149 

150class VisualizationMode(Enum): 

151 """Visualization modes for trace output.""" 

152 NONE = "none" # Return zeros array (no visualization) 

153 TRACE_ONLY = "trace" # Show only traced neurites (binary mask) 

154 OVERLAY = "overlay" # Show original image with traced neurites overlaid 

155 

156 

157class OutputMode(Enum): 

158 """Output visualization modes.""" 

159 TRACE_ONLY = "trace_only" # Binary mask of traced neurites only 

160 OVERLAY = "overlay" # Original image with traces overlaid 

161 NONE = "none" # Return original image unchanged 

162 

163def normalize(img,percentile=99.9): 

164 percentile_value = np.percentile(img, percentile) 

165 img = img / percentile_value # Scale the image to the nth percentile value 

166 img = np.clip(img, 0, 100) # You can change 1 to 100 if you want percentages 

167 #img = img - img.min() 

168 #img = img / img.max() 

169 return img 

170 

171def boundary_masking_canny(image): 

172 bool_im_axon_edit = canny(image) 

173 bool_im_axon_edit[:,:2] = False 

174 bool_im_axon_edit[:,-2:] = False 

175 bool_im_axon_edit[:2,:] = False 

176 bool_im_axon_edit[-2:,:] = False 

177 return np.array(bool_im_axon_edit,dtype=np.int64) 

178 

179def boundary_masking_threshold(image,threshold=threshold_li,min_size=2): 

180 threshed=threshold(image) 

181 bool_image = image > threshed 

182 bool_image[:,:2] = False 

183 bool_image[:,-2:] = False 

184 bool_image[:2,:] = False 

185 bool_image[-2:,:] = False 

186 cleaned_bool_im_axon_edit = skeletonize(bool_image) 

187 return np.array(bool_image,dtype=np.int64) 

188 

189def boundary_masking_blob(image,min_sigma = 1, max_sigma = 2, threshold = 0.02): 

190 if min_sigma is None: 

191 min_sigma = 1 

192 if max_sigma is None: 

193 max_sigma = 2 

194 if threshold is None: 

195 threshold = 0.02 

196 

197 image_median = median(image) 

198 galaxy = local_max(image_median, min_sigma = min_sigma, max_sigma = max_sigma, threshold = threshold) 

199 yy = np.int64(galaxy[:, 0]) 

200 xx = np.int64(galaxy[:, 1]) 

201 boundary_mask = np.copy(image) * 0 

202 boundary_mask[yy, xx] = 1 

203 return boundary_mask 

204 

205def random_seed_by_edge_map(edge_map): 

206 """Generate random seeds from detected edge/blob locations.""" 

207 yy, xx = edge_map.nonzero() 

208 if len(xx) == 0: 

209 # No edges detected, fall back to random seeding 

210 return generate_random_seeds(edge_map.shape, num_seeds=100) 

211 seed_index = np.random.choice(len(xx), len(xx)) 

212 seed_xx = xx[seed_index] 

213 seed_yy = yy[seed_index] 

214 return seed_xx, seed_yy 

215 

216 

217def generate_random_seeds(image_shape: Tuple[int, int], num_seeds: int = 100): 

218 """Generate completely random seeds across the entire image (paper's original method).""" 

219 height, width = image_shape 

220 seed_xx = np.random.randint(0, width, num_seeds) 

221 seed_yy = np.random.randint(0, height, num_seeds) 

222 return seed_xx, seed_yy 

223 

224 

225def generate_seeds_by_method( 

226 image: np.ndarray, 

227 method: SeedingMethod = SeedingMethod.BLOB_DETECTION, 

228 num_seeds: int = 100, 

229 min_sigma: float = 1.0, 

230 max_sigma: float = 2.0, 

231 threshold: float = 0.02 

232) -> Tuple[np.ndarray, np.ndarray]: 

233 """ 

234 Generate seeds using the specified method. 

235 

236 Args: 

237 image: Input image for seed generation 

238 method: Seeding method to use 

239 num_seeds: Number of seeds for random method 

240 min_sigma: Min sigma for blob detection 

241 max_sigma: Max sigma for blob detection 

242 threshold: Threshold for blob detection 

243 

244 Returns: 

245 seed_xx, seed_yy: Arrays of seed coordinates 

246 """ 

247 if method == SeedingMethod.RANDOM: 

248 # Paper's original method - pure random seeding 

249 return generate_random_seeds(image.shape, num_seeds) 

250 

251 elif method == SeedingMethod.BLOB_DETECTION: 

252 # Enhanced method - seeds on detected blobs 

253 edge_map = boundary_masking_blob(image, min_sigma, max_sigma, threshold) 

254 return random_seed_by_edge_map(edge_map) 

255 

256 elif method == SeedingMethod.CANNY_EDGES: 

257 # Alternative - seeds on Canny edges 

258 edge_map = boundary_masking_canny(image) 

259 return random_seed_by_edge_map(edge_map) 

260 

261 elif method == SeedingMethod.GROWTH_CONES: 

262 # Alternative - seeds on growth cones 

263 return get_growth_cone_positions(image) 

264 

265 else: 

266 raise ValueError(f"Unknown seeding method: {method}") 

267 

268def get_growth_cone_positions(image): 

269 """ 

270 Detect growth cone positions using morphological operations (OpenHCS-compatible). 

271 

272 Args: 

273 image: Input image for growth cone detection 

274 

275 Returns: 

276 seed_xx, seed_yy: Arrays of growth cone center coordinates 

277 """ 

278 # Threshold the image to create a binary mask 

279 mask = image > skimage.filters.threshold_otsu(image) 

280 

281 # Use morphological closing to fill in small gaps in the mask 

282 mask = skimage.morphology.closing(mask, skimage.morphology.disk(3)) 

283 labeled = skimage.measure.label(mask) 

284 props = skimage.measure.regionprops(labeled) 

285 

286 seed_xx = [] 

287 seed_yy = [] 

288 for prop in props: 

289 seed_xx.append(prop.centroid[1]) # x coordinate 

290 seed_yy.append(prop.centroid[0]) # y coordinate 

291 

292 return np.array(seed_xx), np.array(seed_yy) 

293 

294def selected_seeding(image,seed_xx,seed_yy,chain_level=1.05,total_node=8,node_r=None,line_length_min=32): 

295 im_copy=np.copy(image) 

296 alva_HMM = alva_MCMC.AlvaHmm(im_copy, 

297 total_node = total_node, 

298 total_path = None, 

299 node_r = node_r, 

300 node_angle_max = None,) 

301 chain_HMM_1st, pair_chain_HMM, pair_seed_xx, pair_seed_yy = alva_HMM.pair_HMM_chain(seed_xx = seed_xx, 

302 seed_yy = seed_yy, 

303 chain_level = chain_level,) 

304 for chain_i in [0, 1]: 

305 chain_HMM = [chain_HMM_1st, pair_chain_HMM][chain_i] 

306 real_chain_ii, real_chain_aa, real_chain_xx, real_chain_yy = chain_HMM[0:4] 

307 seed_node_xx, seed_node_yy = chain_HMM[4:6] 

308 

309 chain_im_fine = alva_HMM.chain_image(chain_HMM_1st, pair_chain_HMM,) 

310 return alva_branch.connect_way(chain_im_fine, 

311 line_length_min = line_length_min, 

312 free_zone_from_y0 = None,) 

313 

314def euclidian_distance(x1, y1, x2, y2): 

315 distance = math.sqrt((x1 - x2) ** 2 + (y1 - y2) ** 2) 

316 return distance 

317 

318def extract_graph(root_tree_xx,root_tree_yy): 

319 graph = nx.Graph() 

320 for path_x,path_y in zip(root_tree_xx,root_tree_yy): 

321 for x,y in zip(path_x,path_y): 

322 graph.add_node((x,y)) 

323 for i in range(len(path_x)-1): 

324 distance=euclidian_distance(path_x[i], path_y[i], path_x[i + 1], path_y[i + 1]) 

325 graph.add_edge((path_x[i], path_y[i]), (path_x[i + 1], path_y[i + 1]),weight=distance) 

326 return graph 

327 

328def graph_to_length(graph): 

329 total_distance = 0 

330 for u, v, data in graph.edges(data=True): 

331 total_distance += data['weight'] 

332 return total_distance 

333 

334def create_overlay_from_graph(original_image: np.ndarray, graph: nx.Graph) -> np.ndarray: 

335 """ 

336 Create overlay visualization with traces on original image. 

337 

338 Args: 

339 original_image: Original input image (numpy array) 

340 graph: NetworkX graph containing trace coordinates as (x, y) tuples 

341 

342 Returns: 

343 Overlay array with traces highlighted on original image 

344 """ 

345 overlay = original_image.copy() 

346 # Set trace pixels to maximum intensity for visibility 

347 max_val = np.max(original_image) if original_image.size > 0 else 1 

348 

349 for u, v in graph.edges: 

350 x1, y1 = u # Nodes are stored as (x, y) tuples 

351 x2, y2 = v # Nodes are stored as (x, y) tuples 

352 # Bounds checking 

353 if (0 <= y1 < original_image.shape[0] and 0 <= x1 < original_image.shape[1]): 

354 overlay[y1, x1] = max_val 

355 if (0 <= y2 < original_image.shape[0] and 0 <= x2 < original_image.shape[1]): 

356 overlay[y2, x2] = max_val 

357 

358 return overlay 

359 

360def create_visualization_array( 

361 original_image: np.ndarray, 

362 graph: nx.Graph, 

363 mode: VisualizationMode 

364) -> np.ndarray: 

365 """ 

366 Create visualization array based on the specified mode. 

367 

368 Args: 

369 original_image: Original 2D image 

370 graph: NetworkX graph with traced neurites 

371 mode: Visualization mode 

372 

373 Returns: 

374 2D array for visualization 

375 """ 

376 if mode == VisualizationMode.NONE: 

377 # Return zeros array 

378 return np.zeros_like(original_image, dtype=original_image.dtype) 

379 

380 elif mode == VisualizationMode.TRACE_ONLY: 

381 # Create binary mask with traced neurites 

382 trace_mask = np.zeros_like(original_image, dtype=np.uint8) 

383 for u, v in graph.edges: 

384 x1, y1 = u # Fix: nodes are stored as (x, y) not (y, x) 

385 x2, y2 = v # Fix: nodes are stored as (x, y) not (y, x) 

386 # Bounds checking 

387 if (0 <= y1 < original_image.shape[0] and 0 <= x1 < original_image.shape[1]): 

388 trace_mask[y1, x1] = 1 

389 if (0 <= y2 < original_image.shape[0] and 0 <= x2 < original_image.shape[1]): 

390 trace_mask[y2, x2] = 1 

391 return trace_mask 

392 

393 elif mode == VisualizationMode.OVERLAY: 

394 # Use shared overlay function 

395 return create_overlay_from_graph(original_image, graph) 

396 

397 else: 

398 raise ValueError(f"Unknown visualization mode: {mode}") 

399 

400@special_outputs(("hmm_analysis", materialize_hmm_analysis), ("trace_visualizations", materialize_trace_visualizations)) 

401@numpy 

402def trace_neurites_rrs_alva( 

403 image_stack: np.ndarray, 

404 seeding_method: SeedingMethod = SeedingMethod.BLOB_DETECTION, 

405 return_trace_visualizations: bool = False, 

406 trace_visualization_mode: VisualizationMode = VisualizationMode.TRACE_ONLY, 

407 chain_level: float = 1.05, 

408 node_r: Optional[int] = None, 

409 total_node: Optional[int] = None, 

410 line_length_min: int = 32, 

411 num_seeds: int = 100, 

412 min_sigma: float = 1.0, 

413 max_sigma: float = 2.0, 

414 threshold: float = 0.02, 

415 normalize_image: bool = False, 

416 percentile: float = 99.9 

417) -> Tuple[np.ndarray, Dict[str, Any], List[np.ndarray]]: 

418 """ 

419 Trace neurites using the alvahmm RRS (Random-Reaction-Seed) algorithm. 

420 

421 This is the OpenHCS-compatible version of the original alvahmm implementation. 

422 Performs bidirectional HMM tracing with branching analysis to reconstruct 

423 complete neurite morphology. 

424 

425 Args: 

426 image_stack: 3D array of shape (Z, Y, X) - input image stack 

427 seeding_method: Method for seed generation (RANDOM=paper default, BLOB_DETECTION=enhanced) 

428 return_trace_visualizations: Whether to generate trace visualizations as special output 

429 trace_visualization_mode: How to visualize results (NONE, TRACE_ONLY, OVERLAY) 

430 chain_level: Validation threshold for HMM chains (default: 1.05) 

431 node_r: Path length between adjacent nodes (default: None, uses alvahmm default) 

432 total_node: Number of HMM nodes in chain (default: None, uses alvahmm default) 

433 line_length_min: Minimum line length for connection (default: 32) 

434 num_seeds: Number of seeds for random seeding method (default: 100) 

435 min_sigma: Minimum sigma for blob detection (default: 1.0) 

436 max_sigma: Maximum sigma for blob detection (default: 2.0) 

437 threshold: Threshold for blob detection (default: 0.02) 

438 normalize_image: Whether to apply percentile normalization (default: False, paper doesn't use) 

439 percentile: Percentile for normalization if enabled (default: 99.9) 

440 

441 Returns: 

442 result_image: Original image stack unchanged (Z, Y, X) 

443 analysis_results: HMM analysis data structure with graph and metrics 

444 trace_visualizations: (Special output) List of visualization arrays if return_trace_visualizations=True 

445 """ 

446 # Validate input is 3D 

447 if image_stack.ndim != 3: 

448 raise ValueError(f"Expected 3D array, got {image_stack.ndim}D") 

449 

450 # Process each slice individually 

451 Z, Y, X = image_stack.shape 

452 all_graphs = [] 

453 trace_visualizations = [] 

454 

455 for z in range(Z): 

456 im_axon = image_stack[z].astype(np.float64) 

457 

458 # Optional normalization (removed from default, paper doesn't use) 

459 if normalize_image: 

460 im_axon = normalize(im_axon, percentile=percentile) 

461 

462 # Generate seeds using selected method 

463 seed_xx, seed_yy = generate_seeds_by_method( 

464 im_axon, 

465 method=seeding_method, 

466 num_seeds=num_seeds, 

467 min_sigma=min_sigma, 

468 max_sigma=max_sigma, 

469 threshold=threshold 

470 ) 

471 

472 # Perform RRS tracing with bidirectional HMM chains 

473 root_tree_yy, root_tree_xx, root_tip_yy, root_tip_xx = selected_seeding( 

474 im_axon, 

475 seed_xx, 

476 seed_yy, 

477 chain_level=chain_level, 

478 node_r=node_r, 

479 total_node=total_node, 

480 line_length_min=line_length_min 

481 ) 

482 

483 # Extract graph representation for this slice 

484 graph = extract_graph(root_tree_xx, root_tree_yy) 

485 all_graphs.append(graph) 

486 

487 # Create visualization for this slice if requested 

488 if return_trace_visualizations: 

489 visualization = create_visualization_array(im_axon, graph, trace_visualization_mode) 

490 trace_visualizations.append(visualization) 

491 

492 # Combine all graphs (for compatibility, return the first one) 

493 combined_graph = all_graphs[0] if all_graphs else nx.Graph() 

494 

495 # Compile analysis results 

496 analysis_results = _compile_hmm_analysis_results( 

497 combined_graph, all_graphs, image_stack.shape, 

498 seeding_method, trace_visualization_mode, chain_level, 

499 node_r, total_node, line_length_min 

500 ) 

501 

502 # Always return original image, analysis results, and trace visualizations 

503 return image_stack, analysis_results, trace_visualizations 

504 

505 

506def _compile_hmm_analysis_results( 

507 combined_graph: nx.Graph, 

508 all_graphs: List[nx.Graph], 

509 image_shape: Tuple[int, int, int], 

510 seeding_method: SeedingMethod, 

511 visualization_mode: VisualizationMode, 

512 chain_level: float, 

513 node_r: Optional[int], 

514 total_node: Optional[int], 

515 line_length_min: int 

516) -> Dict[str, Any]: 

517 """Compile comprehensive HMM analysis results.""" 

518 from datetime import datetime 

519 

520 # Compute summary metrics from the graph 

521 num_nodes = combined_graph.number_of_nodes() 

522 num_edges = combined_graph.number_of_edges() 

523 

524 # Calculate total trace length 

525 total_length = 0.0 

526 edge_lengths = [] 

527 for u, v, data in combined_graph.edges(data=True): 

528 # Calculate Euclidean distance between nodes 

529 x1, y1 = u 

530 x2, y2 = v 

531 length = ((x2 - x1)**2 + (y2 - y1)**2)**0.5 

532 edge_lengths.append(length) 

533 total_length += length 

534 

535 # Summary metrics 

536 summary = { 

537 'total_trace_length': float(total_length), 

538 'num_nodes': int(num_nodes), 

539 'num_edges': int(num_edges), 

540 'num_slices_processed': len(all_graphs), 

541 'mean_edge_length': float(sum(edge_lengths) / len(edge_lengths)) if edge_lengths else 0.0, 

542 'max_edge_length': float(max(edge_lengths)) if edge_lengths else 0.0, 

543 'graph_density': float(nx.density(combined_graph)) if num_nodes > 1 else 0.0, 

544 'num_connected_components': int(nx.number_connected_components(combined_graph)), 

545 } 

546 

547 # Metadata 

548 metadata = { 

549 'algorithm': 'alvahmm_rrs', 

550 'seeding_method': seeding_method.value, 

551 'visualization_mode': visualization_mode.value, 

552 'chain_level': chain_level, 

553 'node_r': node_r, 

554 'total_node': total_node, 

555 'line_length_min': line_length_min, 

556 'image_shape': image_shape, 

557 'processing_timestamp': datetime.now().isoformat(), 

558 } 

559 

560 return { 

561 'summary': summary, 

562 'graph': combined_graph, 

563 'metadata': metadata 

564 } 

565 

566 

567# Legacy file-based processing function (kept for reference) 

568def process_file_legacy(filename, input_folder, output_folder, **kwargs): 

569 """ 

570 Legacy file-based processing function. 

571 

572 This is kept for reference but should not be used in OpenHCS. 

573 Use trace_neurites_rrs_alva() instead for array-in/array-out processing. 

574 """ 

575 raise NotImplementedError( 

576 "Legacy file-based processing not supported in OpenHCS. " 

577 "Use trace_neurites_rrs_alva() for array-in/array-out processing." 

578 ) 

579