""" Training script for homography estimation between Google and Yandex map images. This script trains a CNN model to estimate homography matrices that align Google map images with Yandex map images. """ import argparse import json import os import time from datetime import datetime from pathlib import Path from typing import Dict, List, Optional, Tuple import numpy as np import torch import torch.nn as nn import torch.optim as optim from homography import HomographyDataset, create_data_loaders from homography_cnn import HomographyCNN, HomographyLoss, create_homography_model from torch.utils.data import DataLoader from torch.utils.tensorboard import SummaryWriter from tqdm import tqdm class HomographyTrainer: """Trainer class for homography estimation model.""" def __init__( self, model: nn.Module, train_loader: DataLoader, val_loader: DataLoader, device: torch.device, config: Dict, ): """ Initialize the trainer. Args: model: Homography estimation model train_loader: Training data loader val_loader: Validation data loader device: Device to run training on config: Training configuration dictionary """ self.model = model.to(device) self.train_loader = train_loader self.val_loader = val_loader self.device = device self.config = config # Loss function self.criterion = HomographyLoss( matrix_weight=config.get("matrix_weight", 1.0), geometric_weight=config.get("geometric_weight", 0.5), reg_weight=config.get("reg_weight", 0.1), grid_size=config.get("grid_size", 8), ).to(device) # Optimizer optimizer_name = config.get("optimizer", "adam").lower() lr = config.get("learning_rate", 1e-3) weight_decay = config.get("weight_decay", 1e-4) if optimizer_name == "adam": self.optimizer = optim.Adam( self.model.parameters(), lr=lr, weight_decay=weight_decay ) elif optimizer_name == "adamw": self.optimizer = optim.AdamW( self.model.parameters(), lr=lr, weight_decay=weight_decay ) elif optimizer_name == "sgd": self.optimizer = optim.SGD( self.model.parameters(), lr=lr, momentum=config.get("momentum", 0.9), weight_decay=weight_decay, ) else: raise ValueError(f"Unknown optimizer: {optimizer_name}") # Learning rate scheduler scheduler_name = config.get("scheduler", "plateau").lower() if scheduler_name == "plateau": self.scheduler = optim.lr_scheduler.ReduceLROnPlateau( self.optimizer, mode="min", factor=config.get("scheduler_factor", 0.5), patience=config.get("scheduler_patience", 5), verbose=True, ) elif scheduler_name == "cosine": self.scheduler = optim.lr_scheduler.CosineAnnealingLR( self.optimizer, T_max=config.get("epochs", 100), eta_min=config.get("min_lr", 1e-6), ) elif scheduler_name == "step": self.scheduler = optim.lr_scheduler.StepLR( self.optimizer, step_size=config.get("step_size", 30), gamma=config.get("gamma", 0.1), ) else: self.scheduler = None # Training state self.current_epoch = 0 self.best_val_loss = float("inf") self.train_losses: List[float] = [] self.val_losses: List[float] = [] self.val_metrics: List[Dict] = [] # Create output directory self.output_dir = Path(config.get("output_dir", "runs/homography")) self.output_dir.mkdir(parents=True, exist_ok=True) # TensorBoard writer self.writer = SummaryWriter(log_dir=self.output_dir / "tensorboard") # Save configuration config_path = self.output_dir / "config.json" with open(config_path, "w") as f: json.dump(config, f, indent=2) print(f"Training configuration saved to {config_path}") print( f"Model has {sum(p.numel() for p in self.model.parameters()):,} parameters" ) def train_epoch(self) -> float: """ Train for one epoch. Returns: Average training loss for the epoch """ self.model.train() total_loss = 0.0 num_batches = len(self.train_loader) progress_bar = tqdm(self.train_loader, desc=f"Epoch {self.current_epoch + 1}") for batch_idx, batch in enumerate(progress_bar): # Move data to device google_img = batch["google_img"].to(self.device) yandex_img = batch["yandex_img"].to(self.device) target_homography = batch["homography"].to(self.device) # Forward pass self.optimizer.zero_grad() pred_homography = self.model(google_img, yandex_img, return_matrix=True) # Compute loss loss = self.criterion( pred_homography, target_homography, google_img, yandex_img, ) # Backward pass loss.backward() # Gradient clipping if self.config.get("grad_clip", 1.0) > 0: torch.nn.utils.clip_grad_norm_( self.model.parameters(), self.config.get("grad_clip", 1.0), ) # Optimizer step self.optimizer.step() # Update statistics total_loss += loss.item() # Update progress bar progress_bar.set_postfix({"loss": loss.item()}) # Log batch loss to TensorBoard global_step = self.current_epoch * num_batches + batch_idx self.writer.add_scalar("train/batch_loss", loss.item(), global_step) avg_loss = total_loss / num_batches self.train_losses.append(avg_loss) return avg_loss @torch.no_grad() def validate(self) -> Tuple[float, Dict]: """ Validate the model. Returns: Tuple of (average validation loss, validation metrics) """ self.model.eval() total_loss = 0.0 all_metrics = [] progress_bar = tqdm(self.val_loader, desc="Validation") for batch in progress_bar: # Move data to device google_img = batch["google_img"].to(self.device) yandex_img = batch["yandex_img"].to(self.device) target_homography = batch["homography"].to(self.device) # Forward pass pred_homography = self.model(google_img, yandex_img, return_matrix=True) # Compute loss loss = self.criterion( pred_homography, target_homography, google_img, yandex_img, ) # Compute metrics metrics = self.criterion.compute_metrics(pred_homography, target_homography) # Update statistics total_loss += loss.item() all_metrics.append(metrics) # Update progress bar progress_bar.set_postfix({"loss": loss.item()}) avg_loss = total_loss / len(self.val_loader) self.val_losses.append(avg_loss) # Aggregate metrics avg_metrics = {} for key in all_metrics[0].keys(): avg_metrics[key] = np.mean([m[key] for m in all_metrics]) self.val_metrics.append(avg_metrics) return avg_loss, avg_metrics def save_checkpoint(self, is_best: bool = False): """Save model checkpoint.""" checkpoint = { "epoch": self.current_epoch, "model_state_dict": self.model.state_dict(), "optimizer_state_dict": self.optimizer.state_dict(), "train_losses": self.train_losses, "val_losses": self.val_losses, "val_metrics": self.val_metrics, "best_val_loss": self.best_val_loss, "config": self.config, } if self.scheduler is not None: checkpoint["scheduler_state_dict"] = self.scheduler.state_dict() # Save latest checkpoint checkpoint_path = self.output_dir / "checkpoint_latest.pth" torch.save(checkpoint, checkpoint_path) # Save best checkpoint if is_best: best_path = self.output_dir / "checkpoint_best.pth" torch.save(checkpoint, best_path) print(f"Best model saved to {best_path}") def load_checkpoint(self, checkpoint_path: str): """Load model checkpoint.""" checkpoint = torch.load(checkpoint_path, map_location=self.device) self.model.load_state_dict(checkpoint["model_state_dict"]) self.optimizer.load_state_dict(checkpoint["optimizer_state_dict"]) if self.scheduler is not None and "scheduler_state_dict" in checkpoint: self.scheduler.load_state_dict(checkpoint["scheduler_state_dict"]) self.current_epoch = checkpoint["epoch"] self.train_losses = checkpoint["train_losses"] self.val_losses = checkpoint["val_losses"] self.val_metrics = checkpoint["val_metrics"] self.best_val_loss = checkpoint["best_val_loss"] print(f"Loaded checkpoint from epoch {self.current_epoch}") def train(self, num_epochs: int): """ Train the model for specified number of epochs. Args: num_epochs: Number of epochs to train """ print(f"Starting training for {num_epochs} epochs...") start_time = time.time() for epoch in range(num_epochs): self.current_epoch = epoch # Train for one epoch train_loss = self.train_epoch() # Validate val_loss, val_metrics = self.validate() # Update learning rate scheduler if self.scheduler is not None: if isinstance(self.scheduler, optim.lr_scheduler.ReduceLROnPlateau): self.scheduler.step(val_loss) else: self.scheduler.step() # Log to TensorBoard self.writer.add_scalar("train/epoch_loss", train_loss, epoch) self.writer.add_scalar("val/epoch_loss", val_loss, epoch) for metric_name, metric_value in val_metrics.items(): self.writer.add_scalar(f"val/{metric_name}", metric_value, epoch) # Print epoch summary print(f"\nEpoch {epoch + 1}/{num_epochs}:") print(f" Train Loss: {train_loss:.6f}") print(f" Val Loss: {val_loss:.6f}") print(" Val Metrics:") for metric_name, metric_value in val_metrics.items(): print(f" {metric_name}: {metric_value:.6f}") # Save checkpoint is_best = val_loss < self.best_val_loss if is_best: self.best_val_loss = val_loss self.save_checkpoint(is_best=is_best) # Early stopping if self.config.get("early_stopping_patience", 0) > 0: if ( epoch - np.argmin(self.val_losses) >= self.config["early_stopping_patience"] ): print(f"Early stopping at epoch {epoch + 1}") break # Training completed training_time = time.time() - start_time print(f"\nTraining completed in {training_time:.2f} seconds") print(f"Best validation loss: {self.best_val_loss:.6f}") # Save final model final_model_path = self.output_dir / "model_final.pth" torch.save(self.model.state_dict(), final_model_path) print(f"Final model saved to {final_model_path}") # Close TensorBoard writer self.writer.close() def evaluate(self, test_loader: Optional[DataLoader] = None) -> Dict: """ Evaluate the model on test data. Args: test_loader: Test data loader (uses validation loader if None) Returns: Dictionary of evaluation metrics """ if test_loader is None: test_loader = self.val_loader self.model.eval() all_metrics = [] print("Evaluating model...") with torch.no_grad(): for batch in tqdm(test_loader): # Move data to device google_img = batch["google_img"].to(self.device) yandex_img = batch["yandex_img"].to(self.device) target_homography = batch["homography"].to(self.device) # Forward pass pred_homography = self.model(google_img, yandex_img, return_matrix=True) # Compute metrics metrics = self.criterion.compute_metrics( pred_homography, target_homography ) all_metrics.append(metrics) # Aggregate metrics avg_metrics = {} for key in all_metrics[0].keys(): avg_metrics[key] = np.mean([m[key] for m in all_metrics]) # Print evaluation results print("\nEvaluation Results:") for metric_name, metric_value in avg_metrics.items(): print(f" {metric_name}: {metric_value:.6f}") # Save evaluation results eval_path = self.output_dir / "evaluation_results.json" with open(eval_path, "w") as f: json.dump(avg_metrics, f, indent=2) print(f"Evaluation results saved to {eval_path}") return avg_metrics def parse_args(): """Parse command line arguments.""" parser = argparse.ArgumentParser(description="Train homography estimation model") # Data arguments parser.add_argument( "--data_dir", type=str, default=r"C:\Users\admin\Projects\autopilot\datasets\ya_go_maps\images", help="Directory containing image pairs", ) parser.add_argument( "--batch_size", type=int, default=32, help="Batch size for training" ) parser.add_argument( "--image_size", type=int, nargs=2, default=[256, 256], help="Image size (height width)", ) parser.add_argument( "--train_split", type=float, default=0.8, help="Train/validation split ratio" ) parser.add_argument( "--num_workers", type=int, default=4, help="Number of data loader workers" ) # Model arguments parser.add_argument( "--model_type", type=str, default="cnn", choices=["cnn"], help="Model type" ) parser.add_argument( "--hidden_channels", type=int, default=64, help="Number of hidden channels" ) parser.add_argument( "--num_blocks", type=int, default=4, help="Number of convolutional blocks" ) parser.add_argument("--dropout_rate", type=float, default=0.3, help="Dropout rate") parser.add_argument( "--use_batch_norm", action="store_true", help="Use batch normalization" ) # Training arguments parser.add_argument("--epochs", type=int, default=100, help="Number of epochs") parser.add_argument("--lr", type=float, default=1e-3, help="Learning rate") parser.add_argument("--weight_decay", type=float, default=1e-4, help="Weight decay") parser.add_argument( "--optimizer", type=str, default="adam", choices=["adam", "adamw", "sgd"] ) parser.add_argument( "--scheduler", type=str, default="plateau", choices=["plateau", "cosine", "step", "none"], ) parser.add_argument( "--grad_clip", type=float, default=1.0, help="Gradient clipping value" ) # Loss arguments parser.add_argument( "--matrix_weight", type=float, default=1.0, help="Weight for matrix loss" ) parser.add_argument( "--geometric_weight", type=float, default=0.5, help="Weight for geometric loss", ) parser.add_argument( "--reg_weight", type=float, default=0.1, help="Weight for regularization loss" ) # Other arguments parser.add_argument( "--output_dir", type=str, default="runs/homography", help="Output directory for checkpoints and logs", ) parser.add_argument( "--resume", type=str, help="Path to checkpoint to resume training from", ) parser.add_argument( "--eval_only", action="store_true", help="Only evaluate the model (no training)", ) parser.add_argument( "--seed", type=int, default=42, help="Random seed for reproducibility" ) return parser.parse_args() def main(): """Main training function.""" args = parse_args() # Set random seeds for reproducibility torch.manual_seed(args.seed) np.random.seed(args.seed) if torch.cuda.is_available(): torch.cuda.manual_seed(args.seed) torch.cuda.manual_seed_all(args.seed) # Set device device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"Using device: {device}") # Create data loaders print("Creating data loaders...") train_loader, val_loader = create_data_loaders( root_dir=args.data_dir, batch_size=args.batch_size, train_split=args.train_split, num_workers=args.num_workers, image_size=tuple(args.image_size), augment_train=True, augment_val=False, ) print(f"Train batches: {len(train_loader)}") print(f"Val batches: {len(val_loader)}") # Create model print("Creating model...") model = create_homography_model( model_type=args.model_type, input_size=tuple(args.image_size), input_channels=3, hidden_channels=args.hidden_channels, num_blocks=args.num_blocks, dropout_rate=args.dropout_rate, use_batch_norm=args.use_batch_norm, ) # Create trainer configuration config = { # Model config "model_type": args.model_type, "hidden_channels": args.hidden_channels, "num_blocks": args.num_blocks, "dropout_rate": args.dropout_rate, "use_batch_norm": args.use_batch_norm, "image_size": args.image_size, # Training config "epochs": args.epochs, "batch_size": args.batch_size, "learning_rate": args.lr, "weight_decay": args.weight_decay, "optimizer": args.optimizer, "scheduler": args.scheduler, "grad_clip": args.grad_clip, # Loss config "matrix_weight": args.matrix_weight, "geometric_weight": args.geometric_weight, "reg_weight": args.reg_weight, "grid_size": 8, # Data config "data_dir": args.data_dir, "train_split": args.train_split, "num_workers": args.num_workers, # Output config "output_dir": args.output_dir, "seed": args.seed, } # Create trainer trainer = HomographyTrainer( model=model, train_loader=train_loader, val_loader=val_loader, device=device, config=config, ) # Resume from checkpoint if specified if args.resume: print(f"Resuming from checkpoint: {args.resume}") trainer.load_checkpoint(args.resume) # Evaluate only mode if args.eval_only: trainer.evaluate() return # Train the model trainer.train(num_epochs=args.epochs) # Final evaluation print("\nPerforming final evaluation...") trainer.evaluate() print("\nTraining completed successfully!") print(f"Results saved to: {args.output_dir}") if __name__ == "__main__": main()