Coverage for openhcs/processing/backends/enhance/self_supervised_2d_deconvolution.py: 14.0%

134 statements  

« prev     ^ index     » next       coverage.py v7.10.3, created at 2025-08-14 05:57 +0000

1from __future__ import annotations 

2import logging 

3from typing import Optional, Tuple 

4 

5# Import torch decorator and optional_import utility 

6from openhcs.utils.import_utils import optional_import, create_placeholder_class 

7from openhcs.core.memory.decorators import torch as torch_func 

8 

9# --- PyTorch Imports as optional dependencies --- 

10torch = optional_import("torch") 

11nn = optional_import("torch.nn") if torch is not None else None 

12F = optional_import("torch.nn.functional") if torch is not None else None 

13if torch is not None: 13 ↛ 14line 13 didn't jump to line 14 because the condition on line 13 was never true

14 from torch.fft import irfft2, rfft2 

15else: 

16 irfft2 = None 

17 rfft2 = None 

18 

19logger = logging.getLogger(__name__) 

20 

21nnModule = create_placeholder_class( 

22 "Module", # Name for the placeholder if generated 

23 base_class=nn.Module if nn else None, 

24 required_library="PyTorch" 

25) 

26# --- PyTorch Specific Models and Helpers for 2D --- 

27class _Simple2DCNN_torch(nnModule): 

28 """Simple 2D CNN for deconvolution - optimized for 2D data per paper.""" 

29 def __init__(self, in_channels=1, out_channels=1, features=(96, 192)): # Paper: 96 initial features for 2D 

30 super().__init__() 

31 self.conv_block = nn.Sequential( 

32 nn.Conv2d(in_channels, features[0], kernel_size=3, padding=1), 

33 nn.ReLU(inplace=True), 

34 nn.Conv2d(features[0], features[1], kernel_size=3, padding=1), 

35 nn.ReLU(inplace=True), 

36 nn.Conv2d(features[1], out_channels, kernel_size=3, padding=1) 

37 ) 

38 

39 def forward(self, x: torch.Tensor) -> torch.Tensor: # x: (B, C, H, W) 

40 return self.conv_block(x) 

41 

42class _LearnedBlur2D_torch(nnModule): 

43 """Learned blur for 2D deconvolution.""" 

44 def __init__(self, kernel_size=3): 

45 super().__init__() 

