Coverage for openhcs/processing/backends/analysis/hmm_axon.py: 3.8%

241 statements  

« prev     ^ index     » next       coverage.py v7.11.0, created at 2025-11-04 02:09 +0000

1# BACKUP: hmm_axon_backup.py created before OpenHCS conversion 

2""" 

3OpenHCS-compatible neurite tracing using alvahmm RRS algorithm. 

4 

5Converted from file-based processing to pure array-in/array-out functions 

6following OpenHCS patterns. 

7""" 

8 

9import numpy as np 

10import networkx as nx 

11import skimage 

12import math 

13from enum import Enum 

14from typing import Tuple, Dict, List, Optional, Any 

15from skimage.feature import canny, blob_dog as local_max 

16from skimage.filters import median, threshold_li 

17from skimage.morphology import skeletonize 

18from openhcs.core.memory.decorators import numpy 

19from openhcs.core.pipeline.function_contracts import special_outputs 

20 

21# Import alvahmm from GitHub dependency 

22from alva_machinery.markov import aChain as alva_MCMC 

23from alva_machinery.branching import aWay as alva_branch 

24 

25 

26def materialize_hmm_analysis( 

27 hmm_analysis_data: Dict[str, Any], 

28 path: str, 

29 filemanager, 

30 backend: str, 

31 **kwargs 

32) -> str: 

33 """ 

34 Materialize HMM neurite tracing analysis results to disk. 

35 

36 Creates multiple output files: 

37 - JSON file with graph data and summary metrics 

38 - GraphML file with the NetworkX graph 

39 - CSV file with edge data 

40 

41 Args: 

42 hmm_analysis_data: The HMM analysis results dictionary 

43 path: Base path for output files (from special output path) 

44 filemanager: FileManager instance for consistent I/O 

45 backend: Backend to use for materialization 

46 **kwargs: Additional materialization options 

47 

48 Returns: 

49 str: Path to the primary output file (JSON summary) 

50 """ 

51 import json 

52 import networkx as nx 

53 from pathlib import Path 

54 from openhcs.constants.constants import Backend 

55 

56 # Generate output file paths 

57 base_path = path.replace('.pkl', '') 

58 json_path = f"{base_path}.json" 

59 graphml_path = f"{base_path}_graph.graphml" 

60 csv_path = f"{base_path}_edges.csv" 

61 

62 # Ensure output directory exists for disk backend 

63 output_dir = Path(json_path).parent 

64 if backend == Backend.DISK.value: 

65 filemanager.ensure_directory(str(output_dir), backend) 

66 

67 # 1. Save summary and metadata as JSON (primary output) 

68 summary_data = { 

69 'analysis_type': 'hmm_neurite_tracing', 

70 'summary': hmm_analysis_data['summary'], 

71 'metadata': hmm_analysis_data['metadata'] 

72 } 

73 json_content = json.dumps(summary_data, indent=2, default=str) 

74 filemanager.save(json_content, json_path, backend) 

75 

76 # 2. Save NetworkX graph as GraphML 

77 graph = hmm_analysis_data['graph'] 

78 if graph and graph.number_of_nodes() > 0: 

79 # Use direct file I/O for GraphML (NetworkX doesn't support string I/O) 

80 # Note: NetworkX requires actual file path, not compatible with OMERO backend 

81 if backend == Backend.DISK.value: 

82 nx.write_graphml(graph, graphml_path) 

83 

84 # 3. Save edge data as CSV 

85 if graph.number_of_edges() > 0: 

86 import pandas as pd 

87 edge_data = [] 

88 for u, v, data in graph.edges(data=True): 

89 edge_info = { 

90 'source_x': u[0], 'source_y': u[1], 

91 'target_x': v[0], 'target_y': v[1], 

92 **data # Include any edge attributes 

93 } 

94 edge_data.append(edge_info) 

95 

96 edge_df = pd.DataFrame(edge_data) 

97 csv_content = edge_df.to_csv(index=False) 

98 filemanager.save(csv_content, csv_path, backend) 

99 

100 return json_path 

101 

102 

103def materialize_trace_visualizations(data: List[np.ndarray], path: str, filemanager, backend: str) -> str: 

104 """Materialize trace visualizations as individual TIFF files.""" 

105 

