Coverage for openhcs/processing/backends/assemblers/self_supervised_stitcher.py: 7.0%

267 statements  

« prev     ^ index     » next       coverage.py v7.10.3, created at 2025-08-14 05:57 +0000

1from __future__ import annotations 

2 

3import math 

4from typing import Any, Dict, List, Optional, Tuple 

5 

6from openhcs.utils.import_utils import optional_import, create_placeholder_class 

7from openhcs.core.memory.decorators import torch as torch_backend_func 

8 

9# Import torch modules as optional dependencies 

10torch = optional_import("torch") 

11nn = optional_import("torch.nn") if torch is not None else None 

12F = optional_import("torch.nn.functional") if torch is not None else None 

13models = optional_import("torchvision.models") if optional_import("torchvision") is not None else None 

14 

15nnModule = create_placeholder_class( 

16 "Module", # Name for the placeholder if generated 

17 base_class=nn.Module if nn else None, 

18 required_library="PyTorch" 

19) 

20# --- Helper Modules and Functions (Placeholders or Simplified) --- 

21 

22class FeatureEncoder(nnModule): 

23 def __init__(self, pretrained: bool = True): 

24 super().__init__() 

25 resnet = models.resnet18(weights=models.ResNet18_Weights.DEFAULT if pretrained else None) 

26 # Modify ResNet for grayscale input (1 channel) 

27 self.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False) 

28 if pretrained: 

29 # Adapt weights from RGB to grayscale for the first layer if possible 

30 # Simple averaging of RGB weights: 

31 rgb_weights = resnet.conv1.weight.data 

32 self.conv1.weight.data = rgb_weights.mean(dim=1, keepdim=True) 

33 

34 self.bn1 = resnet.bn1 

35 self.relu = resnet.relu 

36 self.maxpool = resnet.maxpool 

37 self.layer1 = resnet.layer1 

38 self.layer2 = resnet.layer2 

39 self.layer3 = resnet.layer3 

40 # self.layer4 = resnet.layer4 # Often excluded for smaller feature maps in stitching 

41 

42 def forward(self, x: torch.Tensor) -> torch.Tensor: 

43 x = self.conv1(x) 

44 x = self.bn1(x) 

45 x = self.relu(x) 

46 x = self.maxpool(x) 

47 x = self.layer1(x) 

48 x = self.layer2(x) 

49 x = self.layer3(x) 

50 # x = self.layer4(x) 

51 return x 

52 

53class HomographyPredictionNet(nnModule): 

54 def __init__(self, feature_dim: int): 

55 super().__init__() 

56 # Placeholder: A simple network to predict 8 parameters for homography (last one is 1) 

57 # Input would be concatenated features of two tiles 

58 self.fc = nn.Sequential( 

59 nn.Linear(feature_dim * 2, 512), # Assuming features are flattened and concatenated 

60 nn.ReLU(), 

61 nn.Linear(512, 256), 

62 nn.ReLU(), 

63 nn.Linear(256, 8) # For du, dv parameters of homography matrix corners 

64 ) 

65 # Initialize weights to output near-identity transform initially 

66 # For du, dv, this means initializing biases to zero. 

67 # The last layer's weights should also be small. 

68 for m in self.fc: 

69 if isinstance(m, nn.Linear): 

70 nn.init.zeros_(m.bias) 

71 nn.init.xavier_uniform_(m.weight, gain=0.01) 

72 

73 

74 def forward(self, features1: torch.Tensor, features2: torch.Tensor) -> torch.Tensor: 

75 # Assuming features1, features2 are [B, C, Hf, Wf] 

76 # Flatten and concatenate 

77 batch_size = features1.shape[0] 

78 flat_features1 = features1.mean(dim=[2,3]) # Global average pooling as a simple feature vector 

79 flat_features2 = features2.mean(dim=[2,3]) 

80 

81 combined_features = torch.cat((flat_features1, flat_features2), dim=1) 

82 params_8 = self.fc(combined_features) # [B, 8] 

83 

84 # Construct 3x3 homography matrix from 8 parameters 

85 # H = [[h00, h01, h02], [h10, h11, h12], [h20, h21, 1]] 

86 # params_8 = [h00-1, h01, h02, h10, h11-1, h12, h20, h21] (delta from identity) 

87 homography = torch.eye(3, device=params_8.device).repeat(batch_size, 1, 1) 

88 homography[:, 0, 0] += params_8[:, 0] 

89 homography[:, 0, 1] = params_8[:, 1] 

90 homography[:, 0, 2] = params_8[:, 2] 

