Coverage for openhcs/processing/backends/analysis/straighten_object_3d.py: 5.0%

170 statements  

« prev     ^ index     » next       coverage.py v7.10.3, created at 2025-08-14 05:57 +0000

1from __future__ import annotations 

2 

3import logging 

4from typing import Optional, Tuple, Union 

5 

6from openhcs.core.utils import optional_import 

7from openhcs.core.memory.decorators import torch as torch_func 

8 

9# Import torch modules as optional dependencies 

10torch = optional_import("torch") 

11F = optional_import("torch.nn.functional") if torch is not None else None 

12 

13logger = logging.getLogger(__name__) 

14 

15def _moving_average_1d_torch(data: torch.Tensor, window_size: int) -> torch.Tensor: 

16 """Applies a 1D moving average filter along the last dimension.""" 

17 if window_size <= 0: 

18 return data 

19 if window_size % 2 == 0: # Ensure odd window size for centered average 

20 window_size += 1 

21 

22 weights = torch.ones(window_size, device=data.device, dtype=data.dtype) / window_size 

23 weights = weights.view(1, 1, -1) # Shape (out_channels, in_channels/groups, kW) 

24 

25 # data shape (N, L) -> (N, 1, L) for conv1d 

26 data_reshaped = data.unsqueeze(1) 

27 padding = window_size // 2 

28 

29 smoothed_data = F.conv1d(data_reshaped, weights, padding=padding) 

30 return smoothed_data.squeeze(1) 

31 

32 

33@torch_func 

34def straighten_object_3d( 

35 image_volume: torch.Tensor, # Expected (Z, H, W) or (1, Z, H, W) 

36 min_voxel_threshold: float, # Required parameter 

37 patch_radius: Optional[int] = None, 

38 sampling_spacing: float = 1.0, 

39 max_components: int = 1, 

40 return_grid: bool = False, 

41 spline_smoothness: Optional[float] = None, 

42 **kwargs 

43) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: 

44 """ 

45 Identifies and straightens the largest continuous 3D object in a preprocessed 

46 input volume using PyTorch for GPU-native operations. 

47 """ 

48 if not isinstance(image_volume, torch.Tensor): 

49 raise TypeError(f"Input image_volume must be a PyTorch Tensor. Got {type(image_volume)}") 

50 

51 device = image_volume.device 

52 original_dtype = image_volume.dtype 

53 

54 # --- Input Shape Handling --- 

55 img_vol_proc = image_volume.float() # Work with float32 

56 if img_vol_proc.ndim == 4: # (1, Z, H, W) 

57 if img_vol_proc.shape[0] != 1: 

58 raise ValueError("If 4D, first dimension (batch) must be 1.") 

59 img_vol_proc = img_vol_proc.squeeze(0) # (Z, H, W) 

60 elif img_vol_proc.ndim != 3: # (Z, H, W) 

61 raise ValueError(f"image_volume must be 3D (Z,H,W) or 4D (1,Z,H,W). Got {image_volume.ndim}D") 

62 

63 Z_orig, H_orig, W_orig = img_vol_proc.shape 

64 

65 # --- Parse Parameters & Compute Defaults --- 

66 min_voxel_threshold = float(min_voxel_threshold) # Already validated as required 

67 

68 patch_radius_val = patch_radius 

69 if patch_radius_val is None: 

70 patch_radius_val = min(H_orig, W_orig) // 10 

71 patch_radius_val = int(patch_radius_val) 

72 if patch_radius_val <= 0: 

73 patch_radius_val = 1 # Ensure positive patch radius 

74 

75 sampling_spacing_val = float(sampling_spacing) 

76 if sampling_spacing_val <= 0: 

77 sampling_spacing_val = 1.0 

78 

79 max_components_val = int(max_components) 

80 if max_components_val != 1: 

81 # Full 3D CC on GPU without SciPy/etc. is complex for this scope. 

82 raise NotImplementedError("max_components > 1 is not implemented due to GPU CC complexity constraints.") 

83 

84 return_grid_val = bool(return_grid) 

85 

86 # spline_smoothness default depends on curve length, calculated later 

87 

88 # --- 1. Thresholding + Masking (Simplified for largest object) --- 

89 binary_mask = img_vol_proc > min_voxel_threshold 

90 

