46 lines
1.4 KiB
Python
46 lines
1.4 KiB
Python
import torch
|
|
import torch.nn as nn
|
|
from torchvision import models
|
|
|
|
|
|
class HomographyCNN6(nn.Module):
|
|
def __init__(self, input_channels=3, backbone_name="resnet18", pretrained=True, dropout_rate=0.3):
|
|
super().__init__()
|
|
backbone = getattr(models, backbone_name)(weights=models.ResNet18_Weights.IMAGENET1K_V1 if pretrained else None)
|
|
self.feature_dim = backbone.fc.in_features
|
|
backbone.fc = nn.Identity()
|
|
self.backbone = backbone
|
|
|
|
self.head = nn.Sequential(
|
|
nn.Linear(self.feature_dim * 4, 512),
|
|
nn.ReLU(inplace=True),
|
|
nn.Dropout(dropout_rate),
|
|
nn.Linear(512, 256),
|
|
nn.ReLU(inplace=True),
|
|
nn.Dropout(dropout_rate),
|
|
nn.Linear(256, 6),
|
|
)
|
|
|
|
def forward(self, img1, img2):
|
|
f1 = self.backbone(img1)
|
|
f2 = self.backbone(img2)
|
|
combined = torch.cat([f1, f2, torch.abs(f1 - f2), f1 * f2], dim=1)
|
|
return self.head(combined)
|
|
|
|
|
|
class HomographyLoss6(nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.criterion = nn.MSELoss()
|
|
|
|
def forward(self, pred, target):
|
|
return self.criterion(pred, target)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
model = HomographyCNN6()
|
|
img1 = torch.randn(2, 3, 256, 256)
|
|
img2 = torch.randn(2, 3, 256, 256)
|
|
out = model(img1, img2)
|
|
print(f"Output shape: {out.shape}, mean: {out.mean():.3f}")
|