91 homography[:, 1, 0] = params_8[:, 3] 

92 homography[:, 1, 1] += params_8[:, 4] 

93 homography[:, 1, 2] = params_8[:, 5] 

94 homography[:, 2, 0] = params_8[:, 6] 

95 homography[:, 2, 1] = params_8[:, 7] 

96 # homography[:, 2, 2] is already 1 

97 return homography 

98 

99def barlow_twins_loss(z1: torch.Tensor, z2: torch.Tensor, lambda_coeff: float = 5e-3) -> torch.Tensor: 

100 # Placeholder for Barlow Twins loss 

101 # z1, z2 are [B, FeatureDim] 

102 # Normalize features 

103 z1_norm = (z1 - z1.mean(dim=0)) / (z1.std(dim=0) + 1e-5) 

104 z2_norm = (z2 - z2.mean(dim=0)) / (z2.std(dim=0) + 1e-5) 

105 

106 N, D = z1_norm.shape 

107 c = (z1_norm.T @ z2_norm) / N # Cross-correlation matrix 

108 

109 on_diag = torch.diagonal(c).add_(-1).pow_(2).sum() 

110 off_diag = c.masked_fill(torch.eye(D, device=c.device, dtype=torch.bool), 0).pow_(2).sum() 

111 loss = on_diag + lambda_coeff * off_diag 

112 return loss 

113 

114def geometry_consistency_loss(H_ab: torch.Tensor, H_ba: torch.Tensor) -> torch.Tensor: 

115 # H_ab: transform from b to a, H_ba: transform from a to b 

116 # H_ab @ H_ba should be close to Identity 

117 identity = torch.eye(3, device=H_ab.device).unsqueeze(0).repeat(H_ab.shape[0], 1, 1) 

118 product = H_ab @ H_ba 

119 loss = F.mse_loss(product, identity) 

120 return loss 

121 

122def photometric_loss(tile_warped: torch.Tensor, tile_target: torch.Tensor) -> torch.Tensor: 

123 return F.l1_loss(tile_warped, tile_target) 

124 

125def warp_tile_homography(tile: torch.Tensor, H: torch.Tensor, output_shape: Tuple[int, int]) -> torch.Tensor: 

126 """ Warps a tile using a homography. 

127 tile: [1, 1, H_in, W_in] 

128 H: [1, 3, 3] homography matrix 

129 output_shape: (H_out, W_out) 

130 """ 

131 H_out, W_out = output_shape 

132 # Create grid for sampling 

133 grid_y, grid_x = torch.meshgrid(torch.linspace(-1, 1, H_out, device=tile.device), 

134 torch.linspace(-1, 1, W_out, device=tile.device), indexing='ij') 

135 grid = torch.stack((grid_x, grid_y, torch.ones_like(grid_x)), dim=-1) # [H_out, W_out, 3] 

136 grid = grid.view(1, H_out * W_out, 3) # [1, H_out*W_out, 3] 

137 

138 # We need to warp from output grid to input tile's coordinate system. 

139 # So we need H_inv. If H transforms points from tile_src to tile_dst, 

140 # grid_sample needs a grid in tile_src's coordinates. 

141 # If H maps tile_src to tile_dst, then H_inv maps tile_dst (output) to tile_src (input). 

142 try: 

143 H_inv = torch.inverse(H) 

144 except RuntimeError: # Singular matrix 

145 H_inv = H # Fallback, or handle error 

146 print("Warning: Singular homography matrix encountered during inverse.") 

147 

148 

149 # Transform grid points 

150 # grid_transformed = grid @ H_inv.transpose(1, 2) # [1, H_out*W_out, 3] 

151 # grid_transformed = torch.matmul(grid, H_inv.transpose(1,2)) 

152 grid_transformed = torch.bmm(grid, H_inv.transpose(1,2)) 

153 

154 

155 # Normalize to [-1, 1] for grid_sample 

156 # grid_transformed[:, :, 0] = grid_transformed[:, :, 0] / grid_transformed[:, :, 2] # x' = x/w 

157 # grid_transformed[:, :, 1] = grid_transformed[:, :, 1] / grid_transformed[:, :, 2] # y' = y/w 

158 

159 # Avoid division by zero or very small w 

160 w_coords = grid_transformed[:, :, 2].unsqueeze(2) 

161 safe_w_coords = torch.where(torch.abs(w_coords) < 1e-6, torch.sign(w_coords) * 1e-6 + 1e-9, w_coords) 

162 

163 grid_transformed_normalized = grid_transformed[:, :, :2] / safe_w_coords # [1, H_out*W_out, 2] 