106 if not data: 

107 # Create empty summary file to indicate no visualizations were generated 

108 summary_path = path.replace('.pkl', '_trace_summary.txt') 

109 summary_content = "No trace visualizations generated (return_trace_visualizations=False)\n" 

110 filemanager.save(summary_content, summary_path, backend) 

111 return summary_path 

112 

113 # Generate output file paths based on the input path 

114 base_path = path.replace('.pkl', '') 

115 

116 # Save each visualization as a separate TIFF file 

117 for i, visualization in enumerate(data): 

118 viz_filename = f"{base_path}_slice_{i:03d}.tif" 

119 

120 # Convert visualization to appropriate dtype for saving (uint16 to match input images) 

121 if visualization.dtype != np.uint16: 

122 # Normalize to uint16 range if needed 

123 if visualization.max() <= 1.0: 

124 viz_uint16 = (visualization * 65535).astype(np.uint16) 

125 else: 

126 viz_uint16 = visualization.astype(np.uint16) 

127 else: 

128 viz_uint16 = visualization 

129 

130 # Save using filemanager 

131 from openhcs.constants.constants import Backend 

132 filemanager.save(viz_uint16, viz_filename, Backend.DISK.value) 

133 

134 # Return summary path 

135 summary_path = f"{base_path}_trace_summary.txt" 

136 summary_content = f"Trace visualizations saved: {len(data)} files\n" 

137 summary_content += f"Base filename pattern: {base_path}_slice_XXX.tif\n" 

138 summary_content += f"Visualization dtype: {data[0].dtype}\n" 

139 summary_content += f"Visualization shape: {data[0].shape}\n" 

140 

141 filemanager.save(summary_content, summary_path, Backend.DISK.value) 

142 

143 return summary_path 

144 

145 

146class SeedingMethod(Enum): 

147 """Seeding methods for neurite tracing.""" 

148 RANDOM = "random" # Paper's original method - random seeds across entire image 

149 BLOB_DETECTION = "blob" # Enhanced method - seeds on detected blob structures 

150 CANNY_EDGES = "canny" # Alternative - seeds on Canny edge detection 

151 GROWTH_CONES = "growth_cones" # Alternative - seeds on detected growth cones 

152 

153 

154class VisualizationMode(Enum): 

155 """Visualization modes for trace output.""" 

156 NONE = "none" # Return zeros array (no visualization) 

157 TRACE_ONLY = "trace" # Show only traced neurites (binary mask) 

158 OVERLAY = "overlay" # Show original image with traced neurites overlaid 

159 

160 

161class OutputMode(Enum): 

162 """Output visualization modes.""" 

163 TRACE_ONLY = "trace_only" # Binary mask of traced neurites only 

164 OVERLAY = "overlay" # Original image with traces overlaid 

165 NONE = "none" # Return original image unchanged 

166 

167def normalize(img,percentile=99.9): 

168 percentile_value = np.percentile(img, percentile) 

169 img = img / percentile_value # Scale the image to the nth percentile value 

170 img = np.clip(img, 0, 100) # You can change 1 to 100 if you want percentages 

171 #img = img - img.min() 

172 #img = img / img.max() 

173 return img 

174 

175def boundary_masking_canny(image): 

176 bool_im_axon_edit = canny(image) 

177 bool_im_axon_edit[:,:2] = False 

178 bool_im_axon_edit[:,-2:] = False 

179 bool_im_axon_edit[:2,:] = False 

180 bool_im_axon_edit[-2:,:] = False 

181 return np.array(bool_im_axon_edit,dtype=np.int64) 

182 

183def boundary_masking_threshold(image,threshold=threshold_li,min_size=2): 

184 threshed=threshold(image) 

185 bool_image = image > threshed 

186 bool_image[:,:2] = False 

187 bool_image[:,-2:] = False 

188 bool_image[:2,:] = False 

189 bool_image[-2:,:] = False 

190 cleaned_bool_im_axon_edit = skeletonize(bool_image) 

191 return np.array(bool_image,dtype=np.int64) 

192 

193def boundary_masking_blob(image,min_sigma = 1, max_sigma = 2, threshold = 0.02): 

194 if min_sigma is None: 

195 min_sigma = 1 

196 if max_sigma is None: 

197 max_sigma = 2 

