Files
autopilot/models/SiaN/model.py
2026-04-04 20:26:56 +03:00

46 lines
1.4 KiB
Python

import torch
import torch.nn as nn
from torchvision import models
class HomographyCNN6(nn.Module):
def __init__(self, input_channels=3, backbone_name="resnet18", pretrained=True, dropout_rate=0.3):
super().__init__()
backbone = getattr(models, backbone_name)(weights=models.ResNet18_Weights.IMAGENET1K_V1 if pretrained else None)
self.feature_dim = backbone.fc.in_features
backbone.fc = nn.Identity()
self.backbone = backbone
self.head = nn.Sequential(
nn.Linear(self.feature_dim * 4, 512),
nn.ReLU(inplace=True),
nn.Dropout(dropout_rate),
nn.Linear(512, 256),
nn.ReLU(inplace=True),
nn.Dropout(dropout_rate),
nn.Linear(256, 6),
)
def forward(self, img1, img2):
f1 = self.backbone(img1)
f2 = self.backbone(img2)
combined = torch.cat([f1, f2, torch.abs(f1 - f2), f1 * f2], dim=1)
return self.head(combined)
class HomographyLoss6(nn.Module):
def __init__(self):
super().__init__()
self.criterion = nn.MSELoss()
def forward(self, pred, target):
return self.criterion(pred, target)
if __name__ == "__main__":
model = HomographyCNN6()
img1 = torch.randn(2, 3, 256, 256)
img2 = torch.randn(2, 3, 256, 256)
out = model(img1, img2)
print(f"Output shape: {out.shape}, mean: {out.mean():.3f}")