164 

165 # Reshape for grid_sample 

166 sampling_grid = grid_transformed_normalized.view(1, H_out, W_out, 2) # [B, H_out, W_out, 2] (x,y order) 

167 

168 warped_tile = F.grid_sample(tile, sampling_grid, mode='bilinear', padding_mode='zeros', align_corners=False) 

169 return warped_tile 

170 

171 

172def get_adjacency_from_layout(layout_rows: int, layout_cols: int, num_tiles: int, device: torch.device) -> List[Tuple[int, int]]: 

173 """ 

174 Generates a list of adjacent tile index pairs based on a grid layout. 

175 Considers 4-connectivity (up, down, left, right). 

176 """ 

177 adjacency_pairs = [] 

178 for r in range(layout_rows): 

179 for c in range(layout_cols): 

180 current_idx = r * layout_cols + c 

181 if current_idx >= num_tiles: 

182 continue 

183 

184 # Check right neighbor 

185 if c + 1 < layout_cols: 

186 right_idx = r * layout_cols + (c + 1) 

187 if right_idx < num_tiles: 

188 adjacency_pairs.append(tuple(sorted((current_idx, right_idx)))) 

189 

190 # Check bottom neighbor 

191 if r + 1 < layout_rows: 

192 bottom_idx = (r + 1) * layout_cols + c 

193 if bottom_idx < num_tiles: 

194 adjacency_pairs.append(tuple(sorted((current_idx, bottom_idx)))) 

195 

196 # Remove duplicates that might arise from sorted tuples if order doesn't matter for pairs 

197 return sorted(list(set(adjacency_pairs))) 

198 

199def optimize_pose_graph( 

200 pairwise_homographies: Dict[Tuple[int, int], torch.Tensor], 

201 num_tiles: int, 

202 device: torch.device, 

203 initial_global_transforms: Optional[List[torch.Tensor]] = None 

204) -> List[torch.Tensor]: 

205 """ 

206 Placeholder for pose graph optimization. 

207 Takes pairwise homographies and refines global tile transforms. 

208 Returns a list of [3,3] global homography matrices for each tile. 

209 """ 

210 print("Pose Graph Optimization: SKIPPED (using initial/placeholder transforms).") 

211 # TODO: Implement actual pose graph optimization (e.g., using least squares on log-transforms, or a spring model). 

212 # For now, if initial_global_transforms are provided, use them, otherwise return identity. 

213 if initial_global_transforms: 

214 if len(initial_global_transforms) == num_tiles: 

215 return initial_global_transforms 

216 

217 # Fallback to identity if no initial transforms or mismatch 

218 global_transforms = [torch.eye(3, device=device) for _ in range(num_tiles)] 

219 if num_tiles > 0: # Anchor the first tile 

220 global_transforms[0] = torch.eye(3, device=device) 

221 

222 # Simplistic chaining if pairwise_homographies were available (illustrative, not robust) 

223 # This part should be replaced by a real graph solver. 

224 # Example: if H_ij transforms tile j to tile i's frame. 

225 # for i in range(1, num_tiles): 

226 # if (i-1, i) in pairwise_homographies: 

227 # H_prev_curr = pairwise_homographies[(i-1,i)] # H that maps tile i to tile i-1 frame 

228 # global_transforms[i] = global_transforms[i-1] @ torch.inverse(H_prev_curr) 

229 # elif (i, i-1) in pairwise_homographies: 

230 # H_curr_prev = pairwise_homographies[(i,i-1)] # H that maps tile i-1 to tile i frame 

231 # global_transforms[i] = global_transforms[i-1] @ H_curr_prev 

232 

233 

234 return global_transforms 

235 

236 

237# --- Main Stitcher Function --- 

238@torch_backend_func 

239def self_supervised_stitcher_func( 

240 tile_stack: torch.Tensor, # shape: [Z, Y, X] 

241 *, 

242 tile_shape_override: Optional[Tuple[int, int]] = None, # (tile_height, tile_width) 

243 layout_shape_override: Optional[Tuple[int, int]] = None, # (rows, cols) 

244 learn: bool = False, 

245 num_train_iterations: int = 100, # Only if learn=True 

246 overlap_percent: float = 0.1, # For global transform normalization 

247 return_homographies: bool = False, 

248 # For pre-trained model paths 

249 encoder_path: Optional[str] = None, 

250 homography_net_path: Optional[str] = None 

251) -> torch.Tensor | Tuple[torch.Tensor, Tuple[int, int]] | Tuple[torch.Tensor, torch.Tensor, Tuple[int, int]]: 