198 if threshold is None: 

199 threshold = 0.02 

200 

201 image_median = median(image) 

202 galaxy = local_max(image_median, min_sigma = min_sigma, max_sigma = max_sigma, threshold = threshold) 

203 yy = np.int64(galaxy[:, 0]) 

204 xx = np.int64(galaxy[:, 1]) 

205 boundary_mask = np.copy(image) * 0 

206 boundary_mask[yy, xx] = 1 

207 return boundary_mask 

208 

209def random_seed_by_edge_map(edge_map): 

210 """Generate random seeds from detected edge/blob locations.""" 

211 yy, xx = edge_map.nonzero() 

212 if len(xx) == 0: 

213 # No edges detected, fall back to random seeding 

214 return generate_random_seeds(edge_map.shape, num_seeds=100) 

215 seed_index = np.random.choice(len(xx), len(xx)) 

216 seed_xx = xx[seed_index] 

217 seed_yy = yy[seed_index] 

218 return seed_xx, seed_yy 

219 

220 

221def generate_random_seeds(image_shape: Tuple[int, int], num_seeds: int = 100): 

222 """Generate completely random seeds across the entire image (paper's original method).""" 

223 height, width = image_shape 

224 seed_xx = np.random.randint(0, width, num_seeds) 

225 seed_yy = np.random.randint(0, height, num_seeds) 

226 return seed_xx, seed_yy 

227 

228 

229def generate_seeds_by_method( 

230 image: np.ndarray, 

231 method: SeedingMethod = SeedingMethod.BLOB_DETECTION, 

232 num_seeds: int = 100, 

233 min_sigma: float = 1.0, 

234 max_sigma: float = 2.0, 

235 threshold: float = 0.02 

236) -> Tuple[np.ndarray, np.ndarray]: 

237 """ 

238 Generate seeds using the specified method. 

239 

240 Args: 

241 image: Input image for seed generation 

242 method: Seeding method to use 

243 num_seeds: Number of seeds for random method 

244 min_sigma: Min sigma for blob detection 

245 max_sigma: Max sigma for blob detection 

246 threshold: Threshold for blob detection 

247 

248 Returns: 

249 seed_xx, seed_yy: Arrays of seed coordinates 

250 """ 

251 if method == SeedingMethod.RANDOM: 

252 # Paper's original method - pure random seeding 

253 return generate_random_seeds(image.shape, num_seeds) 

254 

255 elif method == SeedingMethod.BLOB_DETECTION: 

256 # Enhanced method - seeds on detected blobs 

257 edge_map = boundary_masking_blob(image, min_sigma, max_sigma, threshold) 

258 return random_seed_by_edge_map(edge_map) 

259 

260 elif method == SeedingMethod.CANNY_EDGES: 

261 # Alternative - seeds on Canny edges 

262 edge_map = boundary_masking_canny(image) 

263 return random_seed_by_edge_map(edge_map) 

264 

265 elif method == SeedingMethod.GROWTH_CONES: 

266 # Alternative - seeds on growth cones 

267 return get_growth_cone_positions(image) 

268 

269 else: 

270 raise ValueError(f"Unknown seeding method: {method}") 

271 

272def get_growth_cone_positions(image): 

273 """ 

274 Detect growth cone positions using morphological operations (OpenHCS-compatible). 

275 

276 Args: 

277 image: Input image for growth cone detection 

278 

279 Returns: 

280 seed_xx, seed_yy: Arrays of growth cone center coordinates 

281 """ 

282 # Threshold the image to create a binary mask 

283 mask = image > skimage.filters.threshold_otsu(image) 

284 

285 # Use morphological closing to fill in small gaps in the mask 

286 mask = skimage.morphology.closing(mask, skimage.morphology.disk(3)) 

287 labeled = skimage.measure.label(mask) 

288 props = skimage.measure.regionprops(labeled) 

289 

290 seed_xx = [] 

291 seed_yy = [] 

292 for prop in props: 

293 seed_xx.append(prop.centroid[1]) # x coordinate 

294 seed_yy.append(prop.centroid[0]) # y coordinate 

295 

296 return np.array(seed_xx), np.array(seed_yy) 

297 

298def selected_seeding(image,seed_xx,seed_yy,chain_level=1.05,total_node=8,node_r=None,line_length_min=32): 