91 object_coords_idx = torch.nonzero(binary_mask, as_tuple=False) # (N_voxels, 3) -> (z, y, x) 

92 if object_coords_idx.shape[0] == 0: 

93 logger.warning("No object found above threshold. Returning empty tensor(s).") 

94 patch_dim = 2 * patch_radius_val + 1 

95 empty_vol = torch.empty((0, patch_dim, patch_dim), device=device, dtype=original_dtype) 

96 empty_grid = torch.empty((0, patch_dim, patch_dim, 3), device=device, dtype=torch.float32) 

97 return (empty_vol, empty_grid) if return_grid_val else empty_vol 

98 

99 object_coords = object_coords_idx.float() 

100 

101 # --- 2. Centerline Fitting --- 

102 # PCA for initial orientation 

103 mean_coord = torch.mean(object_coords, dim=0) 

104 centered_coords = object_coords - mean_coord 

105 

106 # Ensure there's enough variance for SVD 

107 if centered_coords.shape[0] < 2 or torch.allclose(centered_coords, torch.zeros_like(centered_coords)): 

108 logger.warning("Not enough distinct object coordinates for PCA. Using a simplified centerline.") 

109 # Fallback: use z-axis if object is too small/flat for PCA 

110 # This is a simplification; a more robust fallback might be needed. 

111 min_z, max_z = object_coords[:,0].min(), object_coords[:,0].max() 

112 num_fallback_points = max(2, int((max_z - min_z) / sampling_spacing_val) +1) 

113 centerline_points_smooth = torch.zeros((num_fallback_points, 3), device=device) 

114 centerline_points_smooth[:,0] = torch.linspace(min_z, max_z, num_fallback_points, device=device) 

115 centerline_points_smooth[:,1] = mean_coord[1] # Use mean y 

116 centerline_points_smooth[:,2] = mean_coord[2] # Use mean x 

117 else: 

118 # Using try-except for SVD as it can fail on ill-conditioned matrices 

119 try: 

120 # Using .T @ . instead of cov for potentially better stability with many points 

121 # For SVD on covariance: U, S, Vh = torch.linalg.svd(torch.cov(centered_coords.T)) 

122 # Vh[0] would be the principal axis. Here V from (coords.T @ coords) gives principal components as columns. 

123 U_pca, S_pca, V_pca_transpose = torch.linalg.svd(centered_coords, full_matrices=False) 

124 principal_axis = V_pca_transpose[0, :] # First right singular vector 

125 except Exception as e: 

126 logger.warning(f"PCA (SVD) failed: {e}. Using simplified z-axis centerline.") 

127 min_z, max_z = object_coords[:,0].min(), object_coords[:,0].max() 

128 num_fallback_points = max(2, int((max_z - min_z) / sampling_spacing_val) +1) 

129 centerline_points_smooth = torch.zeros((num_fallback_points, 3), device=device) 

130 centerline_points_smooth[:,0] = torch.linspace(min_z, max_z, num_fallback_points, device=device) 

131 centerline_points_smooth[:,1] = mean_coord[1] 

132 centerline_points_smooth[:,2] = mean_coord[2] 

133 else: 

134 projected_scalar = torch.matmul(centered_coords, principal_axis) 

135 sorted_indices = torch.argsort(projected_scalar) 

136 sorted_coords_on_axis = object_coords[sorted_indices] 

137 

138 # Estimate curve length for spline_smoothness default 

139 segment_lengths_est = torch.norm(sorted_coords_on_axis[1:] - sorted_coords_on_axis[:-1], dim=1) 

140 curve_length_est = torch.sum(segment_lengths_est) 

141 

142 spline_smoothness_val = float(spline_smoothness if spline_smoothness is not None else 0.01 * curve_length_est.item()) 

143 

144 # Moving Average for smoothing (applied to each coordinate) 

145 # Window size needs to be related to spline_smoothness and number of points 

146 # A larger spline_smoothness_val should imply a larger window. 

147 # This is a heuristic. A proper spline would use the smoothness param differently. 

148 num_sorted_pts = sorted_coords_on_axis.shape[0] 

149 # Heuristic: window_size proportional to smoothness and curve length, bounded by num_points 

150 # Let spline_smoothness_val be a fraction of total points for window size 

