Coverage for openhcs/processing/backends/assemblers/self_supervised_stitcher.py: 7.0%
267 statements
« prev ^ index » next coverage.py v7.10.3, created at 2025-08-14 05:57 +0000
« prev ^ index » next coverage.py v7.10.3, created at 2025-08-14 05:57 +0000
1from __future__ import annotations
3import math
4from typing import Any, Dict, List, Optional, Tuple
6from openhcs.utils.import_utils import optional_import, create_placeholder_class
7from openhcs.core.memory.decorators import torch as torch_backend_func
9# Import torch modules as optional dependencies
10torch = optional_import("torch")
11nn = optional_import("torch.nn") if torch is not None else None
12F = optional_import("torch.nn.functional") if torch is not None else None
13models = optional_import("torchvision.models") if optional_import("torchvision") is not None else None
15nnModule = create_placeholder_class(
16 "Module", # Name for the placeholder if generated
17 base_class=nn.Module if nn else None,
18 required_library="PyTorch"
19)
20# --- Helper Modules and Functions (Placeholders or Simplified) ---
22class FeatureEncoder(nnModule):
23 def __init__(self, pretrained: bool = True):
24 super().__init__()
25 resnet = models.resnet18(weights=models.ResNet18_Weights.DEFAULT if pretrained else None)
26 # Modify ResNet for grayscale input (1 channel)
27 self.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
28 if pretrained:
29 # Adapt weights from RGB to grayscale for the first layer if possible
30 # Simple averaging of RGB weights:
31 rgb_weights = resnet.conv1.weight.data
32 self.conv1.weight.data = rgb_weights.mean(dim=1, keepdim=True)
34 self.bn1 = resnet.bn1
35 self.relu = resnet.relu
36 self.maxpool = resnet.maxpool
37 self.layer1 = resnet.layer1
38 self.layer2 = resnet.layer2
39 self.layer3 = resnet.layer3
40 # self.layer4 = resnet.layer4 # Often excluded for smaller feature maps in stitching
42 def forward(self, x: torch.Tensor) -> torch.Tensor:
43 x = self.conv1(x)
44 x = self.bn1(x)
45 x = self.relu(x)
46 x = self.maxpool(x)
47 x = self.layer1(x)
48 x = self.layer2(x)
49 x = self.layer3(x)
50 # x = self.layer4(x)
51 return x
53class HomographyPredictionNet(nnModule):
54 def __init__(self, feature_dim: int):
55 super().__init__()
56 # Placeholder: A simple network to predict 8 parameters for homography (last one is 1)
57 # Input would be concatenated features of two tiles
58 self.fc = nn.Sequential(
59 nn.Linear(feature_dim * 2, 512), # Assuming features are flattened and concatenated
60 nn.ReLU(),
61 nn.Linear(512, 256),
62 nn.ReLU(),
63 nn.Linear(256, 8) # For du, dv parameters of homography matrix corners
64 )
65 # Initialize weights to output near-identity transform initially
66 # For du, dv, this means initializing biases to zero.
67 # The last layer's weights should also be small.
68 for m in self.fc:
69 if isinstance(m, nn.Linear):
70 nn.init.zeros_(m.bias)
71 nn.init.xavier_uniform_(m.weight, gain=0.01)
74 def forward(self, features1: torch.Tensor, features2: torch.Tensor) -> torch.Tensor:
75 # Assuming features1, features2 are [B, C, Hf, Wf]
76 # Flatten and concatenate
77 batch_size = features1.shape[0]
78 flat_features1 = features1.mean(dim=[2,3]) # Global average pooling as a simple feature vector
79 flat_features2 = features2.mean(dim=[2,3])
81 combined_features = torch.cat((flat_features1, flat_features2), dim=1)
82 params_8 = self.fc(combined_features) # [B, 8]
84 # Construct 3x3 homography matrix from 8 parameters
85 # H = [[h00, h01, h02], [h10, h11, h12], [h20, h21, 1]]
86 # params_8 = [h00-1, h01, h02, h10, h11-1, h12, h20, h21] (delta from identity)
87 homography = torch.eye(3, device=params_8.device).repeat(batch_size, 1, 1)
88 homography[:, 0, 0] += params_8[:, 0]
89 homography[:, 0, 1] = params_8[:, 1]
90 homography[:, 0, 2] = params_8[:, 2]
91 homography[:, 1, 0] = params_8[:, 3]
92 homography[:, 1, 1] += params_8[:, 4]
93 homography[:, 1, 2] = params_8[:, 5]
94 homography[:, 2, 0] = params_8[:, 6]
95 homography[:, 2, 1] = params_8[:, 7]
96 # homography[:, 2, 2] is already 1
97 return homography
99def barlow_twins_loss(z1: torch.Tensor, z2: torch.Tensor, lambda_coeff: float = 5e-3) -> torch.Tensor:
100 # Placeholder for Barlow Twins loss
101 # z1, z2 are [B, FeatureDim]
102 # Normalize features
103 z1_norm = (z1 - z1.mean(dim=0)) / (z1.std(dim=0) + 1e-5)
104 z2_norm = (z2 - z2.mean(dim=0)) / (z2.std(dim=0) + 1e-5)
106 N, D = z1_norm.shape
107 c = (z1_norm.T @ z2_norm) / N # Cross-correlation matrix
109 on_diag = torch.diagonal(c).add_(-1).pow_(2).sum()
110 off_diag = c.masked_fill(torch.eye(D, device=c.device, dtype=torch.bool), 0).pow_(2).sum()
111 loss = on_diag + lambda_coeff * off_diag
112 return loss
114def geometry_consistency_loss(H_ab: torch.Tensor, H_ba: torch.Tensor) -> torch.Tensor:
115 # H_ab: transform from b to a, H_ba: transform from a to b
116 # H_ab @ H_ba should be close to Identity
117 identity = torch.eye(3, device=H_ab.device).unsqueeze(0).repeat(H_ab.shape[0], 1, 1)
118 product = H_ab @ H_ba
119 loss = F.mse_loss(product, identity)
120 return loss
122def photometric_loss(tile_warped: torch.Tensor, tile_target: torch.Tensor) -> torch.Tensor:
123 return F.l1_loss(tile_warped, tile_target)
125def warp_tile_homography(tile: torch.Tensor, H: torch.Tensor, output_shape: Tuple[int, int]) -> torch.Tensor:
126 """ Warps a tile using a homography.
127 tile: [1, 1, H_in, W_in]
128 H: [1, 3, 3] homography matrix
129 output_shape: (H_out, W_out)
130 """
131 H_out, W_out = output_shape
132 # Create grid for sampling
133 grid_y, grid_x = torch.meshgrid(torch.linspace(-1, 1, H_out, device=tile.device),
134 torch.linspace(-1, 1, W_out, device=tile.device), indexing='ij')
135 grid = torch.stack((grid_x, grid_y, torch.ones_like(grid_x)), dim=-1) # [H_out, W_out, 3]
136 grid = grid.view(1, H_out * W_out, 3) # [1, H_out*W_out, 3]
138 # We need to warp from output grid to input tile's coordinate system.
139 # So we need H_inv. If H transforms points from tile_src to tile_dst,
140 # grid_sample needs a grid in tile_src's coordinates.
141 # If H maps tile_src to tile_dst, then H_inv maps tile_dst (output) to tile_src (input).
142 try:
143 H_inv = torch.inverse(H)
144 except RuntimeError: # Singular matrix
145 H_inv = H # Fallback, or handle error
146 print("Warning: Singular homography matrix encountered during inverse.")
149 # Transform grid points
150 # grid_transformed = grid @ H_inv.transpose(1, 2) # [1, H_out*W_out, 3]
151 # grid_transformed = torch.matmul(grid, H_inv.transpose(1,2))
152 grid_transformed = torch.bmm(grid, H_inv.transpose(1,2))
155 # Normalize to [-1, 1] for grid_sample
156 # grid_transformed[:, :, 0] = grid_transformed[:, :, 0] / grid_transformed[:, :, 2] # x' = x/w
157 # grid_transformed[:, :, 1] = grid_transformed[:, :, 1] / grid_transformed[:, :, 2] # y' = y/w
159 # Avoid division by zero or very small w
160 w_coords = grid_transformed[:, :, 2].unsqueeze(2)
161 safe_w_coords = torch.where(torch.abs(w_coords) < 1e-6, torch.sign(w_coords) * 1e-6 + 1e-9, w_coords)
163 grid_transformed_normalized = grid_transformed[:, :, :2] / safe_w_coords # [1, H_out*W_out, 2]
165 # Reshape for grid_sample
166 sampling_grid = grid_transformed_normalized.view(1, H_out, W_out, 2) # [B, H_out, W_out, 2] (x,y order)
168 warped_tile = F.grid_sample(tile, sampling_grid, mode='bilinear', padding_mode='zeros', align_corners=False)
169 return warped_tile
172def get_adjacency_from_layout(layout_rows: int, layout_cols: int, num_tiles: int, device: torch.device) -> List[Tuple[int, int]]:
173 """
174 Generates a list of adjacent tile index pairs based on a grid layout.
175 Considers 4-connectivity (up, down, left, right).
176 """
177 adjacency_pairs = []
178 for r in range(layout_rows):
179 for c in range(layout_cols):
180 current_idx = r * layout_cols + c
181 if current_idx >= num_tiles:
182 continue
184 # Check right neighbor
185 if c + 1 < layout_cols:
186 right_idx = r * layout_cols + (c + 1)
187 if right_idx < num_tiles:
188 adjacency_pairs.append(tuple(sorted((current_idx, right_idx))))
190 # Check bottom neighbor
191 if r + 1 < layout_rows:
192 bottom_idx = (r + 1) * layout_cols + c
193 if bottom_idx < num_tiles:
194 adjacency_pairs.append(tuple(sorted((current_idx, bottom_idx))))
196 # Remove duplicates that might arise from sorted tuples if order doesn't matter for pairs
197 return sorted(list(set(adjacency_pairs)))
199def optimize_pose_graph(
200 pairwise_homographies: Dict[Tuple[int, int], torch.Tensor],
201 num_tiles: int,
202 device: torch.device,
203 initial_global_transforms: Optional[List[torch.Tensor]] = None
204) -> List[torch.Tensor]:
205 """
206 Placeholder for pose graph optimization.
207 Takes pairwise homographies and refines global tile transforms.
208 Returns a list of [3,3] global homography matrices for each tile.
209 """
210 print("Pose Graph Optimization: SKIPPED (using initial/placeholder transforms).")
211 # TODO: Implement actual pose graph optimization (e.g., using least squares on log-transforms, or a spring model).
212 # For now, if initial_global_transforms are provided, use them, otherwise return identity.
213 if initial_global_transforms:
214 if len(initial_global_transforms) == num_tiles:
215 return initial_global_transforms
217 # Fallback to identity if no initial transforms or mismatch
218 global_transforms = [torch.eye(3, device=device) for _ in range(num_tiles)]
219 if num_tiles > 0: # Anchor the first tile
220 global_transforms[0] = torch.eye(3, device=device)
222 # Simplistic chaining if pairwise_homographies were available (illustrative, not robust)
223 # This part should be replaced by a real graph solver.
224 # Example: if H_ij transforms tile j to tile i's frame.
225 # for i in range(1, num_tiles):
226 # if (i-1, i) in pairwise_homographies:
227 # H_prev_curr = pairwise_homographies[(i-1,i)] # H that maps tile i to tile i-1 frame
228 # global_transforms[i] = global_transforms[i-1] @ torch.inverse(H_prev_curr)
229 # elif (i, i-1) in pairwise_homographies:
230 # H_curr_prev = pairwise_homographies[(i,i-1)] # H that maps tile i-1 to tile i frame
231 # global_transforms[i] = global_transforms[i-1] @ H_curr_prev
234 return global_transforms
237# --- Main Stitcher Function ---
238@torch_backend_func
239def self_supervised_stitcher_func(
240 tile_stack: torch.Tensor, # shape: [Z, Y, X]
241 *,
242 tile_shape_override: Optional[Tuple[int, int]] = None, # (tile_height, tile_width)
243 layout_shape_override: Optional[Tuple[int, int]] = None, # (rows, cols)
244 learn: bool = False,
245 num_train_iterations: int = 100, # Only if learn=True
246 overlap_percent: float = 0.1, # For global transform normalization
247 return_homographies: bool = False,
248 # For pre-trained model paths
249 encoder_path: Optional[str] = None,
250 homography_net_path: Optional[str] = None
251) -> torch.Tensor | Tuple[torch.Tensor, Tuple[int, int]] | Tuple[torch.Tensor, torch.Tensor, Tuple[int, int]]:
252 """
253 Self-supervised image stitching module.
254 Learns relative alignment of tiles using unsupervised geometry matching.
255 Infers pairwise transformations and composes them into global tile offsets (x, y).
256 Returns:
257 - position tensor [1, Z, 2]
258 - (Optionally) global homographies [Z, 3, 3]
259 - canvas dimensions (canvas_H, canvas_W)
260 """
261 device = tile_stack.device
262 Z, Y, X = tile_stack.shape
264 if Z == 0:
265 empty_positions = torch.empty((1, 0, 2), device=device, dtype=torch.float32)
266 canvas_dims = (0, 0)
267 if return_homographies:
268 empty_homographies = torch.empty((0, 3, 3), device=device, dtype=torch.float32)
269 return empty_positions, empty_homographies, canvas_dims
270 return empty_positions, canvas_dims
272 # 1. Input Description & Defaults
273 tile_shape: Tuple[int, int] = tile_shape_override if tile_shape_override else (Y, X)
274 H_tile, W_tile = tile_shape
276 if layout_shape_override:
277 layout_rows, layout_cols = layout_shape_override
278 if layout_rows * layout_cols < Z:
279 raise ValueError(f"Provided layout_shape {layout_shape_override} is too small for {Z} tiles.")
280 else:
281 # Infer layout_shape to be as square as possible
282 layout_rows = math.ceil(math.sqrt(Z))
283 layout_cols = math.ceil(Z / layout_rows)
284 # Ensure it's at least Z
285 while layout_rows * layout_cols < Z:
286 layout_cols +=1 # or layout_rows, depending on preference for aspect ratio
288 layout_shape: Tuple[int, int] = (layout_rows, layout_cols)
290 print(f"Using tile_shape: {tile_shape}, layout_shape: {layout_shape} for {Z} tiles.")
292 # 2. Tile Reshaping
293 # Reshape [Z, Y, X] -> [Z, 1, H_tile, W_tile] for CNN
294 tiles_for_cnn = tile_stack.unsqueeze(1).float() # Add channel dim
296 # Create a 2D grid representation of tiles for adjacency
297 # Pad tile_stack if Z < layout_rows * layout_cols
298 num_layout_slots = layout_rows * layout_cols
299 if Z < num_layout_slots:
300 padding_count = num_layout_slots - Z
301 padding_tensor = torch.zeros(padding_count, Y, X, device=device, dtype=tile_stack.dtype)
302 padded_tile_stack = torch.cat((tile_stack, padding_tensor), dim=0)
303 else:
304 padded_tile_stack = tile_stack
306 tile_grid = padded_tile_stack.view(layout_rows, layout_cols, Y, X)
308 # 3. Feature Encoder
309 feature_encoder = FeatureEncoder().to(device)
310 if learn:
311 feature_encoder.train()
312 else:
313 feature_encoder.eval()
315 # 4. Unsupervised Alignment (AltO-inspired)
316 # This is a highly complex part. Placeholder logic:
318 # Store pairwise homographies, e.g., from tile i to tile j
319 pairwise_H_matrices: Dict[Tuple[int, int], torch.Tensor] = {}
320 # global_transforms will be populated by pose graph optimization or fallback
321 global_transforms: List[torch.Tensor] = [torch.eye(3, device=device) for _ in range(Z)]
324 if learn:
325 print("Starting learning phase for pairwise alignments...")
326 dummy_features = feature_encoder(tiles_for_cnn[:1])
327 feature_dim_encoder = dummy_features.shape[1]
328 homography_net = HomographyPredictionNet(feature_dim_encoder).to(device)
329 optimizer = torch.optim.Adam(list(feature_encoder.parameters()) + list(homography_net.parameters()), lr=1e-4)
331 # Get adjacent pairs based on layout
332 adjacent_tile_pairs = get_adjacency_from_layout(layout_rows, layout_cols, Z, device)
334 if Z > 1 and not adjacent_tile_pairs :
335 print("Warning: No adjacent pairs from layout for training, using sequential pairs as fallback.")
336 for i_rand_pair in range(Z-1): # Create simple chain pairs
337 adjacent_tile_pairs.append(tuple(sorted((i_rand_pair, i_rand_pair+1))))
338 adjacent_tile_pairs = sorted(list(set(adjacent_tile_pairs)))
340 for iter_idx in range(num_train_iterations):
341 optimizer.zero_grad()
342 total_loss_iter = torch.tensor(0.0, device=device)
344 if not adjacent_tile_pairs or Z < 2:
345 print("Not enough tiles or pairs for training iteration.")
346 break
348 # Create batches from adjacent_tile_pairs
349 # For simplicity, process a few pairs per iteration or all if few
350 # Ensure num_pairs_batch is at least 1 if adjacent_tile_pairs is not empty
351 num_pairs_in_batch = min(len(adjacent_tile_pairs), 8) if adjacent_tile_pairs else 0
352 if num_pairs_in_batch == 0:
353 print("No pairs to process in this iteration.")
354 continue
356 # Use randint for memory efficiency (though len(adjacent_tile_pairs) is typically small)
357 current_batch_pairs_indices = torch.randint(0, len(adjacent_tile_pairs), (num_pairs_in_batch,))
359 batch_idx1_list = []
360 batch_idx2_list = []
361 for perm_idx in current_batch_pairs_indices:
362 p1, p2 = adjacent_tile_pairs[perm_idx.item()]
363 batch_idx1_list.append(p1)
364 batch_idx2_list.append(p2)
366 idx1 = torch.tensor(batch_idx1_list, device=device, dtype=torch.long)
367 idx2 = torch.tensor(batch_idx2_list, device=device, dtype=torch.long)
369 tiles1_batch = tiles_for_cnn[idx1]
370 tiles2_batch = tiles_for_cnn[idx2]
372 features1 = feature_encoder(tiles1_batch)
373 features2 = feature_encoder(tiles2_batch)
375 H_12 = homography_net(features1, features2) # tile2 -> tile1
376 H_21 = homography_net(features2, features1) # tile1 -> tile2
378 # Store these for later graph optimization
379 for i_pair in range(idx1.shape[0]):
380 p_idx1, p_idx2 = idx1[i_pair].item(), idx2[i_pair].item()
381 pairwise_H_matrices[(p_idx1, p_idx2)] = H_12[i_pair].detach() # H mapping p_idx2 to p_idx1 frame
382 pairwise_H_matrices[(p_idx2, p_idx1)] = H_21[i_pair].detach() # H mapping p_idx1 to p_idx2 frame
384 loss_bt = barlow_twins_loss(features1.mean(dim=[2,3]), features2.mean(dim=[2,3]))
385 loss_geom = geometry_consistency_loss(H_12, H_21)
387 # TODO: For photometric loss, consider replacing grid_sample in warp_tile_homography
388 # with GPU-efficient tensor tiling (block aggregation) for significant speedup,
389 # especially if overlap regions are known or can be estimated.
390 tiles2_warped_to_1 = torch.stack([
391 warp_tile_homography(tiles2_batch[i].unsqueeze(0), H_12[i].unsqueeze(0), tile_shape)
392 for i in range(tiles2_batch.shape[0])]).squeeze(1)
393 loss_photo = photometric_loss(tiles2_warped_to_1, tiles1_batch)
395 loss_total_batch = loss_bt + loss_geom + loss_photo
396 loss_total_batch.backward()
397 optimizer.step()
398 total_loss_iter += loss_total_batch.item()
400 if num_pairs_in_batch > 0 and (iter_idx + 1) % max(1, (num_train_iterations // 10)) == 0 :
401 print(f"Iter {iter_idx+1}/{num_train_iterations}, Loss: {total_loss_iter / num_pairs_in_batch:.4f}")
403 print("Learning phase finished.")
404 # 5. Graph-Based Global Alignment (after learning all pairwise)
405 # Ensure all necessary pairs for graph are computed if not covered by training batches
406 with torch.no_grad():
407 feature_encoder.eval()
408 homography_net.eval()
409 all_tile_features_final = feature_encoder(tiles_for_cnn[:Z])
411 all_graph_pairs = get_adjacency_from_layout(layout_rows, layout_cols, Z, device)
412 # Optionally add more pairs (e.g., random, next-nearest) for graph robustness
413 # ... (logic for adding more pairs can be inserted here) ...
415 for p1_g, p2_g in all_graph_pairs:
416 # Only compute if not already in pairwise_H_matrices from training
417 if (p1_g, p2_g) not in pairwise_H_matrices:
418 feat1_g = all_tile_features_final[p1_g].unsqueeze(0)
419 feat2_g = all_tile_features_final[p2_g].unsqueeze(0)
420 H_p1_p2 = homography_net(feat1_g, feat2_g).squeeze(0) # p2 -> p1 frame
421 pairwise_H_matrices[(p1_g, p2_g)] = H_p1_p2
422 if (p2_g, p1_g) not in pairwise_H_matrices: # And the reverse
423 feat1_g = all_tile_features_final[p1_g].unsqueeze(0)
424 feat2_g = all_tile_features_final[p2_g].unsqueeze(0)
425 H_p2_p1 = homography_net(feat2_g, feat1_g).squeeze(0) # p1 -> p2 frame
426 pairwise_H_matrices[(p2_g, p1_g)] = H_p2_p1
428 # Initial global transforms for optimization (e.g., identity or grid-based estimate)
429 initial_transforms_for_opt = [torch.eye(3, device=device) for _ in range(Z)]
430 # A simple grid layout can be a better start than pure identity for all
431 for i_init in range(Z):
432 row_idx_init, col_idx_init = i_init // layout_cols, i_init % layout_cols
433 dx_init = col_idx_init * W_tile * (1.0 - overlap_percent)
434 dy_init = row_idx_init * H_tile * (1.0 - overlap_percent)
435 translate_matrix_init = torch.eye(3, device=device)
436 translate_matrix_init[0, 2] = dx_init
437 translate_matrix_init[1, 2] = dy_init
438 initial_transforms_for_opt[i_init] = translate_matrix_init
440 global_transforms = optimize_pose_graph(pairwise_H_matrices, Z, device, initial_transforms_for_opt)
442 else: # learn=False
443 print("Inference mode: Using placeholder global transforms (grid layout).")
444 # TODO: Load pre-trained feature_encoder and homography_net
445 # For now, use a simple grid layout for global transforms
446 for i in range(Z):
447 row_idx = i // layout_cols
448 col_idx = i % layout_cols
449 dx = col_idx * W_tile * (1.0 - overlap_percent) # Assume some overlap
450 dy = row_idx * H_tile * (1.0 - overlap_percent)
452 translate_matrix = torch.eye(3, device=device)
453 translate_matrix[0, 2] = dx
454 translate_matrix[1, 2] = dy
455 global_transforms[i] = translate_matrix
457 # 5. Graph-Based Global Alignment
458 # TODO: Implement a proper graph solver (e.g., spring model or least-squares on H_matrices)
459 # This step would refine `global_transforms`. The current `global_transforms` are placeholders.
460 print("Graph-Based Global Alignment: SKIPPED (using placeholder transforms).")
463 # 6. Finalize Global Transforms and Extract Positions
464 # The global_transforms list currently holds [3,3] homography matrices for each tile,
465 # mapping its local coordinates to a common global frame.
466 # We need to ensure the coordinate system's origin (0,0) is sensible,
467 # e.g., the top-leftmost point of the stitched layout.
469 # Transform corners of all tiles to the current global frame to find bounds
470 all_corners_global_frame = []
471 # Define corners of a tile in its local coordinate system (homogeneous)
472 # (0,0), (W-1,0), (0,H-1), (W-1,H-1)
473 tile_local_corners_homog = torch.tensor([
474 [0, W_tile - 1, 0 , W_tile - 1], # x-coordinates
475 [0, 0 , H_tile - 1, H_tile - 1], # y-coordinates
476 [1, 1 , 1 , 1 ] # w-coordinates
477 ], dtype=torch.float32, device=device) # Shape: [3, 4]
479 for i in range(Z):
480 H_global_i = global_transforms[i] # Shape [3, 3]
481 # Transform local corners to global frame: H_global @ local_corners
482 corners_transformed_homog_i = H_global_i @ tile_local_corners_homog # Shape [3, 4]
484 # Perspective divide (x/w, y/w)
485 w_coords_i = corners_transformed_homog_i[2, :]
486 # Avoid division by zero for w: if w is close to 0, it's problematic.
487 # For affine transforms (like translation used in placeholder), w is always 1.
488 # For full homographies, w can vary.
489 safe_w_coords_i = torch.where(torch.abs(w_coords_i) < 1e-6, torch.sign(w_coords_i) * 1e-6 + 1e-9, w_coords_i)
491 corners_global_frame_i = corners_transformed_homog_i[:2, :] / safe_w_coords_i # Shape [2, 4] (x,y)
492 all_corners_global_frame.append(corners_global_frame_i)
494 all_corners_stacked = torch.cat(all_corners_global_frame, dim=1) # Shape [2, Z*4]
496 # Find min x and min y to define the top-left of the bounding box of all tiles
497 min_global_coords = torch.min(all_corners_stacked, dim=1).values # Shape [2] (min_x, min_y)
499 # Create an offset matrix to shift the entire layout so that min_global_coords becomes (0,0)
500 offset_x_to_origin = -min_global_coords[0]
501 offset_y_to_origin = -min_global_coords[1]
503 normalization_offset_matrix = torch.eye(3, device=device)
504 normalization_offset_matrix[0, 2] = offset_x_to_origin
505 normalization_offset_matrix[1, 2] = offset_y_to_origin
507 # Apply this normalization to all global transforms
508 final_global_transforms_list = []
509 for i in range(Z):
510 final_H_i = normalization_offset_matrix @ global_transforms[i]
511 final_global_transforms_list.append(final_H_i)
513 # Extract (x, y) positions (top-left corner of each tile) from the final homographies
514 # The translation part of H (H[0,2], H[1,2]) gives the position of the tile's origin (0,0)
515 # in the global (canvas) frame.
516 tile_positions_xy = torch.stack(
517 [H[0:2, 2] for H in final_global_transforms_list], dim=0
518 ) # Shape [Z, 2]
520 # Reshape to [1, Z, 2] as per output spec
521 output_positions = tile_positions_xy.unsqueeze(0)
523 if return_homographies:
524 # Stack the list of [3,3] homography tensors into a single [Z, 3, 3] tensor
525 output_homographies = torch.stack(final_global_transforms_list, dim=0)
526 return output_positions, output_homographies
527 else:
528 return output_positions
531if __name__ == '__main__':
532 # Example Usage (for testing within this file)
533 print("Running self_supervised_stitcher_func example...")
534 device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
536 # Create synthetic jittered tiles
537 Z_tiles = 4
538 tile_H, tile_W = 64, 64
540 # Create a base image to extract tiles from
541 base_img_H, base_img_W = tile_H * 2 + 20, tile_W * 2 + 20 # Larger base to allow jitter
542 base_image = torch.zeros(base_img_H, base_img_W, device=device)
543 # Add some features to the base image (e.g., a cross)
544 base_image[base_img_H//2 - 10 : base_img_H//2 + 10, :] = 0.7
545 base_image[:, base_img_W//2 - 10 : base_img_W//2 + 10] = 0.7
546 base_image[base_img_H//4 : 3*base_img_H//4, base_img_W//4 : 3*base_img_W//4] += 0.3
547 base_image = torch.clamp(base_image, 0, 1)
549 synthetic_tiles_list = []
550 overlap = 10 # pixels
552 # Expected layout: 2x2
553 # Tile 0: top-left, Tile 1: top-right, Tile 2: bottom-left, Tile 3: bottom-right
554 # Define ideal top-left corners for each tile in the base image
555 ideal_starts = [
556 (5, 5),
557 (5, 5 + tile_W - overlap),
558 (5 + tile_H - overlap, 5),
559 (5 + tile_H - overlap, 5 + tile_W - overlap)
560 ]
562 for i in range(Z_tiles):
563 start_y, start_x = ideal_starts[i]
565 # Add some random jitter to actual extraction
566 jitter_y = torch.randint(-3, 4, (1,)).item()
567 jitter_x = torch.randint(-3, 4, (1,)).item()
569 current_start_y = start_y + jitter_y
570 current_start_x = start_x + jitter_x
572 tile = base_image[current_start_y : current_start_y + tile_H,
573 current_start_x : current_start_x + tile_W].clone()
575 synthetic_tiles_list.append(tile)
577 synthetic_tile_stack = torch.stack(synthetic_tiles_list).to(device)
578 print(f"Synthetic tile stack shape: {synthetic_tile_stack.shape}")
580 # Test with learn=True (will be slow and likely not converge well with few iterations)
581 print("\nTesting with learn=True...")
582 tile_positions_learn, homographies_learn = self_supervised_stitcher(
583 synthetic_tile_stack.clone(),
584 learn=True,
585 num_train_iterations=10, # Very few iterations for a quick test
586 return_homographies=True,
587 layout_shape_override=(2,2) # Explicit layout for test
588 )
589 print(f"Tile positions (learn=True) shape: {tile_positions_learn.shape}")
590 print(f"Tile positions (learn=True, tile 0): {tile_positions_learn[0,0,:]}")
591 print(f"Homographies (learn=True) shape: {homographies_learn.shape}")
592 # print(f"Homography (learn=True, tile 0):\n{homographies_learn[0]}")
595 # Test with learn=False (uses placeholder grid transforms)
596 print("\nTesting with learn=False...")
597 tile_positions_infer = self_supervised_stitcher(
598 synthetic_tile_stack.clone(),
599 learn=False,
600 return_homographies=False, # Test this path too
601 layout_shape_override=(2,2)
602 )
603 print(f"Tile positions (learn=False) shape: {tile_positions_infer.shape}")
604 print(f"Tile positions (learn=False, tile 0): {tile_positions_infer[0,0,:]}")
607 # Try to visualize the layout if possible
608 try:
609 import matplotlib.pyplot as plt
610 from matplotlib.patches import Rectangle
612 def plot_layout(ax, tile_stack_cpu, positions_cpu, title, H_tile, W_tile, homographies_cpu=None):
613 ax.clear()
614 ax.set_title(title)
615 ax.set_aspect('equal', 'box')
617 all_x = []
618 all_y = []
620 for i in range(positions_cpu.shape[0]):
621 x, y = positions_cpu[i, 0], positions_cpu[i, 1]
622 all_x.extend([x, x + W_tile]) # Approximate bounding box
623 all_y.extend([y, y + H_tile])
625 # Display tile image at its position (approximate, as homography might skew)
626 # For simplicity, just place the tile image's top-left at (x,y)
627 # A more accurate plot would use the homography to warp the tile outline
629 ax.imshow(tile_stack_cpu[i], cmap='gray', alpha=0.7,
630 extent=(x, x + W_tile, y + H_tile, y)) # extent is (left, right, bottom, top)
632 # Draw a rectangle border for the tile
633 rect = Rectangle((x, y), W_tile, H_tile, linewidth=1, edgecolor='r', facecolor='none')
634 ax.add_patch(rect)
635 ax.text(x, y, f"{i}", color='cyan', fontsize=8)
637 if all_x and all_y:
638 ax.set_xlim(min(all_x) - W_tile*0.1, max(all_x) + W_tile*0.1)
639 ax.set_ylim(max(all_y) + H_tile*0.1, min(all_y) - H_tile*0.1) # Flipped for imshow
640 else: # Default if no positions
641 ax.set_xlim(0, W_tile * 2)
642 ax.set_ylim(H_tile * 2, 0)
645 fig, axes = plt.subplots(1, 2, figsize=(12, 6))
647 plot_layout(axes[0],
648 synthetic_tile_stack.cpu().numpy(),
649 tile_positions_learn.squeeze(0).cpu().detach().numpy(),
650 "Layout (learn=True, 10 iter)",
651 tile_H, tile_W,
652 homographies_learn.cpu().detach().numpy() if homographies_learn is not None else None)
654 plot_layout(axes[1],
655 synthetic_tile_stack.cpu().numpy(),
656 tile_positions_infer.squeeze(0).cpu().detach().numpy(),
657 "Layout (learn=False, Grid)",
658 tile_H, tile_W)
660 plt.tight_layout()
661 plt.savefig("self_supervised_stitcher_layout_test_output.png")
662 print("\nSaved test layout plot to self_supervised_stitcher_layout_test_output.png")
664 except ImportError:
665 print("\nMatplotlib not available. Skipping layout visualization.")
666 except Exception as e:
667 print(f"\nError during layout visualization: {e}")
669 print("Self_supervised_stitcher example finished.")