299 im_copy=np.copy(image) 

300 alva_HMM = alva_MCMC.AlvaHmm(im_copy, 

301 total_node = total_node, 

302 total_path = None, 

303 node_r = node_r, 

304 node_angle_max = None,) 

305 chain_HMM_1st, pair_chain_HMM, pair_seed_xx, pair_seed_yy = alva_HMM.pair_HMM_chain(seed_xx = seed_xx, 

306 seed_yy = seed_yy, 

307 chain_level = chain_level,) 

308 for chain_i in [0, 1]: 

309 chain_HMM = [chain_HMM_1st, pair_chain_HMM][chain_i] 

310 real_chain_ii, real_chain_aa, real_chain_xx, real_chain_yy = chain_HMM[0:4] 

311 seed_node_xx, seed_node_yy = chain_HMM[4:6] 

312 

313 chain_im_fine = alva_HMM.chain_image(chain_HMM_1st, pair_chain_HMM,) 

314 return alva_branch.connect_way(chain_im_fine, 

315 line_length_min = line_length_min, 

316 free_zone_from_y0 = None,) 

317 

318def euclidian_distance(x1, y1, x2, y2): 

319 distance = math.sqrt((x1 - x2) ** 2 + (y1 - y2) ** 2) 

320 return distance 

321 

322def extract_graph(root_tree_xx,root_tree_yy): 

323 graph = nx.Graph() 

324 for path_x,path_y in zip(root_tree_xx,root_tree_yy): 

325 for x,y in zip(path_x,path_y): 

326 graph.add_node((x,y)) 

327 for i in range(len(path_x)-1): 

328 distance=euclidian_distance(path_x[i], path_y[i], path_x[i + 1], path_y[i + 1]) 

329 graph.add_edge((path_x[i], path_y[i]), (path_x[i + 1], path_y[i + 1]),weight=distance) 

330 return graph 

331 

332def graph_to_length(graph): 

333 total_distance = 0 

334 for u, v, data in graph.edges(data=True): 

335 total_distance += data['weight'] 

336 return total_distance 

337 

338def create_overlay_from_graph(original_image: np.ndarray, graph: nx.Graph) -> np.ndarray: 

339 """ 

340 Create overlay visualization with traces on original image. 

341 

342 Args: 

343 original_image: Original input image (numpy array) 

344 graph: NetworkX graph containing trace coordinates as (x, y) tuples 

345 

346 Returns: 

347 Overlay array with traces highlighted on original image 

348 """ 

349 overlay = original_image.copy() 

350 # Set trace pixels to maximum intensity for visibility 

351 max_val = np.max(original_image) if original_image.size > 0 else 1 

352 

353 for u, v in graph.edges: 

354 x1, y1 = u # Nodes are stored as (x, y) tuples 

355 x2, y2 = v # Nodes are stored as (x, y) tuples 

356 # Bounds checking 

357 if (0 <= y1 < original_image.shape[0] and 0 <= x1 < original_image.shape[1]): 

358 overlay[y1, x1] = max_val 

359 if (0 <= y2 < original_image.shape[0] and 0 <= x2 < original_image.shape[1]): 

360 overlay[y2, x2] = max_val 

361 

362 return overlay 

363 

364def create_visualization_array( 

365 original_image: np.ndarray, 

366 graph: nx.Graph, 

367 mode: VisualizationMode 

368) -> np.ndarray: 

369 """ 

370 Create visualization array based on the specified mode. 

371 

372 Args: 

373 original_image: Original 2D image 

374 graph: NetworkX graph with traced neurites 

375 mode: Visualization mode 

376 

377 Returns: 

378 2D array for visualization 

379 """ 

380 if mode == VisualizationMode.NONE: 

381 # Return zeros array 

382 return np.zeros_like(original_image, dtype=original_image.dtype) 

383 

384 elif mode == VisualizationMode.TRACE_ONLY: 

385 # Create binary mask with traced neurites 

386 trace_mask = np.zeros_like(original_image, dtype=np.uint8) 

387 for u, v in graph.edges: 

388 x1, y1 = u # Fix: nodes are stored as (x, y) not (y, x) 

389 x2, y2 = v # Fix: nodes are stored as (x, y) not (y, x) 

