Coverage for openhcs/processing/backends/analysis/self_supervised_segmentation_3d.py: 9.4%
183 statements
« prev ^ index » next coverage.py v7.10.3, created at 2025-08-14 05:57 +0000
« prev ^ index » next coverage.py v7.10.3, created at 2025-08-14 05:57 +0000
1from __future__ import annotations
3import logging
4from typing import Tuple, Union
6from openhcs.utils.import_utils import optional_import, create_placeholder_class
7from openhcs.core.memory.decorators import torch as torch_func
9# Import torch modules as optional dependencies
10torch = optional_import("torch")
11nn = optional_import("torch.nn") if torch is not None else None
12F = optional_import("torch.nn.functional") if torch is not None else None
14nnModule = create_placeholder_class(
15 "Module", # Name for the placeholder if generated
16 base_class=nn.Module if nn else None,
17 required_library="PyTorch"
18)
19logger = logging.getLogger(__name__)
21# --- PyTorch Models and Helper Functions ---
23class Encoder3D(nnModule):
24 def __init__(self, in_channels=1, features=(32, 64), embedding_dim=128):
25 super().__init__()
26 self.conv1 = nn.Conv3d(in_channels, features[0], kernel_size=3, padding=1)
27 self.conv2 = nn.Conv3d(features[0], features[1], kernel_size=3, padding=1)
28 # Stride in conv layers can reduce spatial dimensions before pooling
29 # For simplicity here, keeping full spatial resolution until pooling
30 self.pool = nn.AdaptiveAvgPool3d((1, 1, 1))
31 self.fc = nn.Linear(features[1], embedding_dim)
32 self.features_channels_before_pool = features[1] # To know feature depth for K-Means
34 def forward(self, x: torch.Tensor, return_features_before_pool: bool = False) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
35 # x: [B, 1, D, H, W]
36 features_conv1 = F.relu(self.conv1(x))
37 features_conv2 = F.relu(self.conv2(features_conv1)) # [B, features[1], D, H, W]
39 pooled_features = self.pool(features_conv2) # [B, features[1], 1, 1, 1]
40 flattened_features = torch.flatten(pooled_features, 1) # [B, features[1]]
41 embedding = self.fc(flattened_features) # [B, embedding_dim]
42 normalized_embedding = F.normalize(embedding, p=2, dim=1)
44 if return_features_before_pool:
45 return normalized_embedding, features_conv2 # Return features from last conv layer
46 return normalized_embedding
48class Decoder3D(nnModule):
49 def __init__(self, embedding_dim=128, features=(64, 32), out_channels=1, patch_size_dhw=(64,64,64)):
50 super().__init__()
51 self.patch_d, self.patch_h, self.patch_w = patch_size_dhw
53 # Determine initial dimensions for unflattening.
54 # This depends on the encoder's spatial reduction. If encoder doesn't reduce much,
55 # these initial dimensions might need to be larger or upsampling more aggressive.
56 # Assuming encoder's convs are padding=1, kernel=3, stride=1, spatial dim is preserved.
57 # If encoder had pooling/strides, this would need to match the smallest feature map size.
58 # For simplicity, let's assume the decoder reconstructs to the patch size directly
59 # from a smaller latent representation.
60 # A common strategy is to project embedding to features * D/s * H/s * W/s where s is total downsample factor.
61 # Here, let's make it simple: project to features[0] * small_d * small_h * small_w
62 self.init_d, self.init_h, self.init_w = patch_size_dhw[0] // 4, patch_size_dhw[1] // 4, patch_size_dhw[2] // 4
63 if self.init_d < 1: self.init_d = 1
64 if self.init_h < 1: self.init_h = 1
65 if self.init_w < 1: self.init_w = 1
67 self.fc = nn.Linear(embedding_dim, features[0] * self.init_d * self.init_h * self.init_w)
68 self.unflatten_channels = features[0]
70 # Upsample to patch_size/2
71 self.upconv1 = nn.ConvTranspose3d(features[0], features[1], kernel_size=4, stride=2, padding=1)
72 # Upsample to patch_size
73 self.upconv2 = nn.ConvTranspose3d(features[1], out_channels, kernel_size=4, stride=2, padding=1)
75 def forward(self, x: torch.Tensor) -> torch.Tensor: # x: [B, embedding_dim]
76 x = self.fc(x)
77 x = x.view(-1, self.unflatten_channels, self.init_d, self.init_h, self.init_w)
78 x = F.relu(self.upconv1(x))
79 x = self.upconv2(x) # Output reconstruction
80 return x
82def _extract_random_patches(
83 volume: torch.Tensor, patch_size_dhw: Tuple[int,int,int], num_patches: int
84) -> torch.Tensor: # volume: [1, 1, D, H, W] -> output: [num_patches, 1, pD, pH, pW]
85 _, _, D, H, W = volume.shape
86 pD, pH, pW = patch_size_dhw
87 patches = torch.empty((num_patches, 1, pD, pH, pW), device=volume.device, dtype=volume.dtype)
88 for i in range(num_patches):
89 d_start = torch.randint(0, D - pD + 1, (1,)).item() if D > pD else 0
90 h_start = torch.randint(0, H - pH + 1, (1,)).item() if H > pH else 0
91 w_start = torch.randint(0, W - pW + 1, (1,)).item() if W > pW else 0
92 patches[i] = volume[0, :, d_start:d_start+pD, h_start:h_start+pH, w_start:w_start+pW]
93 return patches
95def _affine_augment_patch(patch: torch.Tensor) -> torch.Tensor: # patch: [1, pD, pH, pW]
96 # Random flips
97 if torch.rand(1).item() > 0.5: patch = torch.flip(patch, dims=[1]) # D
98 if torch.rand(1).item() > 0.5: patch = torch.flip(patch, dims=[2]) # H
99 if torch.rand(1).item() > 0.5: patch = torch.flip(patch, dims=[3]) # W
101 # Random 90-degree rotations (example: rotate in DH plane)
102 if torch.rand(1).item() > 0.5:
103 k = torch.randint(0, 4, (1,)).item()
104 patch = torch.rot90(patch, k, dims=[1, 2]) # Rotate D-H plane
105 return patch
107def _nt_xent_loss(z_i: torch.Tensor, z_j: torch.Tensor, temperature: float) -> torch.Tensor:
108 batch_size = z_i.shape[0]
109 # z_i, z_j are already normalized by encoder
111 # Concatenate embeddings from two views
112 z = torch.cat([z_i, z_j], dim=0) # Shape: [2*B, E]
114 # Calculate similarity matrix
115 sim_matrix = torch.matmul(z, z.T) / temperature # Shape: [2*B, 2*B]
117 # Create mask to identify positive pairs (i-th sample from view 1 with i-th sample from view 2)
118 # And exclude self-similarity (i-th sample with itself)
119 identity_mask = torch.eye(2 * batch_size, device=z_i.device, dtype=torch.bool)
121 # Positive pairs are (i, i+B) and (i+B, i)
122 pos_mask = torch.zeros_like(sim_matrix, dtype=torch.bool)
123 pos_mask[torch.arange(batch_size), torch.arange(batch_size) + batch_size] = True
124 pos_mask[torch.arange(batch_size) + batch_size, torch.arange(batch_size)] = True
126 # Numerator: similarity of positive pairs
127 numerator = torch.exp(sim_matrix[pos_mask]) # Shape: [2*B] (actually B pairs, repeated)
129 # Denominator: sum of similarities with all other samples (excluding self)
130 # For each row, sum exp(sim) over all columns except the diagonal (self-similarity)
131 exp_sim_no_self = torch.exp(sim_matrix.masked_fill(identity_mask, -float('inf'))) # Mask self-similarity
132 denominator = exp_sim_no_self.sum(dim=1) # Shape: [2*B]
134 # Calculate log probabilities
135 log_probs = torch.log(numerator / denominator[pos_mask.any(dim=1)]) # Select relevant denominators
137 # Loss is the negative mean of these log probabilities
138 loss = -log_probs.mean()
139 return loss
141def _kmeans_torch(X: torch.Tensor, K: int, n_iters: int = 20) -> Tuple[torch.Tensor, torch.Tensor]:
142 N, D_feat = X.shape
143 if N == 0: return torch.empty(0, dtype=torch.long, device=X.device), torch.empty((K,D_feat), device=X.device, dtype=X.dtype)
144 if N < K : K = N
146 # Use randint for memory efficiency (though N is typically small for K-means)
147 centroids = X[torch.randint(0, N, (K,), device=X.device)]
149 for _ in range(n_iters):
150 dists_sq = torch.sum((X[:, None, :] - centroids[None, :, :])**2, dim=2)
151 labels = torch.argmin(dists_sq, dim=1)
153 new_centroids = torch.zeros_like(centroids)
154 for k_idx in range(K):
155 assigned_points = X[labels == k_idx]
156 if assigned_points.shape[0] > 0:
157 new_centroids[k_idx] = assigned_points.mean(dim=0)
158 else:
159 new_centroids[k_idx] = X[torch.randint(0,N,(1,)).item()] if N > 0 else centroids[k_idx]
161 if torch.allclose(centroids, new_centroids, atol=1e-5): break
162 centroids = new_centroids
163 return labels, centroids
165# --- Main Segmentation Function ---
166@torch_func
167def self_supervised_segmentation_3d(
168 image_volume: torch.Tensor,
169 apply_segmentation: bool = True,
170 min_val: float = 0.0,
171 max_val: float = 1.0,
172 patch_size: Optional[Tuple[int, int, int]] = None,
173 n_epochs: int = 500,
174 embedding_dim: int = 128,
175 temperature: float = 0.1,
176 batch_size: int = 4,
177 learning_rate: float = 1e-4,
178 reconstruction_weight: float = 1.0,
179 contrastive_weight: float = 1.0,
180 cluster_k: int = 2,
181 mask_fraction: float = 0.01,
182 sigma_noise: float = 0.2,
183 lambda_bound: float = 0.1,
184 **kwargs
185) -> torch.Tensor:
187 if not isinstance(image_volume, torch.Tensor):
188 raise TypeError(f"Input image_volume must be a PyTorch Tensor. Got {type(image_volume)}")
190 device = image_volume.device
191 original_input_shape_len = image_volume.ndim
192 original_dtype = image_volume.dtype
194 img_vol_proc = image_volume.float()
195 if img_vol_proc.ndim == 3:
196 Z_orig, H_orig, W_orig = img_vol_proc.shape
197 img_vol_proc = img_vol_proc.unsqueeze(0).unsqueeze(0)
198 elif img_vol_proc.ndim == 4:
199 _, Z_orig, H_orig, W_orig = img_vol_proc.shape
200 img_vol_proc = img_vol_proc.unsqueeze(1)
201 elif img_vol_proc.ndim == 5:
202 _, _, Z_orig, H_orig, W_orig = img_vol_proc.shape
203 else:
204 raise ValueError(f"image_volume must be 3D, 4D or 5D. Got {image_volume.ndim}D")
206 min_val_norm = float(min_val)
207 max_val_norm = float(max_val)
209 img_min_orig, img_max_orig = torch.min(img_vol_proc), torch.max(img_vol_proc)
210 if img_max_orig > img_min_orig:
211 img_vol_norm = (img_vol_proc - img_min_orig) / (img_max_orig - img_min_orig)
212 img_vol_norm = img_vol_norm * (max_val_norm - min_val_norm) + min_val_norm
213 else:
214 img_vol_norm = torch.full_like(img_vol_proc, min_val_norm)
216 # Use provided patch_size or compute default
217 if patch_size is None:
218 patch_size_dhw = (max(16, Z_orig // 8), max(16, H_orig // 8), max(16, W_orig // 8)) # Ensure min size
219 else:
220 patch_size_dhw = patch_size
222 patch_size_dhw = (min(patch_size_dhw[0], Z_orig), min(patch_size_dhw[1], H_orig), min(patch_size_dhw[2], W_orig))
223 if any(p <= 0 for p in patch_size_dhw):
224 raise ValueError(f"Patch dimensions must be positive. Got {patch_size_dhw} for volume {Z_orig,H_orig,W_orig}")
226 encoder = Encoder3D(in_channels=1, embedding_dim=embedding_dim).to(device)
227 decoder = Decoder3D(embedding_dim=embedding_dim, patch_size_dhw=patch_size_dhw).to(device)
228 optimizer = torch.optim.Adam(list(encoder.parameters()) + list(decoder.parameters()), lr=learning_rate)
230 for epoch in range(n_epochs):
231 encoder.train()
232 decoder.train()
234 patches_orig_batch = _extract_random_patches(img_vol_norm, patch_size_dhw, batch_size)
236 patches_mvm_list = []
237 patches_affine_list = []
238 masks_for_loss_list = []
240 for i in range(batch_size):
241 current_patch_orig = patches_orig_batch[i] # [1, pD, pH, pW]
243 mask_mvm = (torch.rand_like(current_patch_orig) < mask_fraction).bool()
244 masks_for_loss_list.append(mask_mvm.clone()) # Store for loss calculation
245 noise = (torch.randn_like(current_patch_orig) * sigma_noise).clamp(min_val_norm, max_val_norm)
246 patch_mvm = torch.where(mask_mvm, noise, current_patch_orig)
247 patches_mvm_list.append(patch_mvm)
249 patch_affine = _affine_augment_patch(current_patch_orig.clone())
250 patches_affine_list.append(patch_affine)
252 patches_mvm_batch = torch.stack(patches_mvm_list)
253 patches_affine_batch = torch.stack(patches_affine_list)
254 masks_batch = torch.stack(masks_for_loss_list) # [B, 1, pD, pH, pW]
256 emb_mvm = encoder(patches_mvm_batch)
257 emb_affine = encoder(patches_affine_batch)
258 reconstructed_patches = decoder(emb_mvm)
260 loss_rec = torch.tensor(0.0, device=device)
261 if masks_batch.any(): # Only compute if there are masked voxels
262 # Ensure shapes match for masked selection
263 masked_reconstruction = reconstructed_patches[masks_batch]
264 masked_original = patches_orig_batch[masks_batch]
265 if masked_reconstruction.numel() > 0: # If any elements were actually masked and selected
266 loss_rec = F.mse_loss(masked_reconstruction, masked_original)
268 loss_contrastive = _nt_xent_loss(emb_mvm, emb_affine, temperature)
269 loss_bound = (torch.relu(reconstructed_patches - max_val_norm) + \
270 torch.relu(min_val_norm - reconstructed_patches)).mean()
272 total_loss = (reconstruction_weight * loss_rec +
273 contrastive_weight * loss_contrastive +
274 lambda_bound * loss_bound)
276 optimizer.zero_grad()
277 total_loss.backward()
278 optimizer.step()
280 if epoch % (n_epochs // 10 if n_epochs >=10 else 1) == 0:
281 logger.info(f"Epoch {epoch+1}/{n_epochs}, Loss: {total_loss.item():.4f} "
282 f"(Rec: {loss_rec.item():.4f}, Contr: {loss_contrastive.item():.4f}, Bound: {loss_bound.item():.4f})")
284 if not apply_segmentation:
285 return image_volume
287 encoder.eval()
288 with torch.no_grad():
289 _, dense_features_full_vol = encoder(img_vol_norm, return_features_before_pool=True)
291 if dense_features_full_vol.shape[2:] != (Z_orig, H_orig, W_orig):
292 dense_features_upsampled = F.interpolate(
293 dense_features_full_vol, size=(Z_orig, H_orig, W_orig),
294 mode='trilinear', align_corners=False
295 )
296 else:
297 dense_features_upsampled = dense_features_full_vol
299 features_for_kmeans = dense_features_upsampled.squeeze(0).permute(1,2,3,0).reshape(-1, encoder.features_channels_before_pool)
301 if features_for_kmeans.shape[0] == 0:
302 logger.warning("No features extracted for K-Means, returning empty segmentation.")
303 return torch.zeros((Z_orig, H_orig, W_orig), dtype=torch.long, device=device)
305 voxel_labels_flat, _ = _kmeans_torch(features_for_kmeans, cluster_k)
306 segmentation_mask = voxel_labels_flat.reshape(Z_orig, H_orig, W_orig)
308 if original_input_shape_len == 3:
309 return segmentation_mask.to(original_dtype)
310 elif original_input_shape_len == 4:
311 return segmentation_mask.unsqueeze(0).to(original_dtype)
312 else:
313 return segmentation_mask.unsqueeze(0).unsqueeze(0).to(original_dtype)