Coverage for openhcs/processing/backends/analysis/straighten_object_3d.py: 5.0%
170 statements
« prev ^ index » next coverage.py v7.10.3, created at 2025-08-14 05:57 +0000
« prev ^ index » next coverage.py v7.10.3, created at 2025-08-14 05:57 +0000
1from __future__ import annotations
3import logging
4from typing import Optional, Tuple, Union
6from openhcs.core.utils import optional_import
7from openhcs.core.memory.decorators import torch as torch_func
9# Import torch modules as optional dependencies
10torch = optional_import("torch")
11F = optional_import("torch.nn.functional") if torch is not None else None
13logger = logging.getLogger(__name__)
15def _moving_average_1d_torch(data: torch.Tensor, window_size: int) -> torch.Tensor:
16 """Applies a 1D moving average filter along the last dimension."""
17 if window_size <= 0:
18 return data
19 if window_size % 2 == 0: # Ensure odd window size for centered average
20 window_size += 1
22 weights = torch.ones(window_size, device=data.device, dtype=data.dtype) / window_size
23 weights = weights.view(1, 1, -1) # Shape (out_channels, in_channels/groups, kW)
25 # data shape (N, L) -> (N, 1, L) for conv1d
26 data_reshaped = data.unsqueeze(1)
27 padding = window_size // 2
29 smoothed_data = F.conv1d(data_reshaped, weights, padding=padding)
30 return smoothed_data.squeeze(1)
33@torch_func
34def straighten_object_3d(
35 image_volume: torch.Tensor, # Expected (Z, H, W) or (1, Z, H, W)
36 min_voxel_threshold: float, # Required parameter
37 patch_radius: Optional[int] = None,
38 sampling_spacing: float = 1.0,
39 max_components: int = 1,
40 return_grid: bool = False,
41 spline_smoothness: Optional[float] = None,
42 **kwargs
43) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
44 """
45 Identifies and straightens the largest continuous 3D object in a preprocessed
46 input volume using PyTorch for GPU-native operations.
47 """
48 if not isinstance(image_volume, torch.Tensor):
49 raise TypeError(f"Input image_volume must be a PyTorch Tensor. Got {type(image_volume)}")
51 device = image_volume.device
52 original_dtype = image_volume.dtype
54 # --- Input Shape Handling ---
55 img_vol_proc = image_volume.float() # Work with float32
56 if img_vol_proc.ndim == 4: # (1, Z, H, W)
57 if img_vol_proc.shape[0] != 1:
58 raise ValueError("If 4D, first dimension (batch) must be 1.")
59 img_vol_proc = img_vol_proc.squeeze(0) # (Z, H, W)
60 elif img_vol_proc.ndim != 3: # (Z, H, W)
61 raise ValueError(f"image_volume must be 3D (Z,H,W) or 4D (1,Z,H,W). Got {image_volume.ndim}D")
63 Z_orig, H_orig, W_orig = img_vol_proc.shape
65 # --- Parse Parameters & Compute Defaults ---
66 min_voxel_threshold = float(min_voxel_threshold) # Already validated as required
68 patch_radius_val = patch_radius
69 if patch_radius_val is None:
70 patch_radius_val = min(H_orig, W_orig) // 10
71 patch_radius_val = int(patch_radius_val)
72 if patch_radius_val <= 0:
73 patch_radius_val = 1 # Ensure positive patch radius
75 sampling_spacing_val = float(sampling_spacing)
76 if sampling_spacing_val <= 0:
77 sampling_spacing_val = 1.0
79 max_components_val = int(max_components)
80 if max_components_val != 1:
81 # Full 3D CC on GPU without SciPy/etc. is complex for this scope.
82 raise NotImplementedError("max_components > 1 is not implemented due to GPU CC complexity constraints.")
84 return_grid_val = bool(return_grid)
86 # spline_smoothness default depends on curve length, calculated later
88 # --- 1. Thresholding + Masking (Simplified for largest object) ---
89 binary_mask = img_vol_proc > min_voxel_threshold
91 object_coords_idx = torch.nonzero(binary_mask, as_tuple=False) # (N_voxels, 3) -> (z, y, x)
92 if object_coords_idx.shape[0] == 0:
93 logger.warning("No object found above threshold. Returning empty tensor(s).")
94 patch_dim = 2 * patch_radius_val + 1
95 empty_vol = torch.empty((0, patch_dim, patch_dim), device=device, dtype=original_dtype)
96 empty_grid = torch.empty((0, patch_dim, patch_dim, 3), device=device, dtype=torch.float32)
97 return (empty_vol, empty_grid) if return_grid_val else empty_vol
99 object_coords = object_coords_idx.float()
101 # --- 2. Centerline Fitting ---
102 # PCA for initial orientation
103 mean_coord = torch.mean(object_coords, dim=0)
104 centered_coords = object_coords - mean_coord
106 # Ensure there's enough variance for SVD
107 if centered_coords.shape[0] < 2 or torch.allclose(centered_coords, torch.zeros_like(centered_coords)):
108 logger.warning("Not enough distinct object coordinates for PCA. Using a simplified centerline.")
109 # Fallback: use z-axis if object is too small/flat for PCA
110 # This is a simplification; a more robust fallback might be needed.
111 min_z, max_z = object_coords[:,0].min(), object_coords[:,0].max()
112 num_fallback_points = max(2, int((max_z - min_z) / sampling_spacing_val) +1)
113 centerline_points_smooth = torch.zeros((num_fallback_points, 3), device=device)
114 centerline_points_smooth[:,0] = torch.linspace(min_z, max_z, num_fallback_points, device=device)
115 centerline_points_smooth[:,1] = mean_coord[1] # Use mean y
116 centerline_points_smooth[:,2] = mean_coord[2] # Use mean x
117 else:
118 # Using try-except for SVD as it can fail on ill-conditioned matrices
119 try:
120 # Using .T @ . instead of cov for potentially better stability with many points
121 # For SVD on covariance: U, S, Vh = torch.linalg.svd(torch.cov(centered_coords.T))
122 # Vh[0] would be the principal axis. Here V from (coords.T @ coords) gives principal components as columns.
123 U_pca, S_pca, V_pca_transpose = torch.linalg.svd(centered_coords, full_matrices=False)
124 principal_axis = V_pca_transpose[0, :] # First right singular vector
125 except Exception as e:
126 logger.warning(f"PCA (SVD) failed: {e}. Using simplified z-axis centerline.")
127 min_z, max_z = object_coords[:,0].min(), object_coords[:,0].max()
128 num_fallback_points = max(2, int((max_z - min_z) / sampling_spacing_val) +1)
129 centerline_points_smooth = torch.zeros((num_fallback_points, 3), device=device)
130 centerline_points_smooth[:,0] = torch.linspace(min_z, max_z, num_fallback_points, device=device)
131 centerline_points_smooth[:,1] = mean_coord[1]
132 centerline_points_smooth[:,2] = mean_coord[2]
133 else:
134 projected_scalar = torch.matmul(centered_coords, principal_axis)
135 sorted_indices = torch.argsort(projected_scalar)
136 sorted_coords_on_axis = object_coords[sorted_indices]
138 # Estimate curve length for spline_smoothness default
139 segment_lengths_est = torch.norm(sorted_coords_on_axis[1:] - sorted_coords_on_axis[:-1], dim=1)
140 curve_length_est = torch.sum(segment_lengths_est)
142 spline_smoothness_val = float(spline_smoothness if spline_smoothness is not None else 0.01 * curve_length_est.item())
144 # Moving Average for smoothing (applied to each coordinate)
145 # Window size needs to be related to spline_smoothness and number of points
146 # A larger spline_smoothness_val should imply a larger window.
147 # This is a heuristic. A proper spline would use the smoothness param differently.
148 num_sorted_pts = sorted_coords_on_axis.shape[0]
149 # Heuristic: window_size proportional to smoothness and curve length, bounded by num_points
150 # Let spline_smoothness_val be a fraction of total points for window size
151 # e.g. if spline_smoothness_val is 0.01 (1%), window is 1% of points
152 # The prompt's default is 0.01 * curve_length. This needs interpretation for window size.
153 # Let's assume spline_smoothness_val is a fraction of the number of points for window size.
154 # If it was given as 0.01 * curve_length, and curve_length ~ num_points * spacing,
155 # then window_size ~ 0.01 * num_points * spacing / spacing = 0.01 * num_points.
156 # This interpretation makes spline_smoothness a relative factor.
157 # Let's use a simpler interpretation: spline_smoothness is a direct factor for window size relative to num_points
158 # Or, if it's an absolute voxel unit, then window_size = smoothness / avg_spacing_along_axis
159 # Given the prompt "0.01 * curve length", let's use it to define window size.
160 # Avg spacing along curve: curve_length_est / num_sorted_pts
161 # Window size in voxels: spline_smoothness_val (if it's already in voxels)
162 # If spline_smoothness_val was from 0.01 * curve_length, it's already in "voxel units"
163 # So, window_size_voxels = spline_smoothness_val.
164 # Number of points in window = window_size_voxels / (curve_length_est / num_sorted_pts)
166 # Simpler: let spline_smoothness be a relative factor for window size
167 # For now, let's use a fixed relative window size or a small absolute one if spline_smoothness is small.
168 # The prompt's default "0.01 * curve_length" is tricky to map directly to a moving average window
169 # without more context on how that smoothness value is intended.
170 # Let's assume spline_smoothness_val is a target smoothing window in voxel units.
171 # Approximate points per voxel unit along curve: num_sorted_pts / curve_length_est
172 # Window size in points: spline_smoothness_val * (num_sorted_pts / curve_length_est)
173 if curve_length_est > 1e-3: # Avoid division by zero
174 window_size = int(max(3, spline_smoothness_val * (num_sorted_pts / curve_length_est.item())))
175 else:
176 window_size = 3
177 window_size = min(window_size, num_sorted_pts // 2 if num_sorted_pts > 4 else 1) # Cap window size
178 if window_size < 1: window_size = 1
179 if window_size % 2 == 0: window_size +=1 # must be odd for centered
181 if num_sorted_pts > window_size and window_size > 1 : # only smooth if enough points and window > 1
182 centerline_points_smooth_z = _moving_average_1d_torch(sorted_coords_on_axis[:, 0].unsqueeze(0), window_size).squeeze(0)
183 centerline_points_smooth_y = _moving_average_1d_torch(sorted_coords_on_axis[:, 1].unsqueeze(0), window_size).squeeze(0)
184 centerline_points_smooth_x = _moving_average_1d_torch(sorted_coords_on_axis[:, 2].unsqueeze(0), window_size).squeeze(0)
185 centerline_points_smooth = torch.stack([centerline_points_smooth_z, centerline_points_smooth_y, centerline_points_smooth_x], dim=1)
186 else:
187 centerline_points_smooth = sorted_coords_on_axis
190 # Resample smoothed centerline
191 segment_lengths = torch.norm(centerline_points_smooth[1:] - centerline_points_smooth[:-1], p=2, dim=1)
192 if segment_lengths.numel() == 0: # Single point object after smoothing/PCA
193 logger.warning("Centerline reduced to a single point. Cannot straighten.")
194 # Return empty or a single slice based on patch_radius
195 patch_dim = 2 * patch_radius_val + 1
196 single_slice = img_vol_proc[
197 int(mean_coord[0]),
198 max(0, int(mean_coord[1])-patch_radius_val):min(H_orig, int(mean_coord[1])+patch_radius_val+1),
199 max(0, int(mean_coord[2])-patch_radius_val):min(W_orig, int(mean_coord[2])+patch_radius_val+1)
200 ]
201 # Pad if necessary to patch_dim x patch_dim
202 padded_slice = torch.zeros((patch_dim, patch_dim), device=device, dtype=original_dtype)
203 h_s, w_s = single_slice.shape
204 y_start, x_start = (patch_dim - h_s)//2, (patch_dim - w_s)//2
205 padded_slice[y_start:y_start+h_s, x_start:x_start+w_s] = single_slice[:patch_dim, :patch_dim].to(original_dtype)
207 final_volume = padded_slice.unsqueeze(0) # (1, patch_dim, patch_dim)
208 final_grid = torch.zeros((1, patch_dim, patch_dim, 3), device=device, dtype=torch.float32) # Dummy grid
209 return (final_volume, final_grid) if return_grid_val else final_volume
212 cum_lengths = torch.cat((torch.tensor([0.0], device=device), torch.cumsum(segment_lengths, dim=0)))
213 total_curve_length = cum_lengths[-1]
215 if total_curve_length < 1e-3: # Effectively a point
216 num_samples_L = 1
217 else:
218 num_samples_L = int(torch.ceil(total_curve_length / sampling_spacing_val).item())
219 if num_samples_L < 2: num_samples_L = 2 # Need at least 2 points for tangents
221 target_cum_lengths = torch.linspace(0, total_curve_length, num_samples_L, device=device)
223 resampled_centerline = torch.empty((num_samples_L, 3), device=device, dtype=torch.float32)
225 # Interpolation for resampling
226 # For each target_cum_length, find its place in original cum_lengths
227 indices = torch.searchsorted(cum_lengths, target_cum_lengths, right=True) - 1
228 indices = torch.clamp(indices, 0, cum_lengths.shape[0] - 2) # Ensure valid index range
230 # Alpha for interpolation: (target - prev_cum) / (next_cum - prev_cum)
231 len_prev_cum = cum_lengths[indices]
232 len_next_cum = cum_lengths[indices + 1]
234 # Avoid division by zero if segment length is zero
235 segment_len_for_alpha = len_next_cum - len_prev_cum
236 alpha = torch.where(
237 segment_len_for_alpha > 1e-6,
238 (target_cum_lengths - len_prev_cum) / segment_len_for_alpha,
239 torch.zeros_like(target_cum_lengths) # if segment length is zero, alpha is 0
240 )
241 alpha = alpha.unsqueeze(1) # for broadcasting with coordinates
243 pt_prev = centerline_points_smooth[indices]
244 pt_next = centerline_points_smooth[indices + 1]
245 resampled_centerline = pt_prev * (1.0 - alpha) + pt_next * alpha
247 # --- 3. Plane Sampling ---
248 L = resampled_centerline.shape[0]
249 patch_dim = 2 * patch_radius_val + 1
251 sampling_grid_slices = torch.empty((L, patch_dim, patch_dim, 3), device=device, dtype=torch.float32)
253 # Create local patch coordinates (relative to plane center)
254 u_coords = torch.linspace(-patch_radius_val, patch_radius_val, patch_dim, device=device)
255 v_coords = torch.linspace(-patch_radius_val, patch_radius_val, patch_dim, device=device)
256 grid_v_local, grid_u_local = torch.meshgrid(v_coords, u_coords, indexing='ij') # (patch_dim, patch_dim)
258 for i in range(L):
259 current_point = resampled_centerline[i] # (z_c, y_c, x_c)
261 if L == 1:
262 tangent = torch.tensor([1.0, 0.0, 0.0], device=device) # Arbitrary if only one point
263 elif i == 0:
264 tangent_vec = resampled_centerline[i+1] - resampled_centerline[i]
265 elif i == L - 1:
266 tangent_vec = resampled_centerline[i] - resampled_centerline[i-1]
267 else:
268 tangent_vec = (resampled_centerline[i+1] - resampled_centerline[i-1]) / 2.0
270 if torch.norm(tangent_vec) < 1e-6: # Degenerate tangent
271 tangent = torch.tensor([1.0, 0.0, 0.0], device=device) # Default to Z-axis like
272 else:
273 tangent = F.normalize(tangent_vec, p=2, dim=0) # This is new Z' (z_new_axis)
275 # Define orthogonal plane vectors (new X', new Y')
276 # Choose arbitrary vector not parallel to tangent
277 if torch.abs(tangent[0]) < 0.9: # If tangent is not mostly along Z axis
278 arbitrary_vec = torch.tensor([1.0, 0.0, 0.0], device=device)
279 else: # Tangent is mostly along Z, pick X or Y
280 arbitrary_vec = torch.tensor([0.0, 1.0, 0.0], device=device)
282 vec_u_prime_unnorm = torch.cross(tangent, arbitrary_vec)
283 if torch.norm(vec_u_prime_unnorm) < 1e-6: # Arbitrary vec was parallel
284 arbitrary_vec = torch.tensor([0.0, 0.0, 1.0], device=device) # Try another
285 if torch.abs(torch.dot(tangent, arbitrary_vec)) > 0.99: # if tangent is Z axis
286 arbitrary_vec = torch.tensor([0.0, 1.0, 0.0], device=device)
287 vec_u_prime_unnorm = torch.cross(tangent, arbitrary_vec)
289 vec_u_prime = F.normalize(vec_u_prime_unnorm, p=2, dim=0) # New X' (x_new_axis)
290 vec_v_prime = F.normalize(torch.cross(tangent, vec_u_prime), p=2, dim=0) # New Y' (y_new_axis)
292 # Transform local grid points to world coordinates
293 # plane_points_world = current_point + grid_u_local[..., None] * vec_u_prime + grid_v_local[..., None] * vec_v_prime
294 # current_point: [3] -> [1,1,3]
295 # grid_u_local: [patch_dim, patch_dim] -> [patch_dim, patch_dim, 1]
296 # vec_u_prime: [3] -> [1,1,3]
297 plane_points_world_z = current_point[0] + grid_u_local * vec_u_prime[0] + grid_v_local * vec_v_prime[0]
298 plane_points_world_y = current_point[1] + grid_u_local * vec_u_prime[1] + grid_v_local * vec_v_prime[1]
299 plane_points_world_x = current_point[2] + grid_u_local * vec_u_prime[2] + grid_v_local * vec_v_prime[2]
301 # Normalize coordinates for grid_sample (expects x, y, z order in [-1, 1])
302 # Original volume dimensions: Z_orig, H_orig, W_orig
303 norm_coords_x = 2.0 * (plane_points_world_x / (W_orig - 1)) - 1.0 if W_orig > 1 else torch.zeros_like(plane_points_world_x)
304 norm_coords_y = 2.0 * (plane_points_world_y / (H_orig - 1)) - 1.0 if H_orig > 1 else torch.zeros_like(plane_points_world_y)
305 norm_coords_z = 2.0 * (plane_points_world_z / (Z_orig - 1)) - 1.0 if Z_orig > 1 else torch.zeros_like(plane_points_world_z)
307 sampling_grid_slices[i] = torch.stack((norm_coords_x, norm_coords_y, norm_coords_z), dim=-1)
309 final_sampling_grid = sampling_grid_slices # Shape (L, patch_dim, patch_dim, 3)
311 # Prepare image_volume for grid_sample: (N, C, D_in, H_in, W_in)
312 img_vol_for_sampling = img_vol_proc.unsqueeze(0).unsqueeze(0) # (1, 1, Z_orig, H_orig, W_orig)
314 # Reshape grid for grid_sample: (N, D_out, H_out, W_out, 3)
315 # Here, D_out = L, H_out = patch_dim, W_out = patch_dim
316 grid_for_sampling = final_sampling_grid.unsqueeze(0) # (1, L, patch_dim, patch_dim, 3)
318 aligned_volume_slices = F.grid_sample(
319 img_vol_for_sampling,
320 grid_for_sampling,
321 mode='bilinear',
322 padding_mode='zeros', # or 'border'
323 align_corners=False # Usually False for feature sampling
324 )
325 # Output shape: (N, C, D_out, H_out, W_out) -> (1, 1, L, patch_dim, patch_dim)
327 aligned_volume = aligned_volume_slices.squeeze(0).squeeze(0) # (L, patch_dim, patch_dim)
328 aligned_volume = aligned_volume.to(original_dtype)
330 if return_grid_val:
331 return aligned_volume, final_sampling_grid
332 else:
333 return aligned_volume