390 # Bounds checking 

391 if (0 <= y1 < original_image.shape[0] and 0 <= x1 < original_image.shape[1]): 

392 trace_mask[y1, x1] = 1 

393 if (0 <= y2 < original_image.shape[0] and 0 <= x2 < original_image.shape[1]): 

394 trace_mask[y2, x2] = 1 

395 return trace_mask 

396 

397 elif mode == VisualizationMode.OVERLAY: 

398 # Use shared overlay function 

399 return create_overlay_from_graph(original_image, graph) 

400 

401 else: 

402 raise ValueError(f"Unknown visualization mode: {mode}") 

403 

404@special_outputs(("hmm_analysis", materialize_hmm_analysis), ("trace_visualizations", materialize_trace_visualizations)) 

405@numpy 

406def trace_neurites_rrs_alva( 

407 image_stack: np.ndarray, 

408 seeding_method: SeedingMethod = SeedingMethod.BLOB_DETECTION, 

409 return_trace_visualizations: bool = False, 

410 trace_visualization_mode: VisualizationMode = VisualizationMode.TRACE_ONLY, 

411 chain_level: float = 1.05, 

412 node_r: Optional[int] = None, 

413 total_node: Optional[int] = None, 

414 line_length_min: int = 32, 

415 num_seeds: int = 100, 

416 min_sigma: float = 1.0, 

417 max_sigma: float = 2.0, 

418 threshold: float = 0.02, 

419 normalize_image: bool = False, 

420 percentile: float = 99.9 

421) -> Tuple[np.ndarray, Dict[str, Any], List[np.ndarray]]: 

422 """ 

423 Trace neurites using the alvahmm RRS (Random-Reaction-Seed) algorithm. 

424 

425 This is the OpenHCS-compatible version of the original alvahmm implementation. 

426 Performs bidirectional HMM tracing with branching analysis to reconstruct 

427 complete neurite morphology. 

428 

429 Args: 

430 image_stack: 3D array of shape (Z, Y, X) - input image stack 

431 seeding_method: Method for seed generation (RANDOM=paper default, BLOB_DETECTION=enhanced) 

432 return_trace_visualizations: Whether to generate trace visualizations as special output 

433 trace_visualization_mode: How to visualize results (NONE, TRACE_ONLY, OVERLAY) 

434 chain_level: Validation threshold for HMM chains (default: 1.05) 

435 node_r: Path length between adjacent nodes (default: None, uses alvahmm default) 

436 total_node: Number of HMM nodes in chain (default: None, uses alvahmm default) 

437 line_length_min: Minimum line length for connection (default: 32) 

438 num_seeds: Number of seeds for random seeding method (default: 100) 

439 min_sigma: Minimum sigma for blob detection (default: 1.0) 

440 max_sigma: Maximum sigma for blob detection (default: 2.0) 

441 threshold: Threshold for blob detection (default: 0.02) 

442 normalize_image: Whether to apply percentile normalization (default: False, paper doesn't use) 

443 percentile: Percentile for normalization if enabled (default: 99.9) 

444 

445 Returns: 

446 result_image: Original image stack unchanged (Z, Y, X) 

447 analysis_results: HMM analysis data structure with graph and metrics 

448 trace_visualizations: (Special output) List of visualization arrays if return_trace_visualizations=True 

449 """ 

450 # Validate input is 3D 

451 if image_stack.ndim != 3: 

452 raise ValueError(f"Expected 3D array, got {image_stack.ndim}D") 

453 

454 # Process each slice individually 

455 Z, Y, X = image_stack.shape 

456 all_graphs = [] 

457 trace_visualizations = [] 

458 

459 for z in range(Z): 

460 im_axon = image_stack[z].astype(np.float64) 

461 

462 # Optional normalization (removed from default, paper doesn't use) 

463 if normalize_image: 

464 im_axon = normalize(im_axon, percentile=percentile) 

465 

466 # Generate seeds using selected method 

467 seed_xx, seed_yy = generate_seeds_by_method( 

468 im_axon, 

469 method=seeding_method, 

470 num_seeds=num_seeds, 

471 min_sigma=min_sigma, 

472 max_sigma=max_sigma, 

473 threshold=threshold 

474 ) 

475 

