153 lines
4.6 KiB
Python
153 lines
4.6 KiB
Python
from typing import Tuple
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
from torchvision import models
|
|
|
|
|
|
class HomographyCNN(nn.Module):
|
|
"""
|
|
Model for estimating homography matrix (3x3) between two images.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
input_channels: int = 3,
|
|
backbone_name: str = "resnet18",
|
|
pretrained: bool = True,
|
|
dropout_rate: float = 0.3,
|
|
use_batch_norm: bool = True,
|
|
):
|
|
super().__init__()
|
|
|
|
self.input_channels = input_channels
|
|
self.backbone_name = backbone_name
|
|
self.pretrained = pretrained
|
|
self.dropout_rate = dropout_rate
|
|
self.use_batch_norm = use_batch_norm
|
|
|
|
backbone = self._create_backbone(backbone_name, pretrained)
|
|
|
|
self.feature_dim = backbone.fc.in_features
|
|
backbone.fc = nn.Identity()
|
|
self.backbone = backbone
|
|
|
|
compare_input_dim = self.feature_dim * 4
|
|
|
|
layers = [
|
|
nn.Linear(compare_input_dim, 512),
|
|
nn.BatchNorm1d(512) if use_batch_norm else nn.Identity(),
|
|
nn.ReLU(inplace=True),
|
|
nn.Dropout(dropout_rate),
|
|
|
|
nn.Linear(512, 256),
|
|
nn.BatchNorm1d(256) if use_batch_norm else nn.Identity(),
|
|
nn.ReLU(inplace=True),
|
|
nn.Dropout(dropout_rate),
|
|
|
|
nn.Linear(256, 9),
|
|
]
|
|
self.head = nn.Sequential(*layers)
|
|
|
|
def _create_backbone(self, name: str, pretrained: bool) -> nn.Module:
|
|
name = name.lower()
|
|
if name == "resnet18":
|
|
model = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1 if pretrained else None)
|
|
elif name == "resnet34":
|
|
model = models.resnet34(weights=models.ResNet34_Weights.IMAGENET1K_V1 if pretrained else None)
|
|
else:
|
|
raise ValueError(f"Unsupported backbone: {name}")
|
|
if self.input_channels != 3:
|
|
old_conv = model.conv1
|
|
model.conv1 = nn.Conv2d(
|
|
self.input_channels,
|
|
old_conv.out_channels,
|
|
kernel_size=old_conv.kernel_size,
|
|
stride=old_conv.stride,
|
|
padding=old_conv.padding,
|
|
bias=old_conv.bias is not None,
|
|
)
|
|
return model
|
|
|
|
def _extract_features(self, x: torch.Tensor) -> torch.Tensor:
|
|
return self.backbone(x)
|
|
|
|
def forward(self, img1: torch.Tensor, img2: torch.Tensor) -> torch.Tensor:
|
|
f1 = self._extract_features(img1)
|
|
f2 = self._extract_features(img2)
|
|
|
|
diff = torch.abs(f1 - f2)
|
|
prod = f1 * f2
|
|
combined = torch.cat([f1, f2, diff, prod], dim=1)
|
|
|
|
h = self.head(combined)
|
|
h = h.view(-1, 3, 3)
|
|
return h
|
|
|
|
def predict_homography(self, img1: torch.Tensor, img2: torch.Tensor) -> torch.Tensor:
|
|
was_training = self.training
|
|
self.eval()
|
|
with torch.no_grad():
|
|
h = self.forward(img1, img2)
|
|
if was_training:
|
|
self.train()
|
|
return h
|
|
|
|
|
|
class HomographyLoss(nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.criterion = nn.MSELoss()
|
|
|
|
def forward(self, pred_homography: torch.Tensor, target_homography: torch.Tensor) -> torch.Tensor:
|
|
return self.criterion(pred_homography, target_homography)
|
|
|
|
|
|
def create_homography_model(
|
|
model_type: str = "backbone",
|
|
input_size: Tuple[int, int] = (256, 256),
|
|
**kwargs,
|
|
) -> nn.Module:
|
|
if model_type == "backbone":
|
|
return HomographyCNN(**kwargs)
|
|
else:
|
|
raise ValueError(f"Unknown model type: {model_type}")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
print(f"Using device: {device}")
|
|
|
|
model = HomographyCNN(
|
|
input_channels=3,
|
|
backbone_name="resnet18",
|
|
pretrained=True,
|
|
dropout_rate=0.3,
|
|
use_batch_norm=True,
|
|
).to(device)
|
|
|
|
print(f"Model created with {sum(p.numel() for p in model.parameters()):,} parameters")
|
|
|
|
batch_size = 4
|
|
height, width = 256, 256
|
|
|
|
img1 = torch.randn(batch_size, 3, height, width).to(device)
|
|
img2 = torch.randn(batch_size, 3, height, width).to(device)
|
|
|
|
print("\nTesting forward pass...")
|
|
output = model(img1, img2)
|
|
print(f"Output shape: {output.shape}")
|
|
|
|
print("\nTesting prediction...")
|
|
pred = model.predict_homography(img1, img2)
|
|
print(f"Prediction shape: {pred.shape}")
|
|
|
|
print("\nTesting loss function...")
|
|
target = torch.eye(3).unsqueeze(0).expand(batch_size, -1, -1).to(device)
|
|
loss_fn = HomographyLoss().to(device)
|
|
loss = loss_fn(output, target)
|
|
print(f"Loss value: {loss.item():.6f}")
|
|
|
|
print("\nAll tests completed successfully!")
|