Coverage for openhcs/processing/backends/analysis/self_supervised_segmentation_3d.py: 9.4%

183 statements  

« prev     ^ index     » next       coverage.py v7.10.3, created at 2025-08-14 05:57 +0000

1from __future__ import annotations 

2 

3import logging 

4from typing import Tuple, Union 

5 

6from openhcs.utils.import_utils import optional_import, create_placeholder_class 

7from openhcs.core.memory.decorators import torch as torch_func 

8 

9# Import torch modules as optional dependencies 

10torch = optional_import("torch") 

11nn = optional_import("torch.nn") if torch is not None else None 

12F = optional_import("torch.nn.functional") if torch is not None else None 

13 

14nnModule = create_placeholder_class( 

15 "Module", # Name for the placeholder if generated 

16 base_class=nn.Module if nn else None, 

17 required_library="PyTorch" 

18) 

19logger = logging.getLogger(__name__) 

20 

21# --- PyTorch Models and Helper Functions --- 

22 

23class Encoder3D(nnModule): 

24 def __init__(self, in_channels=1, features=(32, 64), embedding_dim=128): 

25 super().__init__() 

26 self.conv1 = nn.Conv3d(in_channels, features[0], kernel_size=3, padding=1) 

27 self.conv2 = nn.Conv3d(features[0], features[1], kernel_size=3, padding=1) 

28 # Stride in conv layers can reduce spatial dimensions before pooling 

29 # For simplicity here, keeping full spatial resolution until pooling 

30 self.pool = nn.AdaptiveAvgPool3d((1, 1, 1)) 

31 self.fc = nn.Linear(features[1], embedding_dim) 

32 self.features_channels_before_pool = features[1] # To know feature depth for K-Means 

33 

34 def forward(self, x: torch.Tensor, return_features_before_pool: bool = False) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: 

35 # x: [B, 1, D, H, W] 

36 features_conv1 = F.relu(self.conv1(x)) 

37 features_conv2 = F.relu(self.conv2(features_conv1)) # [B, features[1], D, H, W] 

38 

39 pooled_features = self.pool(features_conv2) # [B, features[1], 1, 1, 1] 

40 flattened_features = torch.flatten(pooled_features, 1) # [B, features[1]] 

41 embedding = self.fc(flattened_features) # [B, embedding_dim] 

42 normalized_embedding = F.normalize(embedding, p=2, dim=1) 

43 

44 if return_features_before_pool: 

45 return normalized_embedding, features_conv2 # Return features from last conv layer 

46 return normalized_embedding 

47 

48class Decoder3D(nnModule): 

49 def __init__(self, embedding_dim=128, features=(64, 32), out_channels=1, patch_size_dhw=(64,64,64)): 

50 super().__init__() 

51 self.patch_d, self.patch_h, self.patch_w = patch_size_dhw 

52 

53 # Determine initial dimensions for unflattening. 

54 # This depends on the encoder's spatial reduction. If encoder doesn't reduce much, 

55 # these initial dimensions might need to be larger or upsampling more aggressive. 

56 # Assuming encoder's convs are padding=1, kernel=3, stride=1, spatial dim is preserved. 

57 # If encoder had pooling/strides, this would need to match the smallest feature map size. 

58 # For simplicity, let's assume the decoder reconstructs to the patch size directly 

59 # from a smaller latent representation. 

60 # A common strategy is to project embedding to features * D/s * H/s * W/s where s is total downsample factor. 

61 # Here, let's make it simple: project to features[0] * small_d * small_h * small_w 

62 self.init_d, self.init_h, self.init_w = patch_size_dhw[0] // 4, patch_size_dhw[1] // 4, patch_size_dhw[2] // 4 

63 if self.init_d < 1: self.init_d = 1 

64 if self.init_h < 1: self.init_h = 1 

65 if self.init_w < 1: self.init_w = 1 

66 

67 self.fc = nn.Linear(embedding_dim, features[0] * self.init_d * self.init_h * self.init_w) 

68 self.unflatten_channels = features[0] 

69 

70 # Upsample to patch_size/2 

71 self.upconv1 = nn.ConvTranspose3d(features[0], features[1], kernel_size=4, stride=2, padding=1) 

72 # Upsample to patch_size 

73 self.upconv2 = nn.ConvTranspose3d(features[1], out_channels, kernel_size=4, stride=2, padding=1) 

74 

75 def forward(self, x: torch.Tensor) -> torch.Tensor: # x: [B, embedding_dim] 