46 self.blur_conv = nn.Conv2d(1, 1, kernel_size=kernel_size, padding=kernel_size//2, bias=False) 

47 # Initialize weights to be somewhat like a Gaussian blur 

48 if kernel_size > 0: 

49 weights = torch.ones(kernel_size, kernel_size) 

50 weights = weights / weights.sum() 

51 self.blur_conv.weight.data = weights.reshape(1, 1, kernel_size, kernel_size) 

52 

53 def forward(self, x: torch.Tensor) -> torch.Tensor: # x: (B, 1, H, W) 

54 return self.blur_conv(x) 

55 

56def _gaussian_kernel_2d_torch(shape: Tuple[int, int], sigma: Tuple[float, float], device) -> torch.Tensor: 

57 """Generate 2D Gaussian kernel.""" 

58 coords_h = torch.arange(shape[0], dtype=torch.float32, device=device) - (shape[0] - 1) / 2.0 

59 coords_w = torch.arange(shape[1], dtype=torch.float32, device=device) - (shape[1] - 1) / 2.0 

60 

61 kernel_h = torch.exp(-coords_h**2 / (2 * sigma[0]**2)) 

62 kernel_w = torch.exp(-coords_w**2 / (2 * sigma[1]**2)) 

63 

64 kernel = torch.outer(kernel_h, kernel_w) 

65 return kernel / torch.sum(kernel) 

66 

67def _blur_fft_2d_torch(image: torch.Tensor, kernel: torch.Tensor, device) -> torch.Tensor: 

68 """FFT-based 2D blur convolution.""" 

69 # image: (B, 1, H, W), kernel: (kH, kW) 

70 B, C, H, W = image.shape 

71 kH, kW = kernel.shape 

72 

73 # Pad kernel to image size for FFT 

74 kernel_padded = F.pad(kernel, ( 

75 (W - kW) // 2, (W - kW + 1) // 2, 

76 (H - kH) // 2, (H - kH + 1) // 2, 

77 )) 

78 kernel_padded = kernel_padded.unsqueeze(0).unsqueeze(0) # (1, 1, H, W) 

79 

80 img_fft = rfft2(image, dim=(-2, -1)) 

81 ker_fft = rfft2(kernel_padded.to(device), dim=(-2, -1)) 

82 

83 blurred_fft = img_fft * ker_fft 

84 blurred_img = irfft2(blurred_fft, s=(H, W), dim=(-2, -1)) 

85 return blurred_img 

86 

87def _extract_random_patches_2d_torch( 

88 image_single_batch_channel: torch.Tensor, # (H, W) 

89 patch_size_hw: Tuple[int, int], 

90 num_patches: int, 

91 device 

92) -> torch.Tensor: # (num_patches, 1, pH, pW) 

93 """Extract random 2D patches - GPU-native.""" 

94 H, W = image_single_batch_channel.shape 

95 pH, pW = patch_size_hw 

96 patches = torch.empty((num_patches, 1, pH, pW), device=device) 

97 for i in range(num_patches): 

98 # Force GPU device for random operations - NO CPU FALLBACK 

99 h_start = torch.randint(0, H - pH + 1, (1,), device=device).item() 

100 w_start = torch.randint(0, W - pW + 1, (1,), device=device).item() 

101 patch = image_single_batch_channel[ 

102 h_start:h_start+pH, w_start:w_start+pW 

103 ] 

104 patches[i, 0, ...] = patch 

105 return patches 

106 

107# --- Main 2D Deconvolution Function --- 

108@torch_func 

109def self_supervised_2d_deconvolution( 

110 image: torch.Tensor, # Expected (H, W) or (1, H, W) 

111 apply_deconvolution: bool = True, 

112 n_epochs: int = 10, # Reduced for testing 

113 patch_size_hw: Tuple[int, int] = (128, 128), # Paper: 128x128 for 2D 

114 mask_fraction: float = 0.005, # Paper: 0.5% 

115 sigma_noise: float = 0.2, 

116 lambda_rec: float = 1.0, 

117 lambda_inv_d: float = 2.0, # Paper: deconvolved invariance for 2D 

118 lambda_bound_d: float = 0.1, # Paper: boundary loss for 2D 

119 min_val: float = 0.0, 

120 max_val: float = 1.0, 

121 learning_rate: float = 4e-4, # Paper: Adam 4e-4 

122 blur_mode: str = "gaussian", # 'fft', 'gaussian', 'learned' 

123 blur_sigma_spatial: float = 1.5, 

124 blur_kernel_size: int = 5, 

125 **kwargs 

126) -> torch.Tensor: 

127 """ 

128 Self-supervised 2D deconvolution optimized for 2D imaging data. 

129  

130 Based on the paper's optimal 2D configuration: 

131 - 96 initial features (vs 48 for 3D) 

132 - 128x128 patches (vs 64x64x64 for 3D) 

133 - Batch size 16 (vs 4 for 3D) 

134 - Loss (4) with deconvolved invariance (vs reconvolved for 3D) 

135 """ 

136 if not isinstance(image, torch.Tensor): 

137 raise TypeError(f"Input image must be a PyTorch Tensor. Got {type(image)}") 

138 

139 

140 if not apply_deconvolution: 

141 return image 

142 

143 # --- PyTorch Backend Implementation --- 

144 device = image.device 

145 

146 # FAIL LOUDLY if not on CUDA - no CPU fallback allowed 

147 if device.type != "cuda": 

148 raise RuntimeError(f"@torch_func requires CUDA tensor, got device: {device}") 

149 

150 # Ensure input is (1, 1, H, W) 

151 if image.ndim == 2: # (H, W) 

152 img_norm = image.unsqueeze(0).unsqueeze(0).float() 

153 elif image.ndim == 3: # (1, H, W) 

154 img_norm = image.unsqueeze(1).float() 

155 elif image.ndim == 4: # (1, 1, H, W) 

156 img_norm = image.float() 

157 else: 

158 raise ValueError(f"Unsupported image ndim: {image.ndim}") 

159 

160 # Normalize to [min_val, max_val] 

161 img_min_orig, img_max_orig = torch.min(img_norm), torch.max(img_norm) 

162 if img_max_orig > img_min_orig: 

163 img_norm = (img_norm - img_min_orig) / (img_max_orig - img_min_orig) 

164 img_norm = img_norm * (max_val - min_val) + min_val 

165 else: 

166 img_norm = torch.full_like(img_norm, min_val) 

167 

168 # Create 2D model with paper's optimal architecture 

169 f_model = _Simple2DCNN_torch().to(device) 

170 

171 g_model_blur: Optional[nn.Module] = None 

172 fixed_blur_kernel: Optional[torch.Tensor] = None 

173 

174 if blur_mode == "learned": 

175 g_model_blur = _LearnedBlur2D_torch(kernel_size=blur_kernel_size).to(device) 

176 optimizer = torch.optim.Adam(list(f_model.parameters()) + list(g_model_blur.parameters()), lr=learning_rate) 

177 else: 

178 optimizer = torch.optim.Adam(f_model.parameters(), lr=learning_rate) 

179 if blur_mode in ["gaussian", "fft"]: 

180 fixed_blur_kernel = _gaussian_kernel_2d_torch( 

181 (blur_kernel_size, blur_kernel_size), 

182 (blur_sigma_spatial, blur_sigma_spatial), device 

183 ) 

184 

185 # Training Loop 

186 for epoch in range(n_epochs): 

187 f_model.train() 

188 if g_model_blur: 

189 g_model_blur.train() 

190 

191 # Extract 2D patch 

192 current_patch_orig = _extract_random_patches_2d_torch( 

193 img_norm.squeeze(0).squeeze(0), patch_size_hw, 1, device 

194 ) 

195 

196 # Create masked variant 

197 mask = (torch.rand_like(current_patch_orig) < mask_fraction).bool() 

198 noise = (torch.randn_like(current_patch_orig) * sigma_noise).clamp(min_val, max_val) 

199 current_patch_masked = torch.where(mask, noise, current_patch_orig) 

200 

201 # Forward pass f(x) 

202 f_x_orig = f_model(current_patch_orig).clamp(min_val, max_val) 

203 f_x_masked = f_model(current_patch_masked).clamp(min_val, max_val) 

204 

205 # Apply blur g(f(x)) 

206 if blur_mode == "learned": 

207 g_f_x_orig = g_model_blur(f_x_orig) 

208 g_f_x_masked = g_model_blur(f_x_masked) 

209 elif blur_mode == "fft": 

210 g_f_x_orig = _blur_fft_2d_torch(f_x_orig, fixed_blur_kernel, device) 

211 g_f_x_masked = _blur_fft_2d_torch(f_x_masked, fixed_blur_kernel, device) 

212 elif blur_mode == "gaussian": 

213 conv_kernel = fixed_blur_kernel.unsqueeze(0).unsqueeze(0).to(device) 

214 pad_size = blur_kernel_size // 2 

215 g_f_x_orig = F.conv2d(f_x_orig, conv_kernel, padding=pad_size) 

216 g_f_x_masked = F.conv2d(f_x_masked, conv_kernel, padding=pad_size) 

217 else: 

218 raise ValueError(f"Unknown blur_mode: {blur_mode}") 

219 

220 # Losses - Paper's optimal Loss (4) for 2D: deconvolved invariance 

221 loss_rec = F.mse_loss(g_f_x_masked, current_patch_orig) 

222 

223 # Deconvolved invariance loss (before PSF) - optimal for 2D per paper 

224 loss_inv_d = torch.tensor(0.0, device=device) 

225 if mask.sum() > 0: 

226 loss_inv_d = F.mse_loss(f_x_orig[mask], f_x_masked[mask]) 

227 

228 # Boundary loss on deconvolved output 

229 loss_bound_d = (torch.relu(f_x_masked - max_val) + torch.relu(min_val - f_x_masked)).mean() 

230 loss_bound_d += (torch.relu(f_x_orig - max_val) + torch.relu(min_val - f_x_orig)).mean() 

231 loss_bound_d /= 2.0 

232 

233 total_loss = lambda_rec * loss_rec + lambda_inv_d * loss_inv_d + lambda_bound_d * loss_bound_d 

234 

235 optimizer.zero_grad() 

236 total_loss.backward() 

237 optimizer.step() 

238 

239 if epoch % (n_epochs // 10 if n_epochs >= 10 else 1) == 0: 

240 logger.info(f"2D Deconv Epoch {epoch}/{n_epochs}, Loss: {total_loss.item():.4f} " 

241 f"(Rec: {loss_rec.item():.4f}, Inv_d: {loss_inv_d.item():.4f}, Bound_d: {loss_bound_d.item():.4f})") 

242 

243 # Inference 

244 f_model.eval() 

245 with torch.no_grad(): 

246 deconvolved_norm = f_model(img_norm).clamp(min_val, max_val) 

247 

248 # Denormalize 

249 if img_max_orig > img_min_orig: 

250 deconvolved_final = (deconvolved_norm - min_val) / (max_val - min_val) 

251 deconvolved_final = deconvolved_final * (img_max_orig - img_min_orig) + img_min_orig 

252 else: 

253 deconvolved_final = torch.full_like(deconvolved_norm, img_min_orig) 

254 

255 # Return in original input shape 

256 if image.ndim == 2: 

257 return deconvolved_final.squeeze(0).squeeze(0) 

258 elif image.ndim == 3: 

259 return deconvolved_final.squeeze(1) 

260 return deconvolved_final