476 # Perform RRS tracing with bidirectional HMM chains 

477 root_tree_yy, root_tree_xx, root_tip_yy, root_tip_xx = selected_seeding( 

478 im_axon, 

479 seed_xx, 

480 seed_yy, 

481 chain_level=chain_level, 

482 node_r=node_r, 

483 total_node=total_node, 

484 line_length_min=line_length_min 

485 ) 

486 

487 # Extract graph representation for this slice 

488 graph = extract_graph(root_tree_xx, root_tree_yy) 

489 all_graphs.append(graph) 

490 

491 # Create visualization for this slice if requested 

492 if return_trace_visualizations: 

493 visualization = create_visualization_array(im_axon, graph, trace_visualization_mode) 

494 trace_visualizations.append(visualization) 

495 

496 # Combine all graphs (for compatibility, return the first one) 

497 combined_graph = all_graphs[0] if all_graphs else nx.Graph() 

498 

499 # Compile analysis results 

500 analysis_results = _compile_hmm_analysis_results( 

501 combined_graph, all_graphs, image_stack.shape, 

502 seeding_method, trace_visualization_mode, chain_level, 

503 node_r, total_node, line_length_min 

504 ) 

505 

506 # Always return original image, analysis results, and trace visualizations 

507 return image_stack, analysis_results, trace_visualizations 

508 

509 

510def _compile_hmm_analysis_results( 

511 combined_graph: nx.Graph, 

512 all_graphs: List[nx.Graph], 

513 image_shape: Tuple[int, int, int], 

514 seeding_method: SeedingMethod, 

515 visualization_mode: VisualizationMode, 

516 chain_level: float, 

517 node_r: Optional[int], 

518 total_node: Optional[int], 

519 line_length_min: int 

520) -> Dict[str, Any]: 

521 """Compile comprehensive HMM analysis results.""" 

522 from datetime import datetime 

523 

524 # Compute summary metrics from the graph 

525 num_nodes = combined_graph.number_of_nodes() 

526 num_edges = combined_graph.number_of_edges() 

527 

528 # Calculate total trace length 

529 total_length = 0.0 

530 edge_lengths = [] 

531 for u, v, data in combined_graph.edges(data=True): 

532 # Calculate Euclidean distance between nodes 

533 x1, y1 = u 

534 x2, y2 = v 

535 length = ((x2 - x1)**2 + (y2 - y1)**2)**0.5 

536 edge_lengths.append(length) 

537 total_length += length 

538 

539 # Summary metrics 

540 summary = { 

541 'total_trace_length': float(total_length), 

542 'num_nodes': int(num_nodes), 

543 'num_edges': int(num_edges), 

544 'num_slices_processed': len(all_graphs), 

545 'mean_edge_length': float(sum(edge_lengths) / len(edge_lengths)) if edge_lengths else 0.0, 

546 'max_edge_length': float(max(edge_lengths)) if edge_lengths else 0.0, 

547 'graph_density': float(nx.density(combined_graph)) if num_nodes > 1 else 0.0, 

548 'num_connected_components': int(nx.number_connected_components(combined_graph)), 

549 } 

550 

551 # Metadata 

552 metadata = { 

553 'algorithm': 'alvahmm_rrs', 

554 'seeding_method': seeding_method.value, 

555 'visualization_mode': visualization_mode.value, 

556 'chain_level': chain_level, 

557 'node_r': node_r, 

558 'total_node': total_node, 

559 'line_length_min': line_length_min, 

560 'image_shape': image_shape, 

561 'processing_timestamp': datetime.now().isoformat(), 

562 } 

563 

564 return { 

565 'summary': summary, 

566 'graph': combined_graph, 

567 'metadata': metadata 

568 } 

569 

570 

571# Legacy file-based processing function (kept for reference) 

572def process_file_legacy(filename, input_folder, output_folder, **kwargs): 

573 """ 

574 Legacy file-based processing function. 

575 

576 This is kept for reference but should not be used in OpenHCS. 

577 Use trace_neurites_rrs_alva() instead for array-in/array-out processing. 

578 """ 

579 raise NotImplementedError( 

580 "Legacy file-based processing not supported in OpenHCS. " 

581 "Use trace_neurites_rrs_alva() for array-in/array-out processing." 

582 ) 

583