76 x = self.fc(x) 

77 x = x.view(-1, self.unflatten_channels, self.init_d, self.init_h, self.init_w) 

78 x = F.relu(self.upconv1(x)) 

79 x = self.upconv2(x) # Output reconstruction 

80 return x 

81 

82def _extract_random_patches( 

83 volume: torch.Tensor, patch_size_dhw: Tuple[int,int,int], num_patches: int 

84) -> torch.Tensor: # volume: [1, 1, D, H, W] -> output: [num_patches, 1, pD, pH, pW] 

85 _, _, D, H, W = volume.shape 

86 pD, pH, pW = patch_size_dhw 

87 patches = torch.empty((num_patches, 1, pD, pH, pW), device=volume.device, dtype=volume.dtype) 

88 for i in range(num_patches): 

89 d_start = torch.randint(0, D - pD + 1, (1,)).item() if D > pD else 0 

90 h_start = torch.randint(0, H - pH + 1, (1,)).item() if H > pH else 0 

91 w_start = torch.randint(0, W - pW + 1, (1,)).item() if W > pW else 0 

92 patches[i] = volume[0, :, d_start:d_start+pD, h_start:h_start+pH, w_start:w_start+pW] 

93 return patches 

94 

95def _affine_augment_patch(patch: torch.Tensor) -> torch.Tensor: # patch: [1, pD, pH, pW] 

96 # Random flips 

97 if torch.rand(1).item() > 0.5: patch = torch.flip(patch, dims=[1]) # D 

98 if torch.rand(1).item() > 0.5: patch = torch.flip(patch, dims=[2]) # H 

99 if torch.rand(1).item() > 0.5: patch = torch.flip(patch, dims=[3]) # W 

100 

101 # Random 90-degree rotations (example: rotate in DH plane) 

102 if torch.rand(1).item() > 0.5: 

103 k = torch.randint(0, 4, (1,)).item() 

104 patch = torch.rot90(patch, k, dims=[1, 2]) # Rotate D-H plane 

105 return patch 

106 

107def _nt_xent_loss(z_i: torch.Tensor, z_j: torch.Tensor, temperature: float) -> torch.Tensor: 

108 batch_size = z_i.shape[0] 

109 # z_i, z_j are already normalized by encoder 

110 

111 # Concatenate embeddings from two views 

112 z = torch.cat([z_i, z_j], dim=0) # Shape: [2*B, E] 

113 

114 # Calculate similarity matrix 

115 sim_matrix = torch.matmul(z, z.T) / temperature # Shape: [2*B, 2*B] 

116 

117 # Create mask to identify positive pairs (i-th sample from view 1 with i-th sample from view 2) 

118 # And exclude self-similarity (i-th sample with itself) 

119 identity_mask = torch.eye(2 * batch_size, device=z_i.device, dtype=torch.bool) 

120 

121 # Positive pairs are (i, i+B) and (i+B, i) 

122 pos_mask = torch.zeros_like(sim_matrix, dtype=torch.bool) 

123 pos_mask[torch.arange(batch_size), torch.arange(batch_size) + batch_size] = True 

124 pos_mask[torch.arange(batch_size) + batch_size, torch.arange(batch_size)] = True 

125 

126 # Numerator: similarity of positive pairs 

127 numerator = torch.exp(sim_matrix[pos_mask]) # Shape: [2*B] (actually B pairs, repeated) 

128 

129 # Denominator: sum of similarities with all other samples (excluding self) 

130 # For each row, sum exp(sim) over all columns except the diagonal (self-similarity) 

131 exp_sim_no_self = torch.exp(sim_matrix.masked_fill(identity_mask, -float('inf'))) # Mask self-similarity 

132 denominator = exp_sim_no_self.sum(dim=1) # Shape: [2*B] 

133 

134 # Calculate log probabilities 

135 log_probs = torch.log(numerator / denominator[pos_mask.any(dim=1)]) # Select relevant denominators 

136 

137 # Loss is the negative mean of these log probabilities 

138 loss = -log_probs.mean() 

139 return loss 

140 

141def _kmeans_torch(X: torch.Tensor, K: int, n_iters: int = 20) -> Tuple[torch.Tensor, torch.Tensor]: 

142 N, D_feat = X.shape 

143 if N == 0: return torch.empty(0, dtype=torch.long, device=X.device), torch.empty((K,D_feat), device=X.device, dtype=X.dtype) 

144 if N < K : K = N 

145 

