from typing import Optional, Tuple import torch import torch.nn as nn import torch.nn.functional as F class HomographyCNN(nn.Module): """ CNN model for homography estimation between two images. This model takes two images (Google and Yandex maps) as input and outputs a 3x3 homography matrix that transforms one image to align with the other. """ def __init__( self, input_channels: int = 3, hidden_channels: int = 64, num_blocks: int = 4, dropout_rate: float = 0.3, use_batch_norm: bool = True, output_size: int = 9, # Flattened 3x3 homography matrix ): """ Initialize the HomographyCNN model. Args: input_channels: Number of input channels per image (3 for RGB) hidden_channels: Base number of channels in the network num_blocks: Number of convolutional blocks dropout_rate: Dropout rate for regularization use_batch_norm: Whether to use batch normalization output_size: Size of output vector (9 for flattened 3x3 matrix) """ super().__init__() self.input_channels = input_channels self.hidden_channels = hidden_channels self.num_blocks = num_blocks self.dropout_rate = dropout_rate self.use_batch_norm = use_batch_norm # Feature extraction for each image separately self.google_encoder = self._build_encoder() self.yandex_encoder = self._build_encoder() # Fusion layers to combine features from both images self.fusion_layers = self._build_fusion_layers() # Regression head for homography estimation self.regression_head = self._build_regression_head(output_size) # Initialize weights self._initialize_weights() def _build_encoder(self) -> nn.Module: """Build the encoder network for a single image.""" layers = [] in_channels = self.input_channels out_channels = self.hidden_channels # First convolutional block layers.append( nn.Conv2d(in_channels, out_channels, kernel_size=7, stride=2, padding=3) ) if self.use_batch_norm: layers.append(nn.BatchNorm2d(out_channels)) layers.append(nn.ReLU(inplace=True)) layers.append(nn.MaxPool2d(kernel_size=3, stride=2, padding=1)) # Additional convolutional blocks for i in range(self.num_blocks): block_in_channels = out_channels block_out_channels = out_channels * 2 if i < 2 else out_channels layers.append( ResidualBlock( in_channels=block_in_channels, out_channels=block_out_channels, stride=1 if i == 0 else 2, dropout_rate=self.dropout_rate, use_batch_norm=self.use_batch_norm, ) ) if i < 2: out_channels = block_out_channels return nn.Sequential(*layers) def _build_fusion_layers(self) -> nn.Module: """Build layers to fuse features from both images.""" # After encoding, each image has hidden_channels * 4 features fused_channels = ( self.hidden_channels * 8 ) # Concatenated features from both images layers = [ # Reduce dimensionality nn.Conv2d( fused_channels, self.hidden_channels * 4, kernel_size=3, padding=1 ), nn.BatchNorm2d(self.hidden_channels * 4) if self.use_batch_norm else nn.Identity(), nn.ReLU(inplace=True), nn.Dropout2d(self.dropout_rate), # Further processing nn.Conv2d( self.hidden_channels * 4, self.hidden_channels * 2, kernel_size=3, padding=1, ), nn.BatchNorm2d(self.hidden_channels * 2) if self.use_batch_norm else nn.Identity(), nn.ReLU(inplace=True), nn.Dropout2d(self.dropout_rate), # Global pooling nn.AdaptiveAvgPool2d((1, 1)), ] return nn.Sequential(*layers) def _build_regression_head(self, output_size: int) -> nn.Module: """Build the regression head for homography estimation.""" # Input size after fusion and global pooling input_features = self.hidden_channels * 2 layers = [ nn.Flatten(), nn.Linear(input_features, 512), nn.BatchNorm1d(512) if self.use_batch_norm else nn.Identity(), nn.ReLU(inplace=True), nn.Dropout(self.dropout_rate), nn.Linear(512, 256), nn.BatchNorm1d(256) if self.use_batch_norm else nn.Identity(), nn.ReLU(inplace=True), nn.Dropout(self.dropout_rate), nn.Linear(256, 128), nn.BatchNorm1d(128) if self.use_batch_norm else nn.Identity(), nn.ReLU(inplace=True), nn.Dropout(self.dropout_rate), nn.Linear(128, output_size), ] return nn.Sequential(*layers) def _initialize_weights(self): """Initialize model weights.""" for m in self.modules(): if isinstance(m, nn.Conv2d): nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") if m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d): nn.init.constant_(m.weight, 1) nn.init.constant_(m.bias, 0) elif isinstance(m, nn.Linear): nn.init.normal_(m.weight, 0, 0.01) nn.init.constant_(m.bias, 0) def forward( self, google_img: torch.Tensor, yandex_img: torch.Tensor, return_matrix: bool = True, ) -> torch.Tensor: """ Forward pass of the model. Args: google_img: Google map image tensor of shape (B, C, H, W) yandex_img: Yandex map image tensor of shape (B, C, H, W) return_matrix: If True, return 3x3 matrix; if False, return flattened vector Returns: Homography matrix tensor of shape (B, 3, 3) or flattened vector of shape (B, 9) """ # Extract features from both images google_features = self.google_encoder(google_img) yandex_features = self.yandex_encoder(yandex_img) # Concatenate features along channel dimension combined_features = torch.cat([google_features, yandex_features], dim=1) # Fuse features fused_features = self.fusion_layers(combined_features) # Regression to get homography parameters homography_flat = self.regression_head(fused_features) if return_matrix: # Reshape to 3x3 matrix batch_size = homography_flat.shape[0] homography_matrix = homography_flat.view(batch_size, 3, 3) # Ensure the last element is 1 (homogeneous coordinate normalization) # Add small epsilon to prevent division by zero epsilon = 1e-8 homography_matrix = homography_matrix / ( homography_matrix[:, 2, 2].view(-1, 1, 1) + epsilon ) return homography_matrix else: return homography_flat def predict_homography( self, google_img: torch.Tensor, yandex_img: torch.Tensor, normalize: bool = True, ) -> torch.Tensor: """ Predict homography matrix with optional normalization. Args: google_img: Google map image tensor yandex_img: Yandex map image tensor normalize: Whether to normalize the homography matrix Returns: Predicted homography matrix """ self.eval() with torch.no_grad(): homography = self.forward(google_img, yandex_img, return_matrix=True) if normalize: # Normalize so that last element is 1 homography = homography / homography[:, 2, 2].view(-1, 1, 1) return homography class ResidualBlock(nn.Module): """Residual block with optional downsampling.""" def __init__( self, in_channels: int, out_channels: int, stride: int = 1, dropout_rate: float = 0.3, use_batch_norm: bool = True, ): super().__init__() self.conv1 = nn.Conv2d( in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False, ) self.bn1 = nn.BatchNorm2d(out_channels) if use_batch_norm else nn.Identity() self.relu1 = nn.ReLU(inplace=True) self.dropout1 = ( nn.Dropout2d(dropout_rate) if dropout_rate > 0 else nn.Identity() ) self.conv2 = nn.Conv2d( out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False ) self.bn2 = nn.BatchNorm2d(out_channels) if use_batch_norm else nn.Identity() self.relu2 = nn.ReLU(inplace=True) self.dropout2 = ( nn.Dropout2d(dropout_rate) if dropout_rate > 0 else nn.Identity() ) # Shortcut connection self.shortcut = nn.Sequential() if stride != 1 or in_channels != out_channels: self.shortcut = nn.Sequential( nn.Conv2d( in_channels, out_channels, kernel_size=1, stride=stride, bias=False ), nn.BatchNorm2d(out_channels) if use_batch_norm else nn.Identity(), ) def forward(self, x: torch.Tensor) -> torch.Tensor: identity = self.shortcut(x) out = self.conv1(x) out = self.bn1(out) out = self.relu1(out) out = self.dropout1(out) out = self.conv2(out) out = self.bn2(out) out += identity out = self.relu2(out) out = self.dropout2(out) return out class HomographyLoss(nn.Module): """ Custom loss function for homography estimation. Combines multiple loss terms: 1. Matrix element-wise L2 loss 2. Geometric consistency loss (warping error) 3. Determinant regularization (to prevent degenerate matrices) """ def __init__( self, matrix_weight: float = 1.0, geometric_weight: float = 0.5, reg_weight: float = 0.1, grid_size: int = 8, ): super().__init__() self.matrix_weight = matrix_weight self.geometric_weight = geometric_weight self.reg_weight = reg_weight self.grid_size = grid_size # Create grid of points for geometric loss self.register_buffer( "grid_points", self._create_grid_points(grid_size), persistent=False, ) def _create_grid_points(self, grid_size: int) -> torch.Tensor: """Create a grid of points for geometric consistency loss.""" x = torch.linspace(-1, 1, grid_size) y = torch.linspace(-1, 1, grid_size) grid_y, grid_x = torch.meshgrid(y, x, indexing="ij") grid_points = torch.stack([grid_x.flatten(), grid_y.flatten()], dim=1) # Add homogeneous coordinate ones = torch.ones(grid_points.shape[0], 1) grid_points = torch.cat([grid_points, ones], dim=1) return grid_points.T # Shape: (3, grid_size*grid_size) def forward( self, pred_homography: torch.Tensor, target_homography: torch.Tensor, google_img: Optional[torch.Tensor] = None, yandex_img: Optional[torch.Tensor] = None, ) -> torch.Tensor: """ Compute homography loss. Args: pred_homography: Predicted homography matrices (B, 3, 3) target_homography: Target homography matrices (B, 3, 3) google_img: Google images (optional, for geometric loss) yandex_img: Yandex images (optional, for geometric loss) Returns: Combined loss value """ batch_size = pred_homography.shape[0] # 1. Matrix element-wise L2 loss matrix_loss = F.mse_loss(pred_homography, target_homography) # 2. Geometric consistency loss (if images provided) geometric_loss = torch.tensor(0.0, device=pred_homography.device) if google_img is not None and yandex_img is not None: # Warp grid points with predicted homography grid_points = self.grid_points.unsqueeze(0).expand(batch_size, -1, -1) warped_points = torch.bmm(pred_homography, grid_points) # Normalize homogeneous coordinates warped_points = warped_points / (warped_points[:, 2:3, :] + 1e-8) # Warp with target homography for comparison target_warped_points = torch.bmm(target_homography, grid_points) target_warped_points = target_warped_points / ( target_warped_points[:, 2:3, :] + 1e-8 ) # Compute point-wise distance geometric_loss = F.mse_loss( warped_points[:, :2, :], target_warped_points[:, :2, :] ) # 3. Regularization loss (prevent degenerate matrices) # Encourage determinant to be close to 1 pred_det = torch.det(pred_homography) reg_loss = F.mse_loss(pred_det, torch.ones_like(pred_det)) # Combine losses total_loss = ( self.matrix_weight * matrix_loss + self.geometric_weight * geometric_loss + self.reg_weight * reg_loss ) return total_loss def compute_metrics( self, pred_homography: torch.Tensor, target_homography: torch.Tensor, ) -> dict: """ Compute evaluation metrics for homography estimation. Args: pred_homography: Predicted homography matrices target_homography: Target homography matrices Returns: Dictionary of metrics """ with torch.no_grad(): # Normalize matrices pred_norm = pred_homography / pred_homography[:, 2, 2].view(-1, 1, 1) target_norm = target_homography / target_homography[:, 2, 2].view(-1, 1, 1) # Matrix L2 error matrix_error = F.mse_loss(pred_norm, target_norm, reduction="none").mean( dim=(1, 2) ) # Corner error (warp 4 corners of the image) corners = torch.tensor( [[-1, -1, 1], [1, -1, 1], [1, 1, 1], [-1, 1, 1]], dtype=torch.float32, device=pred_homography.device, ).T # Shape: (3, 4) corners = corners.unsqueeze(0).expand(pred_homography.shape[0], -1, -1) pred_corners = torch.bmm(pred_norm, corners) pred_corners = pred_corners / (pred_corners[:, 2:3, :] + 1e-8) target_corners = torch.bmm(target_norm, corners) target_corners = target_corners / (target_corners[:, 2:3, :] + 1e-8) corner_error = torch.mean( torch.norm(pred_corners[:, :2, :] - target_corners[:, :2, :], dim=1), dim=1, ) # Average corner error in pixels (assuming image coordinates in [-1, 1]) # Convert to pixel error if image size is known avg_corner_error = corner_error.mean().item() return { "matrix_mse": matrix_error.mean().item(), "corner_error": avg_corner_error, "corner_error_px": avg_corner_error * 128, # Assuming 256x256 images } def create_homography_model( model_type: str = "cnn", input_size: Tuple[int, int] = (256, 256), **kwargs, ) -> nn.Module: """ Factory function to create homography estimation model. Args: model_type: Type of model to create ('cnn' or 'resnet') input_size: Input image size (height, width) **kwargs: Additional arguments passed to model constructor Returns: Homography estimation model """ if model_type == "cnn": return HomographyCNN(**kwargs) else: raise ValueError(f"Unknown model type: {model_type}") if __name__ == "__main__": # Test the model device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"Using device: {device}") # Create model model = HomographyCNN( input_channels=3, hidden_channels=64, num_blocks=4, dropout_rate=0.3, use_batch_norm=True, ).to(device) print( f"Model created with {sum(p.numel() for p in model.parameters()):,} parameters" ) # Create dummy input batch_size = 4 height, width = 256, 256 google_img = torch.randn(batch_size, 3, height, width).to(device) yandex_img = torch.randn(batch_size, 3, height, width).to(device) # Test forward pass print("\nTesting forward pass...") output = model(google_img, yandex_img, return_matrix=True) print(f"Output shape: {output.shape}") # Should be (4, 3, 3) print(f"Sample output:\n{output[0]}") # Test prediction print("\nTesting prediction...") pred = model.predict_homography(google_img, yandex_img) print(f"Prediction shape: {pred.shape}") print(f"Last element (should be ~1): {pred[0, 2, 2]:.6f}") # Test loss function print("\nTesting loss function...") target_homography = torch.eye(3).unsqueeze(0).expand(batch_size, -1, -1).to(device) loss_fn = HomographyLoss( matrix_weight=1.0, geometric_weight=0.5, reg_weight=0.1, grid_size=8, ).to(device) loss = loss_fn(output, target_homography, google_img, yandex_img) print(f"Loss value: {loss.item():.6f}") # Test metrics print("\nTesting metrics...") metrics = loss_fn.compute_metrics(output, target_homography) for key, value in metrics.items(): print(f"{key}: {value:.6f}") # Test model factory print("\nTesting model factory...") model2 = create_homography_model( model_type="cnn", input_size=(256, 256), input_channels=3, hidden_channels=32, num_blocks=3, ).to(device) print( f"Model2 created with {sum(p.numel() for p in model2.parameters()):,} parameters" ) print("\nAll tests completed successfully!")