import torch import torch.nn as nn from torchvision import models class HomographyCNN6(nn.Module): def __init__(self, input_channels=3, backbone_name="resnet18", pretrained=True, dropout_rate=0.3): super().__init__() backbone = getattr(models, backbone_name)(weights=models.ResNet18_Weights.IMAGENET1K_V1 if pretrained else None) self.feature_dim = backbone.fc.in_features backbone.fc = nn.Identity() self.backbone = backbone self.head = nn.Sequential( nn.Linear(self.feature_dim * 4, 512), nn.ReLU(inplace=True), nn.Dropout(dropout_rate), nn.Linear(512, 256), nn.ReLU(inplace=True), nn.Dropout(dropout_rate), nn.Linear(256, 6), ) def forward(self, img1, img2): f1 = self.backbone(img1) f2 = self.backbone(img2) combined = torch.cat([f1, f2, torch.abs(f1 - f2), f1 * f2], dim=1) return self.head(combined) class HomographyLoss6(nn.Module): def __init__(self): super().__init__() self.criterion = nn.MSELoss() def forward(self, pred, target): return self.criterion(pred, target) if __name__ == "__main__": model = HomographyCNN6() img1 = torch.randn(2, 3, 256, 256) img2 = torch.randn(2, 3, 256, 256) out = model(img1, img2) print(f"Output shape: {out.shape}, mean: {out.mean():.3f}")