146 # Use randint for memory efficiency (though N is typically small for K-means) 

147 centroids = X[torch.randint(0, N, (K,), device=X.device)] 

148 

149 for _ in range(n_iters): 

150 dists_sq = torch.sum((X[:, None, :] - centroids[None, :, :])**2, dim=2) 

151 labels = torch.argmin(dists_sq, dim=1) 

152 

153 new_centroids = torch.zeros_like(centroids) 

154 for k_idx in range(K): 

155 assigned_points = X[labels == k_idx] 

156 if assigned_points.shape[0] > 0: 

157 new_centroids[k_idx] = assigned_points.mean(dim=0) 

158 else: 

159 new_centroids[k_idx] = X[torch.randint(0,N,(1,)).item()] if N > 0 else centroids[k_idx] 

160 

161 if torch.allclose(centroids, new_centroids, atol=1e-5): break 

162 centroids = new_centroids 

163 return labels, centroids 

164 

165# --- Main Segmentation Function --- 

166@torch_func 

167def self_supervised_segmentation_3d( 

168 image_volume: torch.Tensor, 

169 apply_segmentation: bool = True, 

170 min_val: float = 0.0, 

171 max_val: float = 1.0, 

172 patch_size: Optional[Tuple[int, int, int]] = None, 

173 n_epochs: int = 500, 

174 embedding_dim: int = 128, 

175 temperature: float = 0.1, 

176 batch_size: int = 4, 

177 learning_rate: float = 1e-4, 

178 reconstruction_weight: float = 1.0, 

179 contrastive_weight: float = 1.0, 

180 cluster_k: int = 2, 

181 mask_fraction: float = 0.01, 

182 sigma_noise: float = 0.2, 

183 lambda_bound: float = 0.1, 

184 **kwargs 

185) -> torch.Tensor: 

186 

187 if not isinstance(image_volume, torch.Tensor): 

188 raise TypeError(f"Input image_volume must be a PyTorch Tensor. Got {type(image_volume)}") 

189 

190 device = image_volume.device 

191 original_input_shape_len = image_volume.ndim 

192 original_dtype = image_volume.dtype 

193 

194 img_vol_proc = image_volume.float() 

195 if img_vol_proc.ndim == 3: 

196 Z_orig, H_orig, W_orig = img_vol_proc.shape 

197 img_vol_proc = img_vol_proc.unsqueeze(0).unsqueeze(0) 

198 elif img_vol_proc.ndim == 4: 

199 _, Z_orig, H_orig, W_orig = img_vol_proc.shape 

200 img_vol_proc = img_vol_proc.unsqueeze(1) 

201 elif img_vol_proc.ndim == 5: 

202 _, _, Z_orig, H_orig, W_orig = img_vol_proc.shape 

203 else: 

204 raise ValueError(f"image_volume must be 3D, 4D or 5D. Got {image_volume.ndim}D") 

205 

206 min_val_norm = float(min_val) 

207 max_val_norm = float(max_val) 

208 

209 img_min_orig, img_max_orig = torch.min(img_vol_proc), torch.max(img_vol_proc) 

210 if img_max_orig > img_min_orig: 

211 img_vol_norm = (img_vol_proc - img_min_orig) / (img_max_orig - img_min_orig) 

212 img_vol_norm = img_vol_norm * (max_val_norm - min_val_norm) + min_val_norm 

213 else: 

214 img_vol_norm = torch.full_like(img_vol_proc, min_val_norm) 

215 

216 # Use provided patch_size or compute default 

217 if patch_size is None: 

