Coverage for openhcs/processing/backends/enhance/self_supervised_3d_deconvolution.py: 13.5%

148 statements  

« prev     ^ index     » next       coverage.py v7.10.3, created at 2025-08-14 05:57 +0000

1from __future__ import annotations 

2import logging 

3from typing import Optional, Tuple, Tuple 

4 

5# Import torch decorator and optional_import utility 

6from openhcs.utils.import_utils import optional_import, create_placeholder_class 

7from openhcs.core.memory.decorators import torch as torch_func 

8 

9# --- PyTorch Imports as optional dependencies --- 

10torch = optional_import("torch") 

11nn = optional_import("torch.nn") if torch is not None else None 

12F = optional_import("torch.nn.functional") if torch is not None else None 

13if torch is not None: 13 ↛ 14line 13 didn't jump to line 14 because the condition on line 13 was never true

14 from torch.fft import irfftn, rfftn 

15else: 

16 irfftn = None 

17 rfftn = None 

18 

19logger = logging.getLogger(__name__) 

20 

21nnModule = create_placeholder_class( 

22 "Module", # Name for the placeholder if generated 

23 base_class=nn.Module if nn else None, 

24 required_library="PyTorch" 

25) 

26 

27# --- PyTorch Specific Models and Helpers --- 

28class _Simple3DCNN_torch(nnModule): 

29 """Simple 3D CNN for deconvolution - optimized for 3D data per paper.""" 

30 def __init__(self, in_channels=1, out_channels=1, features=(48, 96)): # Paper: 48 initial features for 3D 

31 super().__init__() 

32 self.conv_block = nn.Sequential( 

33 nn.Conv3d(in_channels, features[0], kernel_size=3, padding=1), 

34 nn.ReLU(inplace=True), 

35 nn.Conv3d(features[0], features[1], kernel_size=3, padding=1), 

36 nn.ReLU(inplace=True), 

37 nn.Conv3d(features[1], out_channels, kernel_size=3, padding=1) 

38 ) 

39 

40 def forward(self, x: torch.Tensor) -> torch.Tensor: # x: (B, C, D, H, W) 

41 return self.conv_block(x) 

42 

43class _LearnedBlur3D_torch(nnModule): 

44 def __init__(self, kernel_size=3): 

45 super().__init__() 

