from typing import Tuple import torch import torch.nn as nn import torch.nn.functional as F from torchvision import models class HomographyCNN(nn.Module): """ Model for estimating homography matrix (3x3) between two images. """ def __init__( self, input_channels: int = 3, backbone_name: str = "resnet18", pretrained: bool = True, dropout_rate: float = 0.3, use_batch_norm: bool = True, ): super().__init__() self.input_channels = input_channels self.backbone_name = backbone_name self.pretrained = pretrained self.dropout_rate = dropout_rate self.use_batch_norm = use_batch_norm backbone = self._create_backbone(backbone_name, pretrained) self.feature_dim = backbone.fc.in_features backbone.fc = nn.Identity() self.backbone = backbone compare_input_dim = self.feature_dim * 4 layers = [ nn.Linear(compare_input_dim, 512), nn.BatchNorm1d(512) if use_batch_norm else nn.Identity(), nn.ReLU(inplace=True), nn.Dropout(dropout_rate), nn.Linear(512, 256), nn.BatchNorm1d(256) if use_batch_norm else nn.Identity(), nn.ReLU(inplace=True), nn.Dropout(dropout_rate), nn.Linear(256, 9), ] self.head = nn.Sequential(*layers) def _create_backbone(self, name: str, pretrained: bool) -> nn.Module: name = name.lower() if name == "resnet18": model = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1 if pretrained else None) elif name == "resnet34": model = models.resnet34(weights=models.ResNet34_Weights.IMAGENET1K_V1 if pretrained else None) else: raise ValueError(f"Unsupported backbone: {name}") if self.input_channels != 3: old_conv = model.conv1 model.conv1 = nn.Conv2d( self.input_channels, old_conv.out_channels, kernel_size=old_conv.kernel_size, stride=old_conv.stride, padding=old_conv.padding, bias=old_conv.bias is not None, ) return model def _extract_features(self, x: torch.Tensor) -> torch.Tensor: return self.backbone(x) def forward(self, img1: torch.Tensor, img2: torch.Tensor) -> torch.Tensor: f1 = self._extract_features(img1) f2 = self._extract_features(img2) diff = torch.abs(f1 - f2) prod = f1 * f2 combined = torch.cat([f1, f2, diff, prod], dim=1) h = self.head(combined) h = h.view(-1, 3, 3) return h def predict_homography(self, img1: torch.Tensor, img2: torch.Tensor) -> torch.Tensor: was_training = self.training self.eval() with torch.no_grad(): h = self.forward(img1, img2) if was_training: self.train() return h class HomographyLoss(nn.Module): def __init__(self): super().__init__() self.criterion = nn.MSELoss() def forward(self, pred_homography: torch.Tensor, target_homography: torch.Tensor) -> torch.Tensor: return self.criterion(pred_homography, target_homography) def create_homography_model( model_type: str = "backbone", input_size: Tuple[int, int] = (256, 256), **kwargs, ) -> nn.Module: if model_type == "backbone": return HomographyCNN(**kwargs) else: raise ValueError(f"Unknown model type: {model_type}") if __name__ == "__main__": device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"Using device: {device}") model = HomographyCNN( input_channels=3, backbone_name="resnet18", pretrained=True, dropout_rate=0.3, use_batch_norm=True, ).to(device) print(f"Model created with {sum(p.numel() for p in model.parameters()):,} parameters") batch_size = 4 height, width = 256, 256 img1 = torch.randn(batch_size, 3, height, width).to(device) img2 = torch.randn(batch_size, 3, height, width).to(device) print("\nTesting forward pass...") output = model(img1, img2) print(f"Output shape: {output.shape}") print("\nTesting prediction...") pred = model.predict_homography(img1, img2) print(f"Prediction shape: {pred.shape}") print("\nTesting loss function...") target = torch.eye(3).unsqueeze(0).expand(batch_size, -1, -1).to(device) loss_fn = HomographyLoss().to(device) loss = loss_fn(output, target) print(f"Loss value: {loss.item():.6f}") print("\nAll tests completed successfully!")