151 # e.g. if spline_smoothness_val is 0.01 (1%), window is 1% of points 

152 # The prompt's default is 0.01 * curve_length. This needs interpretation for window size. 

153 # Let's assume spline_smoothness_val is a fraction of the number of points for window size. 

154 # If it was given as 0.01 * curve_length, and curve_length ~ num_points * spacing, 

155 # then window_size ~ 0.01 * num_points * spacing / spacing = 0.01 * num_points. 

156 # This interpretation makes spline_smoothness a relative factor. 

157 # Let's use a simpler interpretation: spline_smoothness is a direct factor for window size relative to num_points 

158 # Or, if it's an absolute voxel unit, then window_size = smoothness / avg_spacing_along_axis 

159 # Given the prompt "0.01 * curve length", let's use it to define window size. 

160 # Avg spacing along curve: curve_length_est / num_sorted_pts 

161 # Window size in voxels: spline_smoothness_val (if it's already in voxels) 

162 # If spline_smoothness_val was from 0.01 * curve_length, it's already in "voxel units" 

163 # So, window_size_voxels = spline_smoothness_val. 

164 # Number of points in window = window_size_voxels / (curve_length_est / num_sorted_pts) 

165 

166 # Simpler: let spline_smoothness be a relative factor for window size 

167 # For now, let's use a fixed relative window size or a small absolute one if spline_smoothness is small. 

168 # The prompt's default "0.01 * curve_length" is tricky to map directly to a moving average window 

169 # without more context on how that smoothness value is intended. 

170 # Let's assume spline_smoothness_val is a target smoothing window in voxel units. 

171 # Approximate points per voxel unit along curve: num_sorted_pts / curve_length_est 

172 # Window size in points: spline_smoothness_val * (num_sorted_pts / curve_length_est) 

173 if curve_length_est > 1e-3: # Avoid division by zero 

174 window_size = int(max(3, spline_smoothness_val * (num_sorted_pts / curve_length_est.item()))) 

175 else: 

176 window_size = 3 

