Files
autopilot/models/SiaN/model.py
2026-04-04 19:41:16 +03:00

153 lines
4.6 KiB
Python

from typing import Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import models
class HomographyCNN(nn.Module):
"""
Model for estimating homography matrix (3x3) between two images.
"""
def __init__(
self,
input_channels: int = 3,
backbone_name: str = "resnet18",
pretrained: bool = True,
dropout_rate: float = 0.3,
use_batch_norm: bool = True,
):
super().__init__()
self.input_channels = input_channels
self.backbone_name = backbone_name
self.pretrained = pretrained
self.dropout_rate = dropout_rate
self.use_batch_norm = use_batch_norm
backbone = self._create_backbone(backbone_name, pretrained)
self.feature_dim = backbone.fc.in_features
backbone.fc = nn.Identity()
self.backbone = backbone
compare_input_dim = self.feature_dim * 4
layers = [
nn.Linear(compare_input_dim, 512),
nn.BatchNorm1d(512) if use_batch_norm else nn.Identity(),
nn.ReLU(inplace=True),
nn.Dropout(dropout_rate),
nn.Linear(512, 256),
nn.BatchNorm1d(256) if use_batch_norm else nn.Identity(),
nn.ReLU(inplace=True),
nn.Dropout(dropout_rate),
nn.Linear(256, 9),
]
self.head = nn.Sequential(*layers)
def _create_backbone(self, name: str, pretrained: bool) -> nn.Module:
name = name.lower()
if name == "resnet18":
model = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1 if pretrained else None)
elif name == "resnet34":
model = models.resnet34(weights=models.ResNet34_Weights.IMAGENET1K_V1 if pretrained else None)
else:
raise ValueError(f"Unsupported backbone: {name}")
if self.input_channels != 3:
old_conv = model.conv1
model.conv1 = nn.Conv2d(
self.input_channels,
old_conv.out_channels,
kernel_size=old_conv.kernel_size,
stride=old_conv.stride,
padding=old_conv.padding,
bias=old_conv.bias is not None,
)
return model
def _extract_features(self, x: torch.Tensor) -> torch.Tensor:
return self.backbone(x)
def forward(self, img1: torch.Tensor, img2: torch.Tensor) -> torch.Tensor:
f1 = self._extract_features(img1)
f2 = self._extract_features(img2)
diff = torch.abs(f1 - f2)
prod = f1 * f2
combined = torch.cat([f1, f2, diff, prod], dim=1)
h = self.head(combined)
h = h.view(-1, 3, 3)
return h
def predict_homography(self, img1: torch.Tensor, img2: torch.Tensor) -> torch.Tensor:
was_training = self.training
self.eval()
with torch.no_grad():
h = self.forward(img1, img2)
if was_training:
self.train()
return h
class HomographyLoss(nn.Module):
def __init__(self):
super().__init__()
self.criterion = nn.MSELoss()
def forward(self, pred_homography: torch.Tensor, target_homography: torch.Tensor) -> torch.Tensor:
return self.criterion(pred_homography, target_homography)
def create_homography_model(
model_type: str = "backbone",
input_size: Tuple[int, int] = (256, 256),
**kwargs,
) -> nn.Module:
if model_type == "backbone":
return HomographyCNN(**kwargs)
else:
raise ValueError(f"Unknown model type: {model_type}")
if __name__ == "__main__":
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
model = HomographyCNN(
input_channels=3,
backbone_name="resnet18",
pretrained=True,
dropout_rate=0.3,
use_batch_norm=True,
).to(device)
print(f"Model created with {sum(p.numel() for p in model.parameters()):,} parameters")
batch_size = 4
height, width = 256, 256
img1 = torch.randn(batch_size, 3, height, width).to(device)
img2 = torch.randn(batch_size, 3, height, width).to(device)
print("\nTesting forward pass...")
output = model(img1, img2)
print(f"Output shape: {output.shape}")
print("\nTesting prediction...")
pred = model.predict_homography(img1, img2)
print(f"Prediction shape: {pred.shape}")
print("\nTesting loss function...")
target = torch.eye(3).unsqueeze(0).expand(batch_size, -1, -1).to(device)
loss_fn = HomographyLoss().to(device)
loss = loss_fn(output, target)
print(f"Loss value: {loss.item():.6f}")
print("\nAll tests completed successfully!")