46 self.blur_conv = nn.Conv3d(1, 1, kernel_size=kernel_size, padding=kernel_size//2, bias=False) 

47 # Initialize weights to be somewhat like a Gaussian blur 

48 # For simplicity, using default initialization or a simple averaging kernel 

49 if kernel_size > 0 : 

50 weights = torch.ones(kernel_size, kernel_size, kernel_size) 

51 weights = weights / weights.sum() 

52 self.blur_conv.weight.data = weights.reshape(1,1,kernel_size,kernel_size,kernel_size) 

53 

54 def forward(self, x: torch.Tensor) -> torch.Tensor: # x: (B, 1, D, H, W) 

55 return self.blur_conv(x) 

56 

57def _gaussian_kernel_3d_torch(shape: Tuple[int, int, int], sigma: Tuple[float, float, float], device) -> torch.Tensor: 

58 coords_d = torch.arange(shape[0], dtype=torch.float32, device=device) - (shape[0] - 1) / 2.0 

59 coords_h = torch.arange(shape[1], dtype=torch.float32, device=device) - (shape[1] - 1) / 2.0 

60 coords_w = torch.arange(shape[2], dtype=torch.float32, device=device) - (shape[2] - 1) / 2.0 

61 

62 kernel_d = torch.exp(-coords_d**2 / (2 * sigma[0]**2)) 

63 kernel_h = torch.exp(-coords_h**2 / (2 * sigma[1]**2)) 

64 kernel_w = torch.exp(-coords_w**2 / (2 * sigma[2]**2)) 

65 

66 # Create 3D kernel using broadcasting instead of nested outer products 

67 kernel = kernel_d[:, None, None] * kernel_h[None, :, None] * kernel_w[None, None, :] 

68 return kernel / torch.sum(kernel) 

69 

70def _blur_fft_torch(volume: torch.Tensor, kernel: torch.Tensor, device) -> torch.Tensor: 

71 # volume: (B, 1, D, H, W), kernel: (kD, kH, kW) 

72 B, C, D, H, W = volume.shape 

73 kD, kH, kW = kernel.shape 

74 

75 # Pad kernel to volume size for FFT 

76 kernel_padded = F.pad(kernel, ( 

77 (W - kW) // 2, (W - kW + 1) // 2, 

78 (H - kH) // 2, (H - kH + 1) // 2, 

79 (D - kD) // 2, (D - kD + 1) // 2, 

80 )) 

81 kernel_padded = kernel_padded.unsqueeze(0).unsqueeze(0) # (1, 1, D, H, W) 

82 

83 vol_fft = rfftn(volume, dim=(-3, -2, -1)) 

84 ker_fft = rfftn(kernel_padded.to(device), dim=(-3, -2, -1)) 

85 

86 blurred_fft = vol_fft * ker_fft 

87 blurred_vol = irfftn(blurred_fft, s=(D, H, W), dim=(-3, -2, -1)) 

88 return blurred_vol 

89 

90def _blur_gaussian_conv_torch(volume: torch.Tensor, sigma_spatial: float, sigma_depth: float, kernel_size: int, device) -> torch.Tensor: 

91 kernel_d_1d = _gaussian_kernel_3d_torch((kernel_size,1,1), (sigma_depth,1,1), device)[:,0,0] 

92 kernel_h_1d = _gaussian_kernel_3d_torch((1,kernel_size,1), (1,sigma_spatial,1), device)[0,:,0] 

93 kernel_w_1d = _gaussian_kernel_3d_torch((1,1,kernel_size), (1,1,sigma_spatial), device)[0,0,:] 

94 

95 kernel = kernel_d_1d[:,None,None] * kernel_h_1d[None,:,None] * kernel_w_1d[None,None,:] 

96 kernel = kernel.unsqueeze(0).unsqueeze(0).to(device) # (1, 1, kD, kH, kW) 

97 

98 padding = kernel_size // 2 

99 return F.conv3d(volume, kernel, padding=padding) 

100 

101def _extract_random_patches_torch( 

102 volume_single_batch_channel: torch.Tensor, # (D, H, W) 

103 patch_size_dhw: Tuple[int, int, int], 

104 num_patches: int, 

105 device 

106) -> torch.Tensor: # (num_patches, 1, pD, pH, pW) 

107 D, H, W = volume_single_batch_channel.shape 

108 pD, pH, pW = patch_size_dhw 

109 patches = torch.empty((num_patches, 1, pD, pH, pW), device=device) 

110 for i in range(num_patches): 

111 # Force GPU device for random operations - NO CPU FALLBACK 

112 d_start = torch.randint(0, D - pD + 1, (1,), device=device).item() 

113 h_start = torch.randint(0, H - pH + 1, (1,), device=device).item() 

114 w_start = torch.randint(0, W - pW + 1, (1,), device=device).item() 

115 patch = volume_single_batch_channel[ 

116 d_start:d_start+pD, h_start:h_start+pH, w_start:w_start+pW 

117 ] 

118 patches[i, 0, ...] = patch 

119 return patches 

120 

121# --- Main Deconvolution Function --- 

122@torch_func 

123def self_supervised_3d_deconvolution( 

124 image_volume: torch.Tensor, # Expected (1, Z, H, W) or (Z,H,W) 

125 apply_deconvolution: bool = True, 

126 n_epochs: int = 10, # Reduced default for quick test (was 100) 

127 patch_size: Tuple[int, int, int] = (16, 32, 32), # Reduced for small test images (paper: 64x64x64) 

128 mask_fraction: float = 0.005, # Paper: 0.5% 

129 sigma_noise: float = 0.2, 

130 lambda_rec: float = 1.0, 

131 lambda_inv: float = 2.0, # Paper: reconvolved invariance for 3D 

132 lambda_bound: float = 0.0, # Paper: λbound = 0 for 3D 

133 min_val: float = 0.0, 

134 max_val: float = 1.0, 

135 learning_rate: float = 4e-4, 

136 blur_mode: str = "gaussian", # 'fft', 'gaussian', 'learned' 

137 blur_sigma_spatial: float = 1.5, 

138 blur_sigma_depth: float = 1.5, 

139 blur_kernel_size: int = 5, # For gaussian/learned conv blur 

140 **kwargs 

141) -> torch.Tensor: 

142 

143 if not isinstance(image_volume, torch.Tensor): 

144 raise TypeError(f"Input image_volume must be a PyTorch Tensor. Got {type(image_volume)}") 

145 

146 # --- Parameters already extracted from function signature --- 

147 patch_size_dhw = tuple(patch_size) # Convert to tuple for consistency 

148 

149 if not apply_deconvolution: 

150 return image_volume 

151 

152 # --- PyTorch Backend Implementation --- 

153 device = image_volume.device 

154 

155 # FAIL LOUDLY if not on CUDA - no CPU fallback allowed 

156 if device.type != "cuda": 

157 raise RuntimeError(f"@torch_func requires CUDA tensor, got device: {device}") 

158 

159 # Ensure input is (1, 1, D, H, W) 

160 if image_volume.ndim == 3: # (D, H, W) 

161 img_vol_norm = image_volume.unsqueeze(0).unsqueeze(0).float() 

162 elif image_volume.ndim == 4: # (1, D, H, W) 

163 img_vol_norm = image_volume.unsqueeze(1).float() 

164 elif image_volume.ndim == 5: # (1, 1, D, H, W) 

165 img_vol_norm = image_volume.float() 

166 else: 

167 raise ValueError(f"Unsupported image_volume ndim: {image_volume.ndim}") 

168 

169 # Normalize to [min_val, max_val] (typically [0,1]) 

170 img_min_orig, img_max_orig = torch.min(img_vol_norm), torch.max(img_vol_norm) 

171 if img_max_orig > img_min_orig: 

172 img_vol_norm = (img_vol_norm - img_min_orig) / (img_max_orig - img_min_orig) # to [0,1] 

173 img_vol_norm = img_vol_norm * (max_val - min_val) + min_val # to [min_val, max_val] 

174 else: # Constant image 

175 img_vol_norm = torch.full_like(img_vol_norm, min_val) 

176 

177 f_model = _Simple3DCNN_torch().to(device) 

178 

179 g_model_blur: Optional[nn.Module] = None 

180 fixed_blur_kernel: Optional[torch.Tensor] = None 

181 

182 if blur_mode == "learned": 

183 g_model_blur = _LearnedBlur3D_torch(kernel_size=blur_kernel_size).to(device) 

184 optimizer = torch.optim.Adam(list(f_model.parameters()) + list(g_model_blur.parameters()), lr=learning_rate) 

185 else: # fft or gaussian (fixed kernel) 

186 optimizer = torch.optim.Adam(f_model.parameters(), lr=learning_rate) 

187 if blur_mode == "gaussian" or blur_mode == "fft": # FFT also needs a kernel 

188 fixed_blur_kernel = _gaussian_kernel_3d_torch( 

189 (blur_kernel_size, blur_kernel_size, blur_kernel_size), 

190 (blur_sigma_depth, blur_sigma_spatial, blur_sigma_spatial), device 

191 ) 

192 

193 # Training Loop 

194 for epoch in range(n_epochs): 

195 f_model.train() 

196 if g_model_blur: g_model_blur.train() 

197 

198 # Extract one patch for this step (batch_size=1 for patches) 

199 # Input to _extract_random_patches_torch is (D,H,W) 

200 current_patch_orig = _extract_random_patches_torch(img_vol_norm.squeeze(0).squeeze(0), patch_size_dhw, 1, device) 

201 # current_patch_orig shape: (1, 1, pD, pH, pW) 

202 

203 # Create masked variant 

204 mask = (torch.rand_like(current_patch_orig) < mask_fraction).bool() 

205 noise = (torch.randn_like(current_patch_orig) * sigma_noise).clamp(min_val, max_val) 

206 current_patch_masked = torch.where(mask, noise, current_patch_orig) 

207 

208 # Forward pass f(x) 

209 f_x_orig = f_model(current_patch_orig).clamp(min_val, max_val) 

210 f_x_masked = f_model(current_patch_masked).clamp(min_val, max_val) 

211 

212 # Apply blur g(f(x)) 

213 if blur_mode == "learned": 

214 assert g_model_blur is not None 

215 g_f_x_orig = g_model_blur(f_x_orig) 

216 g_f_x_masked = g_model_blur(f_x_masked) 

217 elif blur_mode == "fft": 

218 assert fixed_blur_kernel is not None 

219 g_f_x_orig = _blur_fft_torch(f_x_orig, fixed_blur_kernel, device) 

220 g_f_x_masked = _blur_fft_torch(f_x_masked, fixed_blur_kernel, device) 

221 elif blur_mode == "gaussian": # Conv based Gaussian 

222 assert fixed_blur_kernel is not None # Re-use kernel logic for conv 

223 # For conv3d, kernel needs to be (out_c, in_c/groups, kD, kH, kW) 

224 conv_kernel = fixed_blur_kernel.unsqueeze(0).unsqueeze(0).to(device) 

225 pad_size = blur_kernel_size // 2 

226 g_f_x_orig = F.conv3d(f_x_orig, conv_kernel, padding=pad_size) 

227 g_f_x_masked = F.conv3d(f_x_masked, conv_kernel, padding=pad_size) 

228 else: 

229 raise ValueError(f"Unknown blur_mode: {blur_mode}") 

230 

231 # Losses - Paper's optimal Loss (3) for 3D: reconvolved invariance 

232 loss_rec = F.mse_loss(g_f_x_masked, current_patch_orig) 

233 

234 # Reconvolved invariance loss (after PSF) - optimal for 3D per paper 

235 loss_inv = torch.tensor(0.0, device=device) 

236 if mask.sum() > 0: # Only if some pixels were masked 

237 loss_inv = F.mse_loss(g_f_x_orig[mask], g_f_x_masked[mask]) 

238 

239 # Boundary loss on reconvolved output (λbound = 0 for 3D per paper) 

240 loss_bound = torch.tensor(0.0, device=device) 

241 if lambda_bound > 0: 

242 loss_bound_f_masked = (torch.relu(g_f_x_masked - max_val) + torch.relu(min_val - g_f_x_masked)).mean() 

243 loss_bound_f_orig = (torch.relu(g_f_x_orig - max_val) + torch.relu(min_val - g_f_x_orig)).mean() 

244 loss_bound = (loss_bound_f_masked + loss_bound_f_orig) / 2.0 

245 

246 total_loss = lambda_rec * loss_rec + lambda_inv * loss_inv + lambda_bound * loss_bound 

247 

248 optimizer.zero_grad() 

249 total_loss.backward() 

250 optimizer.step() 

251 

252 if epoch % (n_epochs // 10 if n_epochs >=10 else 1) == 0: 

253 logger.info(f"3D Deconv Epoch {epoch}/{n_epochs}, Loss: {total_loss.item():.4f} " 

254 f"(Rec: {loss_rec.item():.4f}, Inv: {loss_inv.item():.4f}, Bound: {loss_bound.item():.4f})") 

255 

256 # Inference 

257 f_model.eval() 

258 with torch.no_grad(): 

259 # Full volume inference - may need patching for large volumes on limited GPU 

260 # Assuming img_vol_norm is (1, 1, D, H, W) 

261 deconvolved_norm = f_model(img_vol_norm).clamp(min_val, max_val) 

262 

263 # Denormalize from [min_val, max_val] back to original image range [0, orig_max_val_if_uint] 

264 if img_max_orig > img_min_orig: 

265 deconvolved_final = (deconvolved_norm - min_val) / (max_val - min_val) # to [0,1] 

266 deconvolved_final = deconvolved_final * (img_max_orig - img_min_orig) + img_min_orig 

267 else: # Constant image 

268 deconvolved_final = torch.full_like(deconvolved_norm, img_min_orig) 

269 

270 # Return in the original input shape format if it was (1,D,H,W) or (D,H,W) 

271 if image_volume.ndim == 3: return deconvolved_final.squeeze(0).squeeze(0) 

272 if image_volume.ndim == 4: return deconvolved_final.squeeze(1) 

273 return deconvolved_final # (1,1,D,H,W)