252 """ 

253 Self-supervised image stitching module. 

254 Learns relative alignment of tiles using unsupervised geometry matching. 

255 Infers pairwise transformations and composes them into global tile offsets (x, y). 

256 Returns: 

257 - position tensor [1, Z, 2] 

258 - (Optionally) global homographies [Z, 3, 3] 

259 - canvas dimensions (canvas_H, canvas_W) 

260 """ 

261 device = tile_stack.device 

262 Z, Y, X = tile_stack.shape 

263 

264 if Z == 0: 

265 empty_positions = torch.empty((1, 0, 2), device=device, dtype=torch.float32) 

266 canvas_dims = (0, 0) 

267 if return_homographies: 

268 empty_homographies = torch.empty((0, 3, 3), device=device, dtype=torch.float32) 

269 return empty_positions, empty_homographies, canvas_dims 

270 return empty_positions, canvas_dims 

271 

272 # 1. Input Description & Defaults 

273 tile_shape: Tuple[int, int] = tile_shape_override if tile_shape_override else (Y, X) 

274 H_tile, W_tile = tile_shape 

275 

276 if layout_shape_override: 

277 layout_rows, layout_cols = layout_shape_override 

278 if layout_rows * layout_cols < Z: 

279 raise ValueError(f"Provided layout_shape {layout_shape_override} is too small for {Z} tiles.") 

280 else: 

281 # Infer layout_shape to be as square as possible 

282 layout_rows = math.ceil(math.sqrt(Z)) 

283 layout_cols = math.ceil(Z / layout_rows) 

284 # Ensure it's at least Z 

285 while layout_rows * layout_cols < Z: 

286 layout_cols +=1 # or layout_rows, depending on preference for aspect ratio 

287 

288 layout_shape: Tuple[int, int] = (layout_rows, layout_cols) 

289 

290 print(f"Using tile_shape: {tile_shape}, layout_shape: {layout_shape} for {Z} tiles.") 

291 

292 # 2. Tile Reshaping 

293 # Reshape [Z, Y, X] -> [Z, 1, H_tile, W_tile] for CNN 

294 tiles_for_cnn = tile_stack.unsqueeze(1).float() # Add channel dim 

295 

296 # Create a 2D grid representation of tiles for adjacency 

297 # Pad tile_stack if Z < layout_rows * layout_cols 

298 num_layout_slots = layout_rows * layout_cols 

299 if Z < num_layout_slots: 

300 padding_count = num_layout_slots - Z 

301 padding_tensor = torch.zeros(padding_count, Y, X, device=device, dtype=tile_stack.dtype) 

302 padded_tile_stack = torch.cat((tile_stack, padding_tensor), dim=0) 

303 else: 

304 padded_tile_stack = tile_stack 

305 

306 tile_grid = padded_tile_stack.view(layout_rows, layout_cols, Y, X) 

307 

308 # 3. Feature Encoder 

309 feature_encoder = FeatureEncoder().to(device) 

310 if learn: 

311 feature_encoder.train() 

312 else: 

313 feature_encoder.eval() 

314 

315 # 4. Unsupervised Alignment (AltO-inspired) 

316 # This is a highly complex part. Placeholder logic: 

317 

318 # Store pairwise homographies, e.g., from tile i to tile j 

319 pairwise_H_matrices: Dict[Tuple[int, int], torch.Tensor] = {} 

320 # global_transforms will be populated by pose graph optimization or fallback 

321 global_transforms: List[torch.Tensor] = [torch.eye(3, device=device) for _ in range(Z)] 

322 

323 

324 if learn: 

325 print("Starting learning phase for pairwise alignments...") 

326 dummy_features = feature_encoder(tiles_for_cnn[:1]) 

327 feature_dim_encoder = dummy_features.shape[1] 

328 homography_net = HomographyPredictionNet(feature_dim_encoder).to(device) 

329 optimizer = torch.optim.Adam(list(feature_encoder.parameters()) + list(homography_net.parameters()), lr=1e-4) 

330 

331 # Get adjacent pairs based on layout 

332 adjacent_tile_pairs = get_adjacency_from_layout(layout_rows, layout_cols, Z, device) 

333 

334 if Z > 1 and not adjacent_tile_pairs : 

335 print("Warning: No adjacent pairs from layout for training, using sequential pairs as fallback.") 

336 for i_rand_pair in range(Z-1): # Create simple chain pairs 

337 adjacent_tile_pairs.append(tuple(sorted((i_rand_pair, i_rand_pair+1)))) 

