Files
autopilot/models/SiaN/homography_cnn.py
2026-02-16 19:07:31 +03:00

552 lines
19 KiB
Python

from typing import Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
class HomographyCNN(nn.Module):
"""
CNN model for homography estimation between two images.
This model takes two images (Google and Yandex maps) as input and
outputs a 3x3 homography matrix that transforms one image to align with the other.
"""
def __init__(
self,
input_channels: int = 3,
hidden_channels: int = 64,
num_blocks: int = 4,
dropout_rate: float = 0.3,
use_batch_norm: bool = True,
output_size: int = 9, # Flattened 3x3 homography matrix
):
"""
Initialize the HomographyCNN model.
Args:
input_channels: Number of input channels per image (3 for RGB)
hidden_channels: Base number of channels in the network
num_blocks: Number of convolutional blocks
dropout_rate: Dropout rate for regularization
use_batch_norm: Whether to use batch normalization
output_size: Size of output vector (9 for flattened 3x3 matrix)
"""
super().__init__()
self.input_channels = input_channels
self.hidden_channels = hidden_channels
self.num_blocks = num_blocks
self.dropout_rate = dropout_rate
self.use_batch_norm = use_batch_norm
# Feature extraction for each image separately
self.google_encoder = self._build_encoder()
self.yandex_encoder = self._build_encoder()
# Fusion layers to combine features from both images
self.fusion_layers = self._build_fusion_layers()
# Regression head for homography estimation
self.regression_head = self._build_regression_head(output_size)
# Initialize weights
self._initialize_weights()
def _build_encoder(self) -> nn.Module:
"""Build the encoder network for a single image."""
layers = []
in_channels = self.input_channels
out_channels = self.hidden_channels
# First convolutional block
layers.append(
nn.Conv2d(in_channels, out_channels, kernel_size=7, stride=2, padding=3)
)
if self.use_batch_norm:
layers.append(nn.BatchNorm2d(out_channels))
layers.append(nn.ReLU(inplace=True))
layers.append(nn.MaxPool2d(kernel_size=3, stride=2, padding=1))
# Additional convolutional blocks
for i in range(self.num_blocks):
block_in_channels = out_channels
block_out_channels = out_channels * 2 if i < 2 else out_channels
layers.append(
ResidualBlock(
in_channels=block_in_channels,
out_channels=block_out_channels,
stride=1 if i == 0 else 2,
dropout_rate=self.dropout_rate,
use_batch_norm=self.use_batch_norm,
)
)
if i < 2:
out_channels = block_out_channels
return nn.Sequential(*layers)
def _build_fusion_layers(self) -> nn.Module:
"""Build layers to fuse features from both images."""
# After encoding, each image has hidden_channels * 4 features
fused_channels = (
self.hidden_channels * 8
) # Concatenated features from both images
layers = [
# Reduce dimensionality
nn.Conv2d(
fused_channels, self.hidden_channels * 4, kernel_size=3, padding=1
),
nn.BatchNorm2d(self.hidden_channels * 4)
if self.use_batch_norm
else nn.Identity(),
nn.ReLU(inplace=True),
nn.Dropout2d(self.dropout_rate),
# Further processing
nn.Conv2d(
self.hidden_channels * 4,
self.hidden_channels * 2,
kernel_size=3,
padding=1,
),
nn.BatchNorm2d(self.hidden_channels * 2)
if self.use_batch_norm
else nn.Identity(),
nn.ReLU(inplace=True),
nn.Dropout2d(self.dropout_rate),
# Global pooling
nn.AdaptiveAvgPool2d((1, 1)),
]
return nn.Sequential(*layers)
def _build_regression_head(self, output_size: int) -> nn.Module:
"""Build the regression head for homography estimation."""
# Input size after fusion and global pooling
input_features = self.hidden_channels * 2
layers = [
nn.Flatten(),
nn.Linear(input_features, 512),
nn.BatchNorm1d(512) if self.use_batch_norm else nn.Identity(),
nn.ReLU(inplace=True),
nn.Dropout(self.dropout_rate),
nn.Linear(512, 256),
nn.BatchNorm1d(256) if self.use_batch_norm else nn.Identity(),
nn.ReLU(inplace=True),
nn.Dropout(self.dropout_rate),
nn.Linear(256, 128),
nn.BatchNorm1d(128) if self.use_batch_norm else nn.Identity(),
nn.ReLU(inplace=True),
nn.Dropout(self.dropout_rate),
nn.Linear(128, output_size),
]
return nn.Sequential(*layers)
def _initialize_weights(self):
"""Initialize model weights."""
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.Linear):
nn.init.normal_(m.weight, 0, 0.01)
nn.init.constant_(m.bias, 0)
def forward(
self,
google_img: torch.Tensor,
yandex_img: torch.Tensor,
return_matrix: bool = True,
) -> torch.Tensor:
"""
Forward pass of the model.
Args:
google_img: Google map image tensor of shape (B, C, H, W)
yandex_img: Yandex map image tensor of shape (B, C, H, W)
return_matrix: If True, return 3x3 matrix; if False, return flattened vector
Returns:
Homography matrix tensor of shape (B, 3, 3) or flattened vector of shape (B, 9)
"""
# Extract features from both images
google_features = self.google_encoder(google_img)
yandex_features = self.yandex_encoder(yandex_img)
# Concatenate features along channel dimension
combined_features = torch.cat([google_features, yandex_features], dim=1)
# Fuse features
fused_features = self.fusion_layers(combined_features)
# Regression to get homography parameters
homography_flat = self.regression_head(fused_features)
if return_matrix:
# Reshape to 3x3 matrix
batch_size = homography_flat.shape[0]
homography_matrix = homography_flat.view(batch_size, 3, 3)
# Ensure the last element is 1 (homogeneous coordinate normalization)
# Add small epsilon to prevent division by zero
epsilon = 1e-8
homography_matrix = homography_matrix / (
homography_matrix[:, 2, 2].view(-1, 1, 1) + epsilon
)
return homography_matrix
else:
return homography_flat
def predict_homography(
self,
google_img: torch.Tensor,
yandex_img: torch.Tensor,
normalize: bool = True,
) -> torch.Tensor:
"""
Predict homography matrix with optional normalization.
Args:
google_img: Google map image tensor
yandex_img: Yandex map image tensor
normalize: Whether to normalize the homography matrix
Returns:
Predicted homography matrix
"""
self.eval()
with torch.no_grad():
homography = self.forward(google_img, yandex_img, return_matrix=True)
if normalize:
# Normalize so that last element is 1
homography = homography / homography[:, 2, 2].view(-1, 1, 1)
return homography
class ResidualBlock(nn.Module):
"""Residual block with optional downsampling."""
def __init__(
self,
in_channels: int,
out_channels: int,
stride: int = 1,
dropout_rate: float = 0.3,
use_batch_norm: bool = True,
):
super().__init__()
self.conv1 = nn.Conv2d(
in_channels,
out_channels,
kernel_size=3,
stride=stride,
padding=1,
bias=False,
)
self.bn1 = nn.BatchNorm2d(out_channels) if use_batch_norm else nn.Identity()
self.relu1 = nn.ReLU(inplace=True)
self.dropout1 = (
nn.Dropout2d(dropout_rate) if dropout_rate > 0 else nn.Identity()
)
self.conv2 = nn.Conv2d(
out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False
)
self.bn2 = nn.BatchNorm2d(out_channels) if use_batch_norm else nn.Identity()
self.relu2 = nn.ReLU(inplace=True)
self.dropout2 = (
nn.Dropout2d(dropout_rate) if dropout_rate > 0 else nn.Identity()
)
# Shortcut connection
self.shortcut = nn.Sequential()
if stride != 1 or in_channels != out_channels:
self.shortcut = nn.Sequential(
nn.Conv2d(
in_channels, out_channels, kernel_size=1, stride=stride, bias=False
),
nn.BatchNorm2d(out_channels) if use_batch_norm else nn.Identity(),
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
identity = self.shortcut(x)
out = self.conv1(x)
out = self.bn1(out)
out = self.relu1(out)
out = self.dropout1(out)
out = self.conv2(out)
out = self.bn2(out)
out += identity
out = self.relu2(out)
out = self.dropout2(out)
return out
class HomographyLoss(nn.Module):
"""
Custom loss function for homography estimation.
Combines multiple loss terms:
1. Matrix element-wise L2 loss
2. Geometric consistency loss (warping error)
3. Determinant regularization (to prevent degenerate matrices)
"""
def __init__(
self,
matrix_weight: float = 1.0,
geometric_weight: float = 0.5,
reg_weight: float = 0.1,
grid_size: int = 8,
):
super().__init__()
self.matrix_weight = matrix_weight
self.geometric_weight = geometric_weight
self.reg_weight = reg_weight
self.grid_size = grid_size
# Create grid of points for geometric loss
self.register_buffer(
"grid_points",
self._create_grid_points(grid_size),
persistent=False,
)
def _create_grid_points(self, grid_size: int) -> torch.Tensor:
"""Create a grid of points for geometric consistency loss."""
x = torch.linspace(-1, 1, grid_size)
y = torch.linspace(-1, 1, grid_size)
grid_y, grid_x = torch.meshgrid(y, x, indexing="ij")
grid_points = torch.stack([grid_x.flatten(), grid_y.flatten()], dim=1)
# Add homogeneous coordinate
ones = torch.ones(grid_points.shape[0], 1)
grid_points = torch.cat([grid_points, ones], dim=1)
return grid_points.T # Shape: (3, grid_size*grid_size)
def forward(
self,
pred_homography: torch.Tensor,
target_homography: torch.Tensor,
google_img: Optional[torch.Tensor] = None,
yandex_img: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
Compute homography loss.
Args:
pred_homography: Predicted homography matrices (B, 3, 3)
target_homography: Target homography matrices (B, 3, 3)
google_img: Google images (optional, for geometric loss)
yandex_img: Yandex images (optional, for geometric loss)
Returns:
Combined loss value
"""
batch_size = pred_homography.shape[0]
# 1. Matrix element-wise L2 loss
matrix_loss = F.mse_loss(pred_homography, target_homography)
# 2. Geometric consistency loss (if images provided)
geometric_loss = torch.tensor(0.0, device=pred_homography.device)
if google_img is not None and yandex_img is not None:
# Warp grid points with predicted homography
grid_points = self.grid_points.unsqueeze(0).expand(batch_size, -1, -1)
warped_points = torch.bmm(pred_homography, grid_points)
# Normalize homogeneous coordinates
warped_points = warped_points / (warped_points[:, 2:3, :] + 1e-8)
# Warp with target homography for comparison
target_warped_points = torch.bmm(target_homography, grid_points)
target_warped_points = target_warped_points / (
target_warped_points[:, 2:3, :] + 1e-8
)
# Compute point-wise distance
geometric_loss = F.mse_loss(
warped_points[:, :2, :], target_warped_points[:, :2, :]
)
# 3. Regularization loss (prevent degenerate matrices)
# Encourage determinant to be close to 1
pred_det = torch.det(pred_homography)
reg_loss = F.mse_loss(pred_det, torch.ones_like(pred_det))
# Combine losses
total_loss = (
self.matrix_weight * matrix_loss
+ self.geometric_weight * geometric_loss
+ self.reg_weight * reg_loss
)
return total_loss
def compute_metrics(
self,
pred_homography: torch.Tensor,
target_homography: torch.Tensor,
) -> dict:
"""
Compute evaluation metrics for homography estimation.
Args:
pred_homography: Predicted homography matrices
target_homography: Target homography matrices
Returns:
Dictionary of metrics
"""
with torch.no_grad():
# Normalize matrices
pred_norm = pred_homography / pred_homography[:, 2, 2].view(-1, 1, 1)
target_norm = target_homography / target_homography[:, 2, 2].view(-1, 1, 1)
# Matrix L2 error
matrix_error = F.mse_loss(pred_norm, target_norm, reduction="none").mean(
dim=(1, 2)
)
# Corner error (warp 4 corners of the image)
corners = torch.tensor(
[[-1, -1, 1], [1, -1, 1], [1, 1, 1], [-1, 1, 1]],
dtype=torch.float32,
device=pred_homography.device,
).T # Shape: (3, 4)
corners = corners.unsqueeze(0).expand(pred_homography.shape[0], -1, -1)
pred_corners = torch.bmm(pred_norm, corners)
pred_corners = pred_corners / (pred_corners[:, 2:3, :] + 1e-8)
target_corners = torch.bmm(target_norm, corners)
target_corners = target_corners / (target_corners[:, 2:3, :] + 1e-8)
corner_error = torch.mean(
torch.norm(pred_corners[:, :2, :] - target_corners[:, :2, :], dim=1),
dim=1,
)
# Average corner error in pixels (assuming image coordinates in [-1, 1])
# Convert to pixel error if image size is known
avg_corner_error = corner_error.mean().item()
return {
"matrix_mse": matrix_error.mean().item(),
"corner_error": avg_corner_error,
"corner_error_px": avg_corner_error * 128, # Assuming 256x256 images
}
def create_homography_model(
model_type: str = "cnn",
input_size: Tuple[int, int] = (256, 256),
**kwargs,
) -> nn.Module:
"""
Factory function to create homography estimation model.
Args:
model_type: Type of model to create ('cnn' or 'resnet')
input_size: Input image size (height, width)
**kwargs: Additional arguments passed to model constructor
Returns:
Homography estimation model
"""
if model_type == "cnn":
return HomographyCNN(**kwargs)
else:
raise ValueError(f"Unknown model type: {model_type}")
if __name__ == "__main__":
# Test the model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
# Create model
model = HomographyCNN(
input_channels=3,
hidden_channels=64,
num_blocks=4,
dropout_rate=0.3,
use_batch_norm=True,
).to(device)
print(
f"Model created with {sum(p.numel() for p in model.parameters()):,} parameters"
)
# Create dummy input
batch_size = 4
height, width = 256, 256
google_img = torch.randn(batch_size, 3, height, width).to(device)
yandex_img = torch.randn(batch_size, 3, height, width).to(device)
# Test forward pass
print("\nTesting forward pass...")
output = model(google_img, yandex_img, return_matrix=True)
print(f"Output shape: {output.shape}") # Should be (4, 3, 3)
print(f"Sample output:\n{output[0]}")
# Test prediction
print("\nTesting prediction...")
pred = model.predict_homography(google_img, yandex_img)
print(f"Prediction shape: {pred.shape}")
print(f"Last element (should be ~1): {pred[0, 2, 2]:.6f}")
# Test loss function
print("\nTesting loss function...")
target_homography = torch.eye(3).unsqueeze(0).expand(batch_size, -1, -1).to(device)
loss_fn = HomographyLoss(
matrix_weight=1.0,
geometric_weight=0.5,
reg_weight=0.1,
grid_size=8,
).to(device)
loss = loss_fn(output, target_homography, google_img, yandex_img)
print(f"Loss value: {loss.item():.6f}")
# Test metrics
print("\nTesting metrics...")
metrics = loss_fn.compute_metrics(output, target_homography)
for key, value in metrics.items():
print(f"{key}: {value:.6f}")
# Test model factory
print("\nTesting model factory...")
model2 = create_homography_model(
model_type="cnn",
input_size=(256, 256),
input_channels=3,
hidden_channels=32,
num_blocks=3,
).to(device)
print(
f"Model2 created with {sum(p.numel() for p in model2.parameters()):,} parameters"
)
print("\nAll tests completed successfully!")