177 window_size = min(window_size, num_sorted_pts // 2 if num_sorted_pts > 4 else 1) # Cap window size 

178 if window_size < 1: window_size = 1 

179 if window_size % 2 == 0: window_size +=1 # must be odd for centered 

180 

181 if num_sorted_pts > window_size and window_size > 1 : # only smooth if enough points and window > 1 

182 centerline_points_smooth_z = _moving_average_1d_torch(sorted_coords_on_axis[:, 0].unsqueeze(0), window_size).squeeze(0) 

183 centerline_points_smooth_y = _moving_average_1d_torch(sorted_coords_on_axis[:, 1].unsqueeze(0), window_size).squeeze(0) 

184 centerline_points_smooth_x = _moving_average_1d_torch(sorted_coords_on_axis[:, 2].unsqueeze(0), window_size).squeeze(0) 

185 centerline_points_smooth = torch.stack([centerline_points_smooth_z, centerline_points_smooth_y, centerline_points_smooth_x], dim=1) 

186 else: 

187 centerline_points_smooth = sorted_coords_on_axis 

188 

189 

190 # Resample smoothed centerline 

191 segment_lengths = torch.norm(centerline_points_smooth[1:] - centerline_points_smooth[:-1], p=2, dim=1) 

192 if segment_lengths.numel() == 0: # Single point object after smoothing/PCA 

193 logger.warning("Centerline reduced to a single point. Cannot straighten.") 

194 # Return empty or a single slice based on patch_radius 

195 patch_dim = 2 * patch_radius_val + 1 

196 single_slice = img_vol_proc[ 

197 int(mean_coord[0]), 

198 max(0, int(mean_coord[1])-patch_radius_val):min(H_orig, int(mean_coord[1])+patch_radius_val+1), 

199 max(0, int(mean_coord[2])-patch_radius_val):min(W_orig, int(mean_coord[2])+patch_radius_val+1) 

200 ] 

201 # Pad if necessary to patch_dim x patch_dim 

202 padded_slice = torch.zeros((patch_dim, patch_dim), device=device, dtype=original_dtype) 

203 h_s, w_s = single_slice.shape 

204 y_start, x_start = (patch_dim - h_s)//2, (patch_dim - w_s)//2 

205 padded_slice[y_start:y_start+h_s, x_start:x_start+w_s] = single_slice[:patch_dim, :patch_dim].to(original_dtype) 

206 

207 final_volume = padded_slice.unsqueeze(0) # (1, patch_dim, patch_dim) 

208 final_grid = torch.zeros((1, patch_dim, patch_dim, 3), device=device, dtype=torch.float32) # Dummy grid 

209 return (final_volume, final_grid) if return_grid_val else final_volume 

210 

211 

212 cum_lengths = torch.cat((torch.tensor([0.0], device=device), torch.cumsum(segment_lengths, dim=0))) 

213 total_curve_length = cum_lengths[-1] 

214 

215 if total_curve_length < 1e-3: # Effectively a point 

216 num_samples_L = 1 

217 else: 

218 num_samples_L = int(torch.ceil(total_curve_length / sampling_spacing_val).item()) 

219 if num_samples_L < 2: num_samples_L = 2 # Need at least 2 points for tangents 

220 

221 target_cum_lengths = torch.linspace(0, total_curve_length, num_samples_L, device=device) 

222 

223 resampled_centerline = torch.empty((num_samples_L, 3), device=device, dtype=torch.float32) 

224 

225 # Interpolation for resampling 

226 # For each target_cum_length, find its place in original cum_lengths 

227 indices = torch.searchsorted(cum_lengths, target_cum_lengths, right=True) - 1 

228 indices = torch.clamp(indices, 0, cum_lengths.shape[0] - 2) # Ensure valid index range 

229 

230 # Alpha for interpolation: (target - prev_cum) / (next_cum - prev_cum) 

231 len_prev_cum = cum_lengths[indices] 

232 len_next_cum = cum_lengths[indices + 1] 

233 

234 # Avoid division by zero if segment length is zero 

235 segment_len_for_alpha = len_next_cum - len_prev_cum 

236 alpha = torch.where( 

237 segment_len_for_alpha > 1e-6, 

238 (target_cum_lengths - len_prev_cum) / segment_len_for_alpha, 

239 torch.zeros_like(target_cum_lengths) # if segment length is zero, alpha is 0 

240 ) 

241 alpha = alpha.unsqueeze(1) # for broadcasting with coordinates 

242 

243 pt_prev = centerline_points_smooth[indices] 

244 pt_next = centerline_points_smooth[indices + 1] 

245 resampled_centerline = pt_prev * (1.0 - alpha) + pt_next * alpha 

246 

247 # --- 3. Plane Sampling --- 

248 L = resampled_centerline.shape[0] 

249 patch_dim = 2 * patch_radius_val + 1 

250 

251 sampling_grid_slices = torch.empty((L, patch_dim, patch_dim, 3), device=device, dtype=torch.float32) 

252 

253 # Create local patch coordinates (relative to plane center) 

254 u_coords = torch.linspace(-patch_radius_val, patch_radius_val, patch_dim, device=device) 

255 v_coords = torch.linspace(-patch_radius_val, patch_radius_val, patch_dim, device=device) 

256 grid_v_local, grid_u_local = torch.meshgrid(v_coords, u_coords, indexing='ij') # (patch_dim, patch_dim) 

257 

258 for i in range(L): 

259 current_point = resampled_centerline[i] # (z_c, y_c, x_c) 

260 

261 if L == 1: 

262 tangent = torch.tensor([1.0, 0.0, 0.0], device=device) # Arbitrary if only one point 

263 elif i == 0: 

264 tangent_vec = resampled_centerline[i+1] - resampled_centerline[i] 

265 elif i == L - 1: 

266 tangent_vec = resampled_centerline[i] - resampled_centerline[i-1] 

267 else: 

268 tangent_vec = (resampled_centerline[i+1] - resampled_centerline[i-1]) / 2.0 

269 

270 if torch.norm(tangent_vec) < 1e-6: # Degenerate tangent 

271 tangent = torch.tensor([1.0, 0.0, 0.0], device=device) # Default to Z-axis like 

272 else: 

273 tangent = F.normalize(tangent_vec, p=2, dim=0) # This is new Z' (z_new_axis) 

274 

275 # Define orthogonal plane vectors (new X', new Y') 

276 # Choose arbitrary vector not parallel to tangent 

277 if torch.abs(tangent[0]) < 0.9: # If tangent is not mostly along Z axis 

278 arbitrary_vec = torch.tensor([1.0, 0.0, 0.0], device=device) 

279 else: # Tangent is mostly along Z, pick X or Y 

280 arbitrary_vec = torch.tensor([0.0, 1.0, 0.0], device=device) 

281 

282 vec_u_prime_unnorm = torch.cross(tangent, arbitrary_vec) 

283 if torch.norm(vec_u_prime_unnorm) < 1e-6: # Arbitrary vec was parallel 

284 arbitrary_vec = torch.tensor([0.0, 0.0, 1.0], device=device) # Try another 

285 if torch.abs(torch.dot(tangent, arbitrary_vec)) > 0.99: # if tangent is Z axis 

286 arbitrary_vec = torch.tensor([0.0, 1.0, 0.0], device=device) 

287 vec_u_prime_unnorm = torch.cross(tangent, arbitrary_vec) 

288 

289 vec_u_prime = F.normalize(vec_u_prime_unnorm, p=2, dim=0) # New X' (x_new_axis) 

290 vec_v_prime = F.normalize(torch.cross(tangent, vec_u_prime), p=2, dim=0) # New Y' (y_new_axis) 

291 

292 # Transform local grid points to world coordinates 

293 # plane_points_world = current_point + grid_u_local[..., None] * vec_u_prime + grid_v_local[..., None] * vec_v_prime 

294 # current_point: [3] -> [1,1,3] 

295 # grid_u_local: [patch_dim, patch_dim] -> [patch_dim, patch_dim, 1] 

296 # vec_u_prime: [3] -> [1,1,3] 

297 plane_points_world_z = current_point[0] + grid_u_local * vec_u_prime[0] + grid_v_local * vec_v_prime[0] 

298 plane_points_world_y = current_point[1] + grid_u_local * vec_u_prime[1] + grid_v_local * vec_v_prime[1] 

299 plane_points_world_x = current_point[2] + grid_u_local * vec_u_prime[2] + grid_v_local * vec_v_prime[2] 

300 

301 # Normalize coordinates for grid_sample (expects x, y, z order in [-1, 1]) 

302 # Original volume dimensions: Z_orig, H_orig, W_orig 

303 norm_coords_x = 2.0 * (plane_points_world_x / (W_orig - 1)) - 1.0 if W_orig > 1 else torch.zeros_like(plane_points_world_x) 

304 norm_coords_y = 2.0 * (plane_points_world_y / (H_orig - 1)) - 1.0 if H_orig > 1 else torch.zeros_like(plane_points_world_y) 

305 norm_coords_z = 2.0 * (plane_points_world_z / (Z_orig - 1)) - 1.0 if Z_orig > 1 else torch.zeros_like(plane_points_world_z) 

306 

307 sampling_grid_slices[i] = torch.stack((norm_coords_x, norm_coords_y, norm_coords_z), dim=-1) 

308 

309 final_sampling_grid = sampling_grid_slices # Shape (L, patch_dim, patch_dim, 3) 

310 

311 # Prepare image_volume for grid_sample: (N, C, D_in, H_in, W_in) 

312 img_vol_for_sampling = img_vol_proc.unsqueeze(0).unsqueeze(0) # (1, 1, Z_orig, H_orig, W_orig) 

313 

314 # Reshape grid for grid_sample: (N, D_out, H_out, W_out, 3) 

315 # Here, D_out = L, H_out = patch_dim, W_out = patch_dim 

316 grid_for_sampling = final_sampling_grid.unsqueeze(0) # (1, L, patch_dim, patch_dim, 3) 

317 

318 aligned_volume_slices = F.grid_sample( 

319 img_vol_for_sampling, 

320 grid_for_sampling, 

321 mode='bilinear', 

322 padding_mode='zeros', # or 'border' 

323 align_corners=False # Usually False for feature sampling 

324 ) 

325 # Output shape: (N, C, D_out, H_out, W_out) -> (1, 1, L, patch_dim, patch_dim) 

326 

327 aligned_volume = aligned_volume_slices.squeeze(0).squeeze(0) # (L, patch_dim, patch_dim) 

328 aligned_volume = aligned_volume.to(original_dtype) 

329 

330 if return_grid_val: 

331 return aligned_volume, final_sampling_grid 

332 else: 

333 return aligned_volume