552 lines
19 KiB
Python
552 lines
19 KiB
Python
from typing import Optional, Tuple
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
|
|
|
|
class HomographyCNN(nn.Module):
|
|
"""
|
|
CNN model for homography estimation between two images.
|
|
|
|
This model takes two images (Google and Yandex maps) as input and
|
|
outputs a 3x3 homography matrix that transforms one image to align with the other.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
input_channels: int = 3,
|
|
hidden_channels: int = 64,
|
|
num_blocks: int = 4,
|
|
dropout_rate: float = 0.3,
|
|
use_batch_norm: bool = True,
|
|
output_size: int = 9, # Flattened 3x3 homography matrix
|
|
):
|
|
"""
|
|
Initialize the HomographyCNN model.
|
|
|
|
Args:
|
|
input_channels: Number of input channels per image (3 for RGB)
|
|
hidden_channels: Base number of channels in the network
|
|
num_blocks: Number of convolutional blocks
|
|
dropout_rate: Dropout rate for regularization
|
|
use_batch_norm: Whether to use batch normalization
|
|
output_size: Size of output vector (9 for flattened 3x3 matrix)
|
|
"""
|
|
super().__init__()
|
|
|
|
self.input_channels = input_channels
|
|
self.hidden_channels = hidden_channels
|
|
self.num_blocks = num_blocks
|
|
self.dropout_rate = dropout_rate
|
|
self.use_batch_norm = use_batch_norm
|
|
|
|
# Feature extraction for each image separately
|
|
self.google_encoder = self._build_encoder()
|
|
self.yandex_encoder = self._build_encoder()
|
|
|
|
# Fusion layers to combine features from both images
|
|
self.fusion_layers = self._build_fusion_layers()
|
|
|
|
# Regression head for homography estimation
|
|
self.regression_head = self._build_regression_head(output_size)
|
|
|
|
# Initialize weights
|
|
self._initialize_weights()
|
|
|
|
def _build_encoder(self) -> nn.Module:
|
|
"""Build the encoder network for a single image."""
|
|
layers = []
|
|
in_channels = self.input_channels
|
|
out_channels = self.hidden_channels
|
|
|
|
# First convolutional block
|
|
layers.append(
|
|
nn.Conv2d(in_channels, out_channels, kernel_size=7, stride=2, padding=3)
|
|
)
|
|
if self.use_batch_norm:
|
|
layers.append(nn.BatchNorm2d(out_channels))
|
|
layers.append(nn.ReLU(inplace=True))
|
|
layers.append(nn.MaxPool2d(kernel_size=3, stride=2, padding=1))
|
|
|
|
# Additional convolutional blocks
|
|
for i in range(self.num_blocks):
|
|
block_in_channels = out_channels
|
|
block_out_channels = out_channels * 2 if i < 2 else out_channels
|
|
|
|
layers.append(
|
|
ResidualBlock(
|
|
in_channels=block_in_channels,
|
|
out_channels=block_out_channels,
|
|
stride=1 if i == 0 else 2,
|
|
dropout_rate=self.dropout_rate,
|
|
use_batch_norm=self.use_batch_norm,
|
|
)
|
|
)
|
|
|
|
if i < 2:
|
|
out_channels = block_out_channels
|
|
|
|
return nn.Sequential(*layers)
|
|
|
|
def _build_fusion_layers(self) -> nn.Module:
|
|
"""Build layers to fuse features from both images."""
|
|
# After encoding, each image has hidden_channels * 4 features
|
|
fused_channels = (
|
|
self.hidden_channels * 8
|
|
) # Concatenated features from both images
|
|
|
|
layers = [
|
|
# Reduce dimensionality
|
|
nn.Conv2d(
|
|
fused_channels, self.hidden_channels * 4, kernel_size=3, padding=1
|
|
),
|
|
nn.BatchNorm2d(self.hidden_channels * 4)
|
|
if self.use_batch_norm
|
|
else nn.Identity(),
|
|
nn.ReLU(inplace=True),
|
|
nn.Dropout2d(self.dropout_rate),
|
|
# Further processing
|
|
nn.Conv2d(
|
|
self.hidden_channels * 4,
|
|
self.hidden_channels * 2,
|
|
kernel_size=3,
|
|
padding=1,
|
|
),
|
|
nn.BatchNorm2d(self.hidden_channels * 2)
|
|
if self.use_batch_norm
|
|
else nn.Identity(),
|
|
nn.ReLU(inplace=True),
|
|
nn.Dropout2d(self.dropout_rate),
|
|
# Global pooling
|
|
nn.AdaptiveAvgPool2d((1, 1)),
|
|
]
|
|
|
|
return nn.Sequential(*layers)
|
|
|
|
def _build_regression_head(self, output_size: int) -> nn.Module:
|
|
"""Build the regression head for homography estimation."""
|
|
# Input size after fusion and global pooling
|
|
input_features = self.hidden_channels * 2
|
|
|
|
layers = [
|
|
nn.Flatten(),
|
|
nn.Linear(input_features, 512),
|
|
nn.BatchNorm1d(512) if self.use_batch_norm else nn.Identity(),
|
|
nn.ReLU(inplace=True),
|
|
nn.Dropout(self.dropout_rate),
|
|
nn.Linear(512, 256),
|
|
nn.BatchNorm1d(256) if self.use_batch_norm else nn.Identity(),
|
|
nn.ReLU(inplace=True),
|
|
nn.Dropout(self.dropout_rate),
|
|
nn.Linear(256, 128),
|
|
nn.BatchNorm1d(128) if self.use_batch_norm else nn.Identity(),
|
|
nn.ReLU(inplace=True),
|
|
nn.Dropout(self.dropout_rate),
|
|
nn.Linear(128, output_size),
|
|
]
|
|
|
|
return nn.Sequential(*layers)
|
|
|
|
def _initialize_weights(self):
|
|
"""Initialize model weights."""
|
|
for m in self.modules():
|
|
if isinstance(m, nn.Conv2d):
|
|
nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
|
|
if m.bias is not None:
|
|
nn.init.constant_(m.bias, 0)
|
|
elif isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d):
|
|
nn.init.constant_(m.weight, 1)
|
|
nn.init.constant_(m.bias, 0)
|
|
elif isinstance(m, nn.Linear):
|
|
nn.init.normal_(m.weight, 0, 0.01)
|
|
nn.init.constant_(m.bias, 0)
|
|
|
|
def forward(
|
|
self,
|
|
google_img: torch.Tensor,
|
|
yandex_img: torch.Tensor,
|
|
return_matrix: bool = True,
|
|
) -> torch.Tensor:
|
|
"""
|
|
Forward pass of the model.
|
|
|
|
Args:
|
|
google_img: Google map image tensor of shape (B, C, H, W)
|
|
yandex_img: Yandex map image tensor of shape (B, C, H, W)
|
|
return_matrix: If True, return 3x3 matrix; if False, return flattened vector
|
|
|
|
Returns:
|
|
Homography matrix tensor of shape (B, 3, 3) or flattened vector of shape (B, 9)
|
|
"""
|
|
# Extract features from both images
|
|
google_features = self.google_encoder(google_img)
|
|
yandex_features = self.yandex_encoder(yandex_img)
|
|
|
|
# Concatenate features along channel dimension
|
|
combined_features = torch.cat([google_features, yandex_features], dim=1)
|
|
|
|
# Fuse features
|
|
fused_features = self.fusion_layers(combined_features)
|
|
|
|
# Regression to get homography parameters
|
|
homography_flat = self.regression_head(fused_features)
|
|
|
|
if return_matrix:
|
|
# Reshape to 3x3 matrix
|
|
batch_size = homography_flat.shape[0]
|
|
homography_matrix = homography_flat.view(batch_size, 3, 3)
|
|
|
|
# Ensure the last element is 1 (homogeneous coordinate normalization)
|
|
# Add small epsilon to prevent division by zero
|
|
epsilon = 1e-8
|
|
homography_matrix = homography_matrix / (
|
|
homography_matrix[:, 2, 2].view(-1, 1, 1) + epsilon
|
|
)
|
|
|
|
return homography_matrix
|
|
else:
|
|
return homography_flat
|
|
|
|
def predict_homography(
|
|
self,
|
|
google_img: torch.Tensor,
|
|
yandex_img: torch.Tensor,
|
|
normalize: bool = True,
|
|
) -> torch.Tensor:
|
|
"""
|
|
Predict homography matrix with optional normalization.
|
|
|
|
Args:
|
|
google_img: Google map image tensor
|
|
yandex_img: Yandex map image tensor
|
|
normalize: Whether to normalize the homography matrix
|
|
|
|
Returns:
|
|
Predicted homography matrix
|
|
"""
|
|
self.eval()
|
|
with torch.no_grad():
|
|
homography = self.forward(google_img, yandex_img, return_matrix=True)
|
|
|
|
if normalize:
|
|
# Normalize so that last element is 1
|
|
homography = homography / homography[:, 2, 2].view(-1, 1, 1)
|
|
|
|
return homography
|
|
|
|
|
|
class ResidualBlock(nn.Module):
|
|
"""Residual block with optional downsampling."""
|
|
|
|
def __init__(
|
|
self,
|
|
in_channels: int,
|
|
out_channels: int,
|
|
stride: int = 1,
|
|
dropout_rate: float = 0.3,
|
|
use_batch_norm: bool = True,
|
|
):
|
|
super().__init__()
|
|
|
|
self.conv1 = nn.Conv2d(
|
|
in_channels,
|
|
out_channels,
|
|
kernel_size=3,
|
|
stride=stride,
|
|
padding=1,
|
|
bias=False,
|
|
)
|
|
self.bn1 = nn.BatchNorm2d(out_channels) if use_batch_norm else nn.Identity()
|
|
self.relu1 = nn.ReLU(inplace=True)
|
|
self.dropout1 = (
|
|
nn.Dropout2d(dropout_rate) if dropout_rate > 0 else nn.Identity()
|
|
)
|
|
|
|
self.conv2 = nn.Conv2d(
|
|
out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False
|
|
)
|
|
self.bn2 = nn.BatchNorm2d(out_channels) if use_batch_norm else nn.Identity()
|
|
self.relu2 = nn.ReLU(inplace=True)
|
|
self.dropout2 = (
|
|
nn.Dropout2d(dropout_rate) if dropout_rate > 0 else nn.Identity()
|
|
)
|
|
|
|
# Shortcut connection
|
|
self.shortcut = nn.Sequential()
|
|
if stride != 1 or in_channels != out_channels:
|
|
self.shortcut = nn.Sequential(
|
|
nn.Conv2d(
|
|
in_channels, out_channels, kernel_size=1, stride=stride, bias=False
|
|
),
|
|
nn.BatchNorm2d(out_channels) if use_batch_norm else nn.Identity(),
|
|
)
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
identity = self.shortcut(x)
|
|
|
|
out = self.conv1(x)
|
|
out = self.bn1(out)
|
|
out = self.relu1(out)
|
|
out = self.dropout1(out)
|
|
|
|
out = self.conv2(out)
|
|
out = self.bn2(out)
|
|
|
|
out += identity
|
|
out = self.relu2(out)
|
|
out = self.dropout2(out)
|
|
|
|
return out
|
|
|
|
|
|
class HomographyLoss(nn.Module):
|
|
"""
|
|
Custom loss function for homography estimation.
|
|
|
|
Combines multiple loss terms:
|
|
1. Matrix element-wise L2 loss
|
|
2. Geometric consistency loss (warping error)
|
|
3. Determinant regularization (to prevent degenerate matrices)
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
matrix_weight: float = 1.0,
|
|
geometric_weight: float = 0.5,
|
|
reg_weight: float = 0.1,
|
|
grid_size: int = 8,
|
|
):
|
|
super().__init__()
|
|
self.matrix_weight = matrix_weight
|
|
self.geometric_weight = geometric_weight
|
|
self.reg_weight = reg_weight
|
|
self.grid_size = grid_size
|
|
|
|
# Create grid of points for geometric loss
|
|
self.register_buffer(
|
|
"grid_points",
|
|
self._create_grid_points(grid_size),
|
|
persistent=False,
|
|
)
|
|
|
|
def _create_grid_points(self, grid_size: int) -> torch.Tensor:
|
|
"""Create a grid of points for geometric consistency loss."""
|
|
x = torch.linspace(-1, 1, grid_size)
|
|
y = torch.linspace(-1, 1, grid_size)
|
|
grid_y, grid_x = torch.meshgrid(y, x, indexing="ij")
|
|
grid_points = torch.stack([grid_x.flatten(), grid_y.flatten()], dim=1)
|
|
# Add homogeneous coordinate
|
|
ones = torch.ones(grid_points.shape[0], 1)
|
|
grid_points = torch.cat([grid_points, ones], dim=1)
|
|
return grid_points.T # Shape: (3, grid_size*grid_size)
|
|
|
|
def forward(
|
|
self,
|
|
pred_homography: torch.Tensor,
|
|
target_homography: torch.Tensor,
|
|
google_img: Optional[torch.Tensor] = None,
|
|
yandex_img: Optional[torch.Tensor] = None,
|
|
) -> torch.Tensor:
|
|
"""
|
|
Compute homography loss.
|
|
|
|
Args:
|
|
pred_homography: Predicted homography matrices (B, 3, 3)
|
|
target_homography: Target homography matrices (B, 3, 3)
|
|
google_img: Google images (optional, for geometric loss)
|
|
yandex_img: Yandex images (optional, for geometric loss)
|
|
|
|
Returns:
|
|
Combined loss value
|
|
"""
|
|
batch_size = pred_homography.shape[0]
|
|
|
|
# 1. Matrix element-wise L2 loss
|
|
matrix_loss = F.mse_loss(pred_homography, target_homography)
|
|
|
|
# 2. Geometric consistency loss (if images provided)
|
|
geometric_loss = torch.tensor(0.0, device=pred_homography.device)
|
|
if google_img is not None and yandex_img is not None:
|
|
# Warp grid points with predicted homography
|
|
grid_points = self.grid_points.unsqueeze(0).expand(batch_size, -1, -1)
|
|
warped_points = torch.bmm(pred_homography, grid_points)
|
|
|
|
# Normalize homogeneous coordinates
|
|
warped_points = warped_points / (warped_points[:, 2:3, :] + 1e-8)
|
|
|
|
# Warp with target homography for comparison
|
|
target_warped_points = torch.bmm(target_homography, grid_points)
|
|
target_warped_points = target_warped_points / (
|
|
target_warped_points[:, 2:3, :] + 1e-8
|
|
)
|
|
|
|
# Compute point-wise distance
|
|
geometric_loss = F.mse_loss(
|
|
warped_points[:, :2, :], target_warped_points[:, :2, :]
|
|
)
|
|
|
|
# 3. Regularization loss (prevent degenerate matrices)
|
|
# Encourage determinant to be close to 1
|
|
pred_det = torch.det(pred_homography)
|
|
reg_loss = F.mse_loss(pred_det, torch.ones_like(pred_det))
|
|
|
|
# Combine losses
|
|
total_loss = (
|
|
self.matrix_weight * matrix_loss
|
|
+ self.geometric_weight * geometric_loss
|
|
+ self.reg_weight * reg_loss
|
|
)
|
|
|
|
return total_loss
|
|
|
|
def compute_metrics(
|
|
self,
|
|
pred_homography: torch.Tensor,
|
|
target_homography: torch.Tensor,
|
|
) -> dict:
|
|
"""
|
|
Compute evaluation metrics for homography estimation.
|
|
|
|
Args:
|
|
pred_homography: Predicted homography matrices
|
|
target_homography: Target homography matrices
|
|
|
|
Returns:
|
|
Dictionary of metrics
|
|
"""
|
|
with torch.no_grad():
|
|
# Normalize matrices
|
|
pred_norm = pred_homography / pred_homography[:, 2, 2].view(-1, 1, 1)
|
|
target_norm = target_homography / target_homography[:, 2, 2].view(-1, 1, 1)
|
|
|
|
# Matrix L2 error
|
|
matrix_error = F.mse_loss(pred_norm, target_norm, reduction="none").mean(
|
|
dim=(1, 2)
|
|
)
|
|
|
|
# Corner error (warp 4 corners of the image)
|
|
corners = torch.tensor(
|
|
[[-1, -1, 1], [1, -1, 1], [1, 1, 1], [-1, 1, 1]],
|
|
dtype=torch.float32,
|
|
device=pred_homography.device,
|
|
).T # Shape: (3, 4)
|
|
|
|
corners = corners.unsqueeze(0).expand(pred_homography.shape[0], -1, -1)
|
|
|
|
pred_corners = torch.bmm(pred_norm, corners)
|
|
pred_corners = pred_corners / (pred_corners[:, 2:3, :] + 1e-8)
|
|
|
|
target_corners = torch.bmm(target_norm, corners)
|
|
target_corners = target_corners / (target_corners[:, 2:3, :] + 1e-8)
|
|
|
|
corner_error = torch.mean(
|
|
torch.norm(pred_corners[:, :2, :] - target_corners[:, :2, :], dim=1),
|
|
dim=1,
|
|
)
|
|
|
|
# Average corner error in pixels (assuming image coordinates in [-1, 1])
|
|
# Convert to pixel error if image size is known
|
|
avg_corner_error = corner_error.mean().item()
|
|
|
|
return {
|
|
"matrix_mse": matrix_error.mean().item(),
|
|
"corner_error": avg_corner_error,
|
|
"corner_error_px": avg_corner_error * 128, # Assuming 256x256 images
|
|
}
|
|
|
|
|
|
def create_homography_model(
|
|
model_type: str = "cnn",
|
|
input_size: Tuple[int, int] = (256, 256),
|
|
**kwargs,
|
|
) -> nn.Module:
|
|
"""
|
|
Factory function to create homography estimation model.
|
|
|
|
Args:
|
|
model_type: Type of model to create ('cnn' or 'resnet')
|
|
input_size: Input image size (height, width)
|
|
**kwargs: Additional arguments passed to model constructor
|
|
|
|
Returns:
|
|
Homography estimation model
|
|
"""
|
|
if model_type == "cnn":
|
|
return HomographyCNN(**kwargs)
|
|
else:
|
|
raise ValueError(f"Unknown model type: {model_type}")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
# Test the model
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
print(f"Using device: {device}")
|
|
|
|
# Create model
|
|
model = HomographyCNN(
|
|
input_channels=3,
|
|
hidden_channels=64,
|
|
num_blocks=4,
|
|
dropout_rate=0.3,
|
|
use_batch_norm=True,
|
|
).to(device)
|
|
|
|
print(
|
|
f"Model created with {sum(p.numel() for p in model.parameters()):,} parameters"
|
|
)
|
|
|
|
# Create dummy input
|
|
batch_size = 4
|
|
height, width = 256, 256
|
|
|
|
google_img = torch.randn(batch_size, 3, height, width).to(device)
|
|
yandex_img = torch.randn(batch_size, 3, height, width).to(device)
|
|
|
|
# Test forward pass
|
|
print("\nTesting forward pass...")
|
|
output = model(google_img, yandex_img, return_matrix=True)
|
|
print(f"Output shape: {output.shape}") # Should be (4, 3, 3)
|
|
print(f"Sample output:\n{output[0]}")
|
|
|
|
# Test prediction
|
|
print("\nTesting prediction...")
|
|
pred = model.predict_homography(google_img, yandex_img)
|
|
print(f"Prediction shape: {pred.shape}")
|
|
print(f"Last element (should be ~1): {pred[0, 2, 2]:.6f}")
|
|
|
|
# Test loss function
|
|
print("\nTesting loss function...")
|
|
target_homography = torch.eye(3).unsqueeze(0).expand(batch_size, -1, -1).to(device)
|
|
loss_fn = HomographyLoss(
|
|
matrix_weight=1.0,
|
|
geometric_weight=0.5,
|
|
reg_weight=0.1,
|
|
grid_size=8,
|
|
).to(device)
|
|
|
|
loss = loss_fn(output, target_homography, google_img, yandex_img)
|
|
print(f"Loss value: {loss.item():.6f}")
|
|
|
|
# Test metrics
|
|
print("\nTesting metrics...")
|
|
metrics = loss_fn.compute_metrics(output, target_homography)
|
|
for key, value in metrics.items():
|
|
print(f"{key}: {value:.6f}")
|
|
|
|
# Test model factory
|
|
print("\nTesting model factory...")
|
|
model2 = create_homography_model(
|
|
model_type="cnn",
|
|
input_size=(256, 256),
|
|
input_channels=3,
|
|
hidden_channels=32,
|
|
num_blocks=3,
|
|
).to(device)
|
|
|
|
print(
|
|
f"Model2 created with {sum(p.numel() for p in model2.parameters()):,} parameters"
|
|
)
|
|
|
|
print("\nAll tests completed successfully!")
|