218 patch_size_dhw = (max(16, Z_orig // 8), max(16, H_orig // 8), max(16, W_orig // 8)) # Ensure min size 

219 else: 

220 patch_size_dhw = patch_size 

221 

222 patch_size_dhw = (min(patch_size_dhw[0], Z_orig), min(patch_size_dhw[1], H_orig), min(patch_size_dhw[2], W_orig)) 

223 if any(p <= 0 for p in patch_size_dhw): 

224 raise ValueError(f"Patch dimensions must be positive. Got {patch_size_dhw} for volume {Z_orig,H_orig,W_orig}") 

225 

226 encoder = Encoder3D(in_channels=1, embedding_dim=embedding_dim).to(device) 

227 decoder = Decoder3D(embedding_dim=embedding_dim, patch_size_dhw=patch_size_dhw).to(device) 

228 optimizer = torch.optim.Adam(list(encoder.parameters()) + list(decoder.parameters()), lr=learning_rate) 

229 

230 for epoch in range(n_epochs): 

231 encoder.train() 

232 decoder.train() 

233 

234 patches_orig_batch = _extract_random_patches(img_vol_norm, patch_size_dhw, batch_size) 

235 

236 patches_mvm_list = [] 

237 patches_affine_list = [] 

238 masks_for_loss_list = [] 

239 

240 for i in range(batch_size): 

241 current_patch_orig = patches_orig_batch[i] # [1, pD, pH, pW] 

242 

243 mask_mvm = (torch.rand_like(current_patch_orig) < mask_fraction).bool() 

244 masks_for_loss_list.append(mask_mvm.clone()) # Store for loss calculation 

245 noise = (torch.randn_like(current_patch_orig) * sigma_noise).clamp(min_val_norm, max_val_norm) 

246 patch_mvm = torch.where(mask_mvm, noise, current_patch_orig) 

247 patches_mvm_list.append(patch_mvm) 

248 

249 patch_affine = _affine_augment_patch(current_patch_orig.clone()) 

250 patches_affine_list.append(patch_affine) 

251 

252 patches_mvm_batch = torch.stack(patches_mvm_list) 

253 patches_affine_batch = torch.stack(patches_affine_list) 

254 masks_batch = torch.stack(masks_for_loss_list) # [B, 1, pD, pH, pW] 

255 

256 emb_mvm = encoder(patches_mvm_batch) 

257 emb_affine = encoder(patches_affine_batch) 

258 reconstructed_patches = decoder(emb_mvm) 

259 

260 loss_rec = torch.tensor(0.0, device=device) 

261 if masks_batch.any(): # Only compute if there are masked voxels 

262 # Ensure shapes match for masked selection 

263 masked_reconstruction = reconstructed_patches[masks_batch] 

264 masked_original = patches_orig_batch[masks_batch] 

265 if masked_reconstruction.numel() > 0: # If any elements were actually masked and selected 

266 loss_rec = F.mse_loss(masked_reconstruction, masked_original) 

267 

268 loss_contrastive = _nt_xent_loss(emb_mvm, emb_affine, temperature) 

269 loss_bound = (torch.relu(reconstructed_patches - max_val_norm) + \ 

270 torch.relu(min_val_norm - reconstructed_patches)).mean() 

271 

272 total_loss = (reconstruction_weight * loss_rec + 

273 contrastive_weight * loss_contrastive + 

274 lambda_bound * loss_bound) 

275 

276 optimizer.zero_grad() 

277 total_loss.backward() 

278 optimizer.step() 

279 

280 if epoch % (n_epochs // 10 if n_epochs >=10 else 1) == 0: 

281 logger.info(f"Epoch {epoch+1}/{n_epochs}, Loss: {total_loss.item():.4f} " 

282 f"(Rec: {loss_rec.item():.4f}, Contr: {loss_contrastive.item():.4f}, Bound: {loss_bound.item():.4f})") 

283 

284 if not apply_segmentation: 

285 return image_volume 

286 

287 encoder.eval() 

288 with torch.no_grad(): 

289 _, dense_features_full_vol = encoder(img_vol_norm, return_features_before_pool=True) 

290 

291 if dense_features_full_vol.shape[2:] != (Z_orig, H_orig, W_orig): 

292 dense_features_upsampled = F.interpolate( 

293 dense_features_full_vol, size=(Z_orig, H_orig, W_orig), 

294 mode='trilinear', align_corners=False 

295 ) 

296 else: 

297 dense_features_upsampled = dense_features_full_vol 

298 

299 features_for_kmeans = dense_features_upsampled.squeeze(0).permute(1,2,3,0).reshape(-1, encoder.features_channels_before_pool) 

300 

301 if features_for_kmeans.shape[0] == 0: 

302 logger.warning("No features extracted for K-Means, returning empty segmentation.") 

303 return torch.zeros((Z_orig, H_orig, W_orig), dtype=torch.long, device=device) 

304 

305 voxel_labels_flat, _ = _kmeans_torch(features_for_kmeans, cluster_k) 

306 segmentation_mask = voxel_labels_flat.reshape(Z_orig, H_orig, W_orig) 

307 

308 if original_input_shape_len == 3: 

309 return segmentation_mask.to(original_dtype) 

310 elif original_input_shape_len == 4: 

311 return segmentation_mask.unsqueeze(0).to(original_dtype) 

312 else: 

313 return segmentation_mask.unsqueeze(0).unsqueeze(0).to(original_dtype)