338 adjacent_tile_pairs = sorted(list(set(adjacent_tile_pairs))) 

339 

340 for iter_idx in range(num_train_iterations): 

341 optimizer.zero_grad() 

342 total_loss_iter = torch.tensor(0.0, device=device) 

343 

344 if not adjacent_tile_pairs or Z < 2: 

345 print("Not enough tiles or pairs for training iteration.") 

346 break 

347 

348 # Create batches from adjacent_tile_pairs 

349 # For simplicity, process a few pairs per iteration or all if few 

350 # Ensure num_pairs_batch is at least 1 if adjacent_tile_pairs is not empty 

351 num_pairs_in_batch = min(len(adjacent_tile_pairs), 8) if adjacent_tile_pairs else 0 

352 if num_pairs_in_batch == 0: 

353 print("No pairs to process in this iteration.") 

354 continue 

355 

356 # Use randint for memory efficiency (though len(adjacent_tile_pairs) is typically small) 

357 current_batch_pairs_indices = torch.randint(0, len(adjacent_tile_pairs), (num_pairs_in_batch,)) 

358 

359 batch_idx1_list = [] 

360 batch_idx2_list = [] 

361 for perm_idx in current_batch_pairs_indices: 

362 p1, p2 = adjacent_tile_pairs[perm_idx.item()] 

363 batch_idx1_list.append(p1) 

364 batch_idx2_list.append(p2) 

365 

366 idx1 = torch.tensor(batch_idx1_list, device=device, dtype=torch.long) 

367 idx2 = torch.tensor(batch_idx2_list, device=device, dtype=torch.long) 

368 

369 tiles1_batch = tiles_for_cnn[idx1] 

370 tiles2_batch = tiles_for_cnn[idx2] 

371 

372 features1 = feature_encoder(tiles1_batch) 

373 features2 = feature_encoder(tiles2_batch) 

374 

375 H_12 = homography_net(features1, features2) # tile2 -> tile1 

376 H_21 = homography_net(features2, features1) # tile1 -> tile2 

377 

378 # Store these for later graph optimization 

379 for i_pair in range(idx1.shape[0]): 

380 p_idx1, p_idx2 = idx1[i_pair].item(), idx2[i_pair].item() 

381 pairwise_H_matrices[(p_idx1, p_idx2)] = H_12[i_pair].detach() # H mapping p_idx2 to p_idx1 frame 

382 pairwise_H_matrices[(p_idx2, p_idx1)] = H_21[i_pair].detach() # H mapping p_idx1 to p_idx2 frame 

383 

384 loss_bt = barlow_twins_loss(features1.mean(dim=[2,3]), features2.mean(dim=[2,3])) 

385 loss_geom = geometry_consistency_loss(H_12, H_21) 

386 

387 # TODO: For photometric loss, consider replacing grid_sample in warp_tile_homography 

388 # with GPU-efficient tensor tiling (block aggregation) for significant speedup, 

389 # especially if overlap regions are known or can be estimated. 

390 tiles2_warped_to_1 = torch.stack([ 

391 warp_tile_homography(tiles2_batch[i].unsqueeze(0), H_12[i].unsqueeze(0), tile_shape) 

392 for i in range(tiles2_batch.shape[0])]).squeeze(1) 

393 loss_photo = photometric_loss(tiles2_warped_to_1, tiles1_batch) 

394 

395 loss_total_batch = loss_bt + loss_geom + loss_photo 

396 loss_total_batch.backward() 

397 optimizer.step() 

398 total_loss_iter += loss_total_batch.item() 

399 

