write SiaN
This commit is contained in:
152
models/SiaN/model.py
Normal file
152
models/SiaN/model.py
Normal file
@@ -0,0 +1,152 @@
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torchvision import models
|
||||
|
||||
|
||||
class HomographyCNN(nn.Module):
|
||||
"""
|
||||
Model for estimating homography matrix (3x3) between two images.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_channels: int = 3,
|
||||
backbone_name: str = "resnet18",
|
||||
pretrained: bool = True,
|
||||
dropout_rate: float = 0.3,
|
||||
use_batch_norm: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.input_channels = input_channels
|
||||
self.backbone_name = backbone_name
|
||||
self.pretrained = pretrained
|
||||
self.dropout_rate = dropout_rate
|
||||
self.use_batch_norm = use_batch_norm
|
||||
|
||||
backbone = self._create_backbone(backbone_name, pretrained)
|
||||
|
||||
self.feature_dim = backbone.fc.in_features
|
||||
backbone.fc = nn.Identity()
|
||||
self.backbone = backbone
|
||||
|
||||
compare_input_dim = self.feature_dim * 4
|
||||
|
||||
layers = [
|
||||
nn.Linear(compare_input_dim, 512),
|
||||
nn.BatchNorm1d(512) if use_batch_norm else nn.Identity(),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Dropout(dropout_rate),
|
||||
|
||||
nn.Linear(512, 256),
|
||||
nn.BatchNorm1d(256) if use_batch_norm else nn.Identity(),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Dropout(dropout_rate),
|
||||
|
||||
nn.Linear(256, 9),
|
||||
]
|
||||
self.head = nn.Sequential(*layers)
|
||||
|
||||
def _create_backbone(self, name: str, pretrained: bool) -> nn.Module:
|
||||
name = name.lower()
|
||||
if name == "resnet18":
|
||||
model = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1 if pretrained else None)
|
||||
elif name == "resnet34":
|
||||
model = models.resnet34(weights=models.ResNet34_Weights.IMAGENET1K_V1 if pretrained else None)
|
||||
else:
|
||||
raise ValueError(f"Unsupported backbone: {name}")
|
||||
if self.input_channels != 3:
|
||||
old_conv = model.conv1
|
||||
model.conv1 = nn.Conv2d(
|
||||
self.input_channels,
|
||||
old_conv.out_channels,
|
||||
kernel_size=old_conv.kernel_size,
|
||||
stride=old_conv.stride,
|
||||
padding=old_conv.padding,
|
||||
bias=old_conv.bias is not None,
|
||||
)
|
||||
return model
|
||||
|
||||
def _extract_features(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return self.backbone(x)
|
||||
|
||||
def forward(self, img1: torch.Tensor, img2: torch.Tensor) -> torch.Tensor:
|
||||
f1 = self._extract_features(img1)
|
||||
f2 = self._extract_features(img2)
|
||||
|
||||
diff = torch.abs(f1 - f2)
|
||||
prod = f1 * f2
|
||||
combined = torch.cat([f1, f2, diff, prod], dim=1)
|
||||
|
||||
h = self.head(combined)
|
||||
h = h.view(-1, 3, 3)
|
||||
return h
|
||||
|
||||
def predict_homography(self, img1: torch.Tensor, img2: torch.Tensor) -> torch.Tensor:
|
||||
was_training = self.training
|
||||
self.eval()
|
||||
with torch.no_grad():
|
||||
h = self.forward(img1, img2)
|
||||
if was_training:
|
||||
self.train()
|
||||
return h
|
||||
|
||||
|
||||
class HomographyLoss(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.criterion = nn.MSELoss()
|
||||
|
||||
def forward(self, pred_homography: torch.Tensor, target_homography: torch.Tensor) -> torch.Tensor:
|
||||
return self.criterion(pred_homography, target_homography)
|
||||
|
||||
|
||||
def create_homography_model(
|
||||
model_type: str = "backbone",
|
||||
input_size: Tuple[int, int] = (256, 256),
|
||||
**kwargs,
|
||||
) -> nn.Module:
|
||||
if model_type == "backbone":
|
||||
return HomographyCNN(**kwargs)
|
||||
else:
|
||||
raise ValueError(f"Unknown model type: {model_type}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
print(f"Using device: {device}")
|
||||
|
||||
model = HomographyCNN(
|
||||
input_channels=3,
|
||||
backbone_name="resnet18",
|
||||
pretrained=True,
|
||||
dropout_rate=0.3,
|
||||
use_batch_norm=True,
|
||||
).to(device)
|
||||
|
||||
print(f"Model created with {sum(p.numel() for p in model.parameters()):,} parameters")
|
||||
|
||||
batch_size = 4
|
||||
height, width = 256, 256
|
||||
|
||||
img1 = torch.randn(batch_size, 3, height, width).to(device)
|
||||
img2 = torch.randn(batch_size, 3, height, width).to(device)
|
||||
|
||||
print("\nTesting forward pass...")
|
||||
output = model(img1, img2)
|
||||
print(f"Output shape: {output.shape}")
|
||||
|
||||
print("\nTesting prediction...")
|
||||
pred = model.predict_homography(img1, img2)
|
||||
print(f"Prediction shape: {pred.shape}")
|
||||
|
||||
print("\nTesting loss function...")
|
||||
target = torch.eye(3).unsqueeze(0).expand(batch_size, -1, -1).to(device)
|
||||
loss_fn = HomographyLoss().to(device)
|
||||
loss = loss_fn(output, target)
|
||||
print(f"Loss value: {loss.item():.6f}")
|
||||
|
||||
print("\nAll tests completed successfully!")
|
||||
Reference in New Issue
Block a user