400 if num_pairs_in_batch > 0 and (iter_idx + 1) % max(1, (num_train_iterations // 10)) == 0 : 

401 print(f"Iter {iter_idx+1}/{num_train_iterations}, Loss: {total_loss_iter / num_pairs_in_batch:.4f}") 

402 

403 print("Learning phase finished.") 

404 # 5. Graph-Based Global Alignment (after learning all pairwise) 

405 # Ensure all necessary pairs for graph are computed if not covered by training batches 

406 with torch.no_grad(): 

407 feature_encoder.eval() 

408 homography_net.eval() 

409 all_tile_features_final = feature_encoder(tiles_for_cnn[:Z]) 

410 

411 all_graph_pairs = get_adjacency_from_layout(layout_rows, layout_cols, Z, device) 

412 # Optionally add more pairs (e.g., random, next-nearest) for graph robustness 

413 # ... (logic for adding more pairs can be inserted here) ... 

414 

415 for p1_g, p2_g in all_graph_pairs: 

416 # Only compute if not already in pairwise_H_matrices from training 

417 if (p1_g, p2_g) not in pairwise_H_matrices: 

418 feat1_g = all_tile_features_final[p1_g].unsqueeze(0) 

419 feat2_g = all_tile_features_final[p2_g].unsqueeze(0) 

420 H_p1_p2 = homography_net(feat1_g, feat2_g).squeeze(0) # p2 -> p1 frame 

421 pairwise_H_matrices[(p1_g, p2_g)] = H_p1_p2 

422 if (p2_g, p1_g) not in pairwise_H_matrices: # And the reverse 

423 feat1_g = all_tile_features_final[p1_g].unsqueeze(0) 

424 feat2_g = all_tile_features_final[p2_g].unsqueeze(0) 

425 H_p2_p1 = homography_net(feat2_g, feat1_g).squeeze(0) # p1 -> p2 frame 

426 pairwise_H_matrices[(p2_g, p1_g)] = H_p2_p1 

427 

428 # Initial global transforms for optimization (e.g., identity or grid-based estimate) 

429 initial_transforms_for_opt = [torch.eye(3, device=device) for _ in range(Z)] 

430 # A simple grid layout can be a better start than pure identity for all 

431 for i_init in range(Z): 

432 row_idx_init, col_idx_init = i_init // layout_cols, i_init % layout_cols 

433 dx_init = col_idx_init * W_tile * (1.0 - overlap_percent) 

434 dy_init = row_idx_init * H_tile * (1.0 - overlap_percent) 

435 translate_matrix_init = torch.eye(3, device=device) 

436 translate_matrix_init[0, 2] = dx_init 

437 translate_matrix_init[1, 2] = dy_init 

438 initial_transforms_for_opt[i_init] = translate_matrix_init 

439 

440 global_transforms = optimize_pose_graph(pairwise_H_matrices, Z, device, initial_transforms_for_opt) 

441 

442 else: # learn=False 

443 print("Inference mode: Using placeholder global transforms (grid layout).") 

444 # TODO: Load pre-trained feature_encoder and homography_net 

445 # For now, use a simple grid layout for global transforms 

446 for i in range(Z): 

447 row_idx = i // layout_cols 

448 col_idx = i % layout_cols 

449 dx = col_idx * W_tile * (1.0 - overlap_percent) # Assume some overlap 

450 dy = row_idx * H_tile * (1.0 - overlap_percent) 

451 

452 translate_matrix = torch.eye(3, device=device) 

453 translate_matrix[0, 2] = dx 

454 translate_matrix[1, 2] = dy 

455 global_transforms[i] = translate_matrix 

456 

457 # 5. Graph-Based Global Alignment 

458 # TODO: Implement a proper graph solver (e.g., spring model or least-squares on H_matrices) 

459 # This step would refine `global_transforms`. The current `global_transforms` are placeholders. 

460 print("Graph-Based Global Alignment: SKIPPED (using placeholder transforms).") 

461 

462 

463 # 6. Finalize Global Transforms and Extract Positions 

464 # The global_transforms list currently holds [3,3] homography matrices for each tile, 

465 # mapping its local coordinates to a common global frame. 

466 # We need to ensure the coordinate system's origin (0,0) is sensible, 

467 # e.g., the top-leftmost point of the stitched layout. 

468 

469 # Transform corners of all tiles to the current global frame to find bounds 

470 all_corners_global_frame = [] 

471 # Define corners of a tile in its local coordinate system (homogeneous) 

472 # (0,0), (W-1,0), (0,H-1), (W-1,H-1) 

473 tile_local_corners_homog = torch.tensor([ 

474 [0, W_tile - 1, 0 , W_tile - 1], # x-coordinates 

475 [0, 0 , H_tile - 1, H_tile - 1], # y-coordinates 

476 [1, 1 , 1 , 1 ] # w-coordinates 

477 ], dtype=torch.float32, device=device) # Shape: [3, 4] 

478 

479 for i in range(Z): 

480 H_global_i = global_transforms[i] # Shape [3, 3] 

481 # Transform local corners to global frame: H_global @ local_corners 

482 corners_transformed_homog_i = H_global_i @ tile_local_corners_homog # Shape [3, 4] 

483 

484 # Perspective divide (x/w, y/w) 

485 w_coords_i = corners_transformed_homog_i[2, :] 

486 # Avoid division by zero for w: if w is close to 0, it's problematic. 

487 # For affine transforms (like translation used in placeholder), w is always 1. 

488 # For full homographies, w can vary. 

489 safe_w_coords_i = torch.where(torch.abs(w_coords_i) < 1e-6, torch.sign(w_coords_i) * 1e-6 + 1e-9, w_coords_i) 

490 

491 corners_global_frame_i = corners_transformed_homog_i[:2, :] / safe_w_coords_i # Shape [2, 4] (x,y) 

492 all_corners_global_frame.append(corners_global_frame_i) 

493 

494 all_corners_stacked = torch.cat(all_corners_global_frame, dim=1) # Shape [2, Z*4] 

495 

496 # Find min x and min y to define the top-left of the bounding box of all tiles 

497 min_global_coords = torch.min(all_corners_stacked, dim=1).values # Shape [2] (min_x, min_y) 

498 

499 # Create an offset matrix to shift the entire layout so that min_global_coords becomes (0,0) 

500 offset_x_to_origin = -min_global_coords[0] 

501 offset_y_to_origin = -min_global_coords[1] 

502 

503 normalization_offset_matrix = torch.eye(3, device=device) 

504 normalization_offset_matrix[0, 2] = offset_x_to_origin 

505 normalization_offset_matrix[1, 2] = offset_y_to_origin 

506 

507 # Apply this normalization to all global transforms 

508 final_global_transforms_list = [] 

509 for i in range(Z): 

510 final_H_i = normalization_offset_matrix @ global_transforms[i] 

511 final_global_transforms_list.append(final_H_i) 

512 

513 # Extract (x, y) positions (top-left corner of each tile) from the final homographies 

514 # The translation part of H (H[0,2], H[1,2]) gives the position of the tile's origin (0,0) 

515 # in the global (canvas) frame. 

516 tile_positions_xy = torch.stack( 

517 [H[0:2, 2] for H in final_global_transforms_list], dim=0 

518 ) # Shape [Z, 2] 

519 

520 # Reshape to [1, Z, 2] as per output spec 

521 output_positions = tile_positions_xy.unsqueeze(0) 

522 

523 if return_homographies: 

524 # Stack the list of [3,3] homography tensors into a single [Z, 3, 3] tensor 

525 output_homographies = torch.stack(final_global_transforms_list, dim=0) 

526 return output_positions, output_homographies 

527 else: 

528 return output_positions 

529 

530 

531if __name__ == '__main__': 

532 # Example Usage (for testing within this file) 

533 print("Running self_supervised_stitcher_func example...") 

534 device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 

535 

536 # Create synthetic jittered tiles 

537 Z_tiles = 4 

538 tile_H, tile_W = 64, 64 

539 

540 # Create a base image to extract tiles from 

541 base_img_H, base_img_W = tile_H * 2 + 20, tile_W * 2 + 20 # Larger base to allow jitter 

542 base_image = torch.zeros(base_img_H, base_img_W, device=device) 

543 # Add some features to the base image (e.g., a cross) 

544 base_image[base_img_H//2 - 10 : base_img_H//2 + 10, :] = 0.7 

545 base_image[:, base_img_W//2 - 10 : base_img_W//2 + 10] = 0.7 

546 base_image[base_img_H//4 : 3*base_img_H//4, base_img_W//4 : 3*base_img_W//4] += 0.3 

547 base_image = torch.clamp(base_image, 0, 1) 

548 

549 synthetic_tiles_list = [] 

550 overlap = 10 # pixels 

551 

552 # Expected layout: 2x2 

553 # Tile 0: top-left, Tile 1: top-right, Tile 2: bottom-left, Tile 3: bottom-right 

554 # Define ideal top-left corners for each tile in the base image 

555 ideal_starts = [ 

556 (5, 5), 

557 (5, 5 + tile_W - overlap), 

558 (5 + tile_H - overlap, 5), 

559 (5 + tile_H - overlap, 5 + tile_W - overlap) 

560 ] 

561 

562 for i in range(Z_tiles): 

563 start_y, start_x = ideal_starts[i] 

564 

565 # Add some random jitter to actual extraction 

566 jitter_y = torch.randint(-3, 4, (1,)).item() 

567 jitter_x = torch.randint(-3, 4, (1,)).item() 

568 

569 current_start_y = start_y + jitter_y 

570 current_start_x = start_x + jitter_x 

571 

572 tile = base_image[current_start_y : current_start_y + tile_H, 

573 current_start_x : current_start_x + tile_W].clone() 

574 

575 synthetic_tiles_list.append(tile) 

576 

577 synthetic_tile_stack = torch.stack(synthetic_tiles_list).to(device) 

578 print(f"Synthetic tile stack shape: {synthetic_tile_stack.shape}") 

579 

580 # Test with learn=True (will be slow and likely not converge well with few iterations) 

581 print("\nTesting with learn=True...") 

582 tile_positions_learn, homographies_learn = self_supervised_stitcher( 

583 synthetic_tile_stack.clone(), 

584 learn=True, 

585 num_train_iterations=10, # Very few iterations for a quick test 

586 return_homographies=True, 

587 layout_shape_override=(2,2) # Explicit layout for test 

588 ) 

589 print(f"Tile positions (learn=True) shape: {tile_positions_learn.shape}") 

590 print(f"Tile positions (learn=True, tile 0): {tile_positions_learn[0,0,:]}") 

591 print(f"Homographies (learn=True) shape: {homographies_learn.shape}") 

592 # print(f"Homography (learn=True, tile 0):\n{homographies_learn[0]}") 

593 

594 

595 # Test with learn=False (uses placeholder grid transforms) 

596 print("\nTesting with learn=False...") 

597 tile_positions_infer = self_supervised_stitcher( 

598 synthetic_tile_stack.clone(), 

599 learn=False, 

600 return_homographies=False, # Test this path too 

601 layout_shape_override=(2,2) 

602 ) 

603 print(f"Tile positions (learn=False) shape: {tile_positions_infer.shape}") 

604 print(f"Tile positions (learn=False, tile 0): {tile_positions_infer[0,0,:]}") 

605 

606 

607 # Try to visualize the layout if possible 

608 try: 

609 import matplotlib.pyplot as plt 

610 from matplotlib.patches import Rectangle 

611 

612 def plot_layout(ax, tile_stack_cpu, positions_cpu, title, H_tile, W_tile, homographies_cpu=None): 

613 ax.clear() 

614 ax.set_title(title) 

615 ax.set_aspect('equal', 'box') 

616 

617 all_x = [] 

618 all_y = [] 

619 

620 for i in range(positions_cpu.shape[0]): 

621 x, y = positions_cpu[i, 0], positions_cpu[i, 1] 

622 all_x.extend([x, x + W_tile]) # Approximate bounding box 

623 all_y.extend([y, y + H_tile]) 

624 

625 # Display tile image at its position (approximate, as homography might skew) 

626 # For simplicity, just place the tile image's top-left at (x,y) 

627 # A more accurate plot would use the homography to warp the tile outline 

628 

629 ax.imshow(tile_stack_cpu[i], cmap='gray', alpha=0.7, 

630 extent=(x, x + W_tile, y + H_tile, y)) # extent is (left, right, bottom, top) 

631 

632 # Draw a rectangle border for the tile 

633 rect = Rectangle((x, y), W_tile, H_tile, linewidth=1, edgecolor='r', facecolor='none') 

634 ax.add_patch(rect) 

635 ax.text(x, y, f"{i}", color='cyan', fontsize=8) 

636 

637 if all_x and all_y: 

638 ax.set_xlim(min(all_x) - W_tile*0.1, max(all_x) + W_tile*0.1) 

639 ax.set_ylim(max(all_y) + H_tile*0.1, min(all_y) - H_tile*0.1) # Flipped for imshow 

640 else: # Default if no positions 

641 ax.set_xlim(0, W_tile * 2) 

642 ax.set_ylim(H_tile * 2, 0) 

643 

644 

645 fig, axes = plt.subplots(1, 2, figsize=(12, 6)) 

646 

647 plot_layout(axes[0], 

648 synthetic_tile_stack.cpu().numpy(), 

649 tile_positions_learn.squeeze(0).cpu().detach().numpy(), 

650 "Layout (learn=True, 10 iter)", 

651 tile_H, tile_W, 

652 homographies_learn.cpu().detach().numpy() if homographies_learn is not None else None) 

653 

654 plot_layout(axes[1], 

655 synthetic_tile_stack.cpu().numpy(), 

656 tile_positions_infer.squeeze(0).cpu().detach().numpy(), 

657 "Layout (learn=False, Grid)", 

658 tile_H, tile_W) 

659 

660 plt.tight_layout() 

661 plt.savefig("self_supervised_stitcher_layout_test_output.png") 

662 print("\nSaved test layout plot to self_supervised_stitcher_layout_test_output.png") 

663 

664 except ImportError: 

665 print("\nMatplotlib not available. Skipping layout visualization.") 

666 except Exception as e: 

667 print(f"\nError during layout visualization: {e}") 

668 

669 print("Self_supervised_stitcher example finished.")