Files
autopilot/models/SiaN/train_homography.py
2026-02-16 19:07:31 +03:00

612 lines
20 KiB
Python

"""
Training script for homography estimation between Google and Yandex map images.
This script trains a CNN model to estimate homography matrices that align
Google map images with Yandex map images.
"""
import argparse
import json
import os
import time
from datetime import datetime
from pathlib import Path
from typing import Dict, List, Optional, Tuple
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from homography import HomographyDataset, create_data_loaders
from homography_cnn import HomographyCNN, HomographyLoss, create_homography_model
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm
class HomographyTrainer:
"""Trainer class for homography estimation model."""
def __init__(
self,
model: nn.Module,
train_loader: DataLoader,
val_loader: DataLoader,
device: torch.device,
config: Dict,
):
"""
Initialize the trainer.
Args:
model: Homography estimation model
train_loader: Training data loader
val_loader: Validation data loader
device: Device to run training on
config: Training configuration dictionary
"""
self.model = model.to(device)
self.train_loader = train_loader
self.val_loader = val_loader
self.device = device
self.config = config
# Loss function
self.criterion = HomographyLoss(
matrix_weight=config.get("matrix_weight", 1.0),
geometric_weight=config.get("geometric_weight", 0.5),
reg_weight=config.get("reg_weight", 0.1),
grid_size=config.get("grid_size", 8),
).to(device)
# Optimizer
optimizer_name = config.get("optimizer", "adam").lower()
lr = config.get("learning_rate", 1e-3)
weight_decay = config.get("weight_decay", 1e-4)
if optimizer_name == "adam":
self.optimizer = optim.Adam(
self.model.parameters(), lr=lr, weight_decay=weight_decay
)
elif optimizer_name == "adamw":
self.optimizer = optim.AdamW(
self.model.parameters(), lr=lr, weight_decay=weight_decay
)
elif optimizer_name == "sgd":
self.optimizer = optim.SGD(
self.model.parameters(),
lr=lr,
momentum=config.get("momentum", 0.9),
weight_decay=weight_decay,
)
else:
raise ValueError(f"Unknown optimizer: {optimizer_name}")
# Learning rate scheduler
scheduler_name = config.get("scheduler", "plateau").lower()
if scheduler_name == "plateau":
self.scheduler = optim.lr_scheduler.ReduceLROnPlateau(
self.optimizer,
mode="min",
factor=config.get("scheduler_factor", 0.5),
patience=config.get("scheduler_patience", 5),
verbose=True,
)
elif scheduler_name == "cosine":
self.scheduler = optim.lr_scheduler.CosineAnnealingLR(
self.optimizer,
T_max=config.get("epochs", 100),
eta_min=config.get("min_lr", 1e-6),
)
elif scheduler_name == "step":
self.scheduler = optim.lr_scheduler.StepLR(
self.optimizer,
step_size=config.get("step_size", 30),
gamma=config.get("gamma", 0.1),
)
else:
self.scheduler = None
# Training state
self.current_epoch = 0
self.best_val_loss = float("inf")
self.train_losses: List[float] = []
self.val_losses: List[float] = []
self.val_metrics: List[Dict] = []
# Create output directory
self.output_dir = Path(config.get("output_dir", "runs/homography"))
self.output_dir.mkdir(parents=True, exist_ok=True)
# TensorBoard writer
self.writer = SummaryWriter(log_dir=self.output_dir / "tensorboard")
# Save configuration
config_path = self.output_dir / "config.json"
with open(config_path, "w") as f:
json.dump(config, f, indent=2)
print(f"Training configuration saved to {config_path}")
print(
f"Model has {sum(p.numel() for p in self.model.parameters()):,} parameters"
)
def train_epoch(self) -> float:
"""
Train for one epoch.
Returns:
Average training loss for the epoch
"""
self.model.train()
total_loss = 0.0
num_batches = len(self.train_loader)
progress_bar = tqdm(self.train_loader, desc=f"Epoch {self.current_epoch + 1}")
for batch_idx, batch in enumerate(progress_bar):
# Move data to device
google_img = batch["google_img"].to(self.device)
yandex_img = batch["yandex_img"].to(self.device)
target_homography = batch["homography"].to(self.device)
# Forward pass
self.optimizer.zero_grad()
pred_homography = self.model(google_img, yandex_img, return_matrix=True)
# Compute loss
loss = self.criterion(
pred_homography,
target_homography,
google_img,
yandex_img,
)
# Backward pass
loss.backward()
# Gradient clipping
if self.config.get("grad_clip", 1.0) > 0:
torch.nn.utils.clip_grad_norm_(
self.model.parameters(),
self.config.get("grad_clip", 1.0),
)
# Optimizer step
self.optimizer.step()
# Update statistics
total_loss += loss.item()
# Update progress bar
progress_bar.set_postfix({"loss": loss.item()})
# Log batch loss to TensorBoard
global_step = self.current_epoch * num_batches + batch_idx
self.writer.add_scalar("train/batch_loss", loss.item(), global_step)
avg_loss = total_loss / num_batches
self.train_losses.append(avg_loss)
return avg_loss
@torch.no_grad()
def validate(self) -> Tuple[float, Dict]:
"""
Validate the model.
Returns:
Tuple of (average validation loss, validation metrics)
"""
self.model.eval()
total_loss = 0.0
all_metrics = []
progress_bar = tqdm(self.val_loader, desc="Validation")
for batch in progress_bar:
# Move data to device
google_img = batch["google_img"].to(self.device)
yandex_img = batch["yandex_img"].to(self.device)
target_homography = batch["homography"].to(self.device)
# Forward pass
pred_homography = self.model(google_img, yandex_img, return_matrix=True)
# Compute loss
loss = self.criterion(
pred_homography,
target_homography,
google_img,
yandex_img,
)
# Compute metrics
metrics = self.criterion.compute_metrics(pred_homography, target_homography)
# Update statistics
total_loss += loss.item()
all_metrics.append(metrics)
# Update progress bar
progress_bar.set_postfix({"loss": loss.item()})
avg_loss = total_loss / len(self.val_loader)
self.val_losses.append(avg_loss)
# Aggregate metrics
avg_metrics = {}
for key in all_metrics[0].keys():
avg_metrics[key] = np.mean([m[key] for m in all_metrics])
self.val_metrics.append(avg_metrics)
return avg_loss, avg_metrics
def save_checkpoint(self, is_best: bool = False):
"""Save model checkpoint."""
checkpoint = {
"epoch": self.current_epoch,
"model_state_dict": self.model.state_dict(),
"optimizer_state_dict": self.optimizer.state_dict(),
"train_losses": self.train_losses,
"val_losses": self.val_losses,
"val_metrics": self.val_metrics,
"best_val_loss": self.best_val_loss,
"config": self.config,
}
if self.scheduler is not None:
checkpoint["scheduler_state_dict"] = self.scheduler.state_dict()
# Save latest checkpoint
checkpoint_path = self.output_dir / "checkpoint_latest.pth"
torch.save(checkpoint, checkpoint_path)
# Save best checkpoint
if is_best:
best_path = self.output_dir / "checkpoint_best.pth"
torch.save(checkpoint, best_path)
print(f"Best model saved to {best_path}")
def load_checkpoint(self, checkpoint_path: str):
"""Load model checkpoint."""
checkpoint = torch.load(checkpoint_path, map_location=self.device)
self.model.load_state_dict(checkpoint["model_state_dict"])
self.optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
if self.scheduler is not None and "scheduler_state_dict" in checkpoint:
self.scheduler.load_state_dict(checkpoint["scheduler_state_dict"])
self.current_epoch = checkpoint["epoch"]
self.train_losses = checkpoint["train_losses"]
self.val_losses = checkpoint["val_losses"]
self.val_metrics = checkpoint["val_metrics"]
self.best_val_loss = checkpoint["best_val_loss"]
print(f"Loaded checkpoint from epoch {self.current_epoch}")
def train(self, num_epochs: int):
"""
Train the model for specified number of epochs.
Args:
num_epochs: Number of epochs to train
"""
print(f"Starting training for {num_epochs} epochs...")
start_time = time.time()
for epoch in range(num_epochs):
self.current_epoch = epoch
# Train for one epoch
train_loss = self.train_epoch()
# Validate
val_loss, val_metrics = self.validate()
# Update learning rate scheduler
if self.scheduler is not None:
if isinstance(self.scheduler, optim.lr_scheduler.ReduceLROnPlateau):
self.scheduler.step(val_loss)
else:
self.scheduler.step()
# Log to TensorBoard
self.writer.add_scalar("train/epoch_loss", train_loss, epoch)
self.writer.add_scalar("val/epoch_loss", val_loss, epoch)
for metric_name, metric_value in val_metrics.items():
self.writer.add_scalar(f"val/{metric_name}", metric_value, epoch)
# Print epoch summary
print(f"\nEpoch {epoch + 1}/{num_epochs}:")
print(f" Train Loss: {train_loss:.6f}")
print(f" Val Loss: {val_loss:.6f}")
print(" Val Metrics:")
for metric_name, metric_value in val_metrics.items():
print(f" {metric_name}: {metric_value:.6f}")
# Save checkpoint
is_best = val_loss < self.best_val_loss
if is_best:
self.best_val_loss = val_loss
self.save_checkpoint(is_best=is_best)
# Early stopping
if self.config.get("early_stopping_patience", 0) > 0:
if (
epoch - np.argmin(self.val_losses)
>= self.config["early_stopping_patience"]
):
print(f"Early stopping at epoch {epoch + 1}")
break
# Training completed
training_time = time.time() - start_time
print(f"\nTraining completed in {training_time:.2f} seconds")
print(f"Best validation loss: {self.best_val_loss:.6f}")
# Save final model
final_model_path = self.output_dir / "model_final.pth"
torch.save(self.model.state_dict(), final_model_path)
print(f"Final model saved to {final_model_path}")
# Close TensorBoard writer
self.writer.close()
def evaluate(self, test_loader: Optional[DataLoader] = None) -> Dict:
"""
Evaluate the model on test data.
Args:
test_loader: Test data loader (uses validation loader if None)
Returns:
Dictionary of evaluation metrics
"""
if test_loader is None:
test_loader = self.val_loader
self.model.eval()
all_metrics = []
print("Evaluating model...")
with torch.no_grad():
for batch in tqdm(test_loader):
# Move data to device
google_img = batch["google_img"].to(self.device)
yandex_img = batch["yandex_img"].to(self.device)
target_homography = batch["homography"].to(self.device)
# Forward pass
pred_homography = self.model(google_img, yandex_img, return_matrix=True)
# Compute metrics
metrics = self.criterion.compute_metrics(
pred_homography, target_homography
)
all_metrics.append(metrics)
# Aggregate metrics
avg_metrics = {}
for key in all_metrics[0].keys():
avg_metrics[key] = np.mean([m[key] for m in all_metrics])
# Print evaluation results
print("\nEvaluation Results:")
for metric_name, metric_value in avg_metrics.items():
print(f" {metric_name}: {metric_value:.6f}")
# Save evaluation results
eval_path = self.output_dir / "evaluation_results.json"
with open(eval_path, "w") as f:
json.dump(avg_metrics, f, indent=2)
print(f"Evaluation results saved to {eval_path}")
return avg_metrics
def parse_args():
"""Parse command line arguments."""
parser = argparse.ArgumentParser(description="Train homography estimation model")
# Data arguments
parser.add_argument(
"--data_dir",
type=str,
default=r"C:\Users\admin\Projects\autopilot\datasets\ya_go_maps\images",
help="Directory containing image pairs",
)
parser.add_argument(
"--batch_size", type=int, default=32, help="Batch size for training"
)
parser.add_argument(
"--image_size",
type=int,
nargs=2,
default=[256, 256],
help="Image size (height width)",
)
parser.add_argument(
"--train_split", type=float, default=0.8, help="Train/validation split ratio"
)
parser.add_argument(
"--num_workers", type=int, default=4, help="Number of data loader workers"
)
# Model arguments
parser.add_argument(
"--model_type", type=str, default="cnn", choices=["cnn"], help="Model type"
)
parser.add_argument(
"--hidden_channels", type=int, default=64, help="Number of hidden channels"
)
parser.add_argument(
"--num_blocks", type=int, default=4, help="Number of convolutional blocks"
)
parser.add_argument("--dropout_rate", type=float, default=0.3, help="Dropout rate")
parser.add_argument(
"--use_batch_norm", action="store_true", help="Use batch normalization"
)
# Training arguments
parser.add_argument("--epochs", type=int, default=100, help="Number of epochs")
parser.add_argument("--lr", type=float, default=1e-3, help="Learning rate")
parser.add_argument("--weight_decay", type=float, default=1e-4, help="Weight decay")
parser.add_argument(
"--optimizer", type=str, default="adam", choices=["adam", "adamw", "sgd"]
)
parser.add_argument(
"--scheduler",
type=str,
default="plateau",
choices=["plateau", "cosine", "step", "none"],
)
parser.add_argument(
"--grad_clip", type=float, default=1.0, help="Gradient clipping value"
)
# Loss arguments
parser.add_argument(
"--matrix_weight", type=float, default=1.0, help="Weight for matrix loss"
)
parser.add_argument(
"--geometric_weight",
type=float,
default=0.5,
help="Weight for geometric loss",
)
parser.add_argument(
"--reg_weight", type=float, default=0.1, help="Weight for regularization loss"
)
# Other arguments
parser.add_argument(
"--output_dir",
type=str,
default="runs/homography",
help="Output directory for checkpoints and logs",
)
parser.add_argument(
"--resume",
type=str,
help="Path to checkpoint to resume training from",
)
parser.add_argument(
"--eval_only",
action="store_true",
help="Only evaluate the model (no training)",
)
parser.add_argument(
"--seed", type=int, default=42, help="Random seed for reproducibility"
)
return parser.parse_args()
def main():
"""Main training function."""
args = parse_args()
# Set random seeds for reproducibility
torch.manual_seed(args.seed)
np.random.seed(args.seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(args.seed)
torch.cuda.manual_seed_all(args.seed)
# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
# Create data loaders
print("Creating data loaders...")
train_loader, val_loader = create_data_loaders(
root_dir=args.data_dir,
batch_size=args.batch_size,
train_split=args.train_split,
num_workers=args.num_workers,
image_size=tuple(args.image_size),
augment_train=True,
augment_val=False,
)
print(f"Train batches: {len(train_loader)}")
print(f"Val batches: {len(val_loader)}")
# Create model
print("Creating model...")
model = create_homography_model(
model_type=args.model_type,
input_size=tuple(args.image_size),
input_channels=3,
hidden_channels=args.hidden_channels,
num_blocks=args.num_blocks,
dropout_rate=args.dropout_rate,
use_batch_norm=args.use_batch_norm,
)
# Create trainer configuration
config = {
# Model config
"model_type": args.model_type,
"hidden_channels": args.hidden_channels,
"num_blocks": args.num_blocks,
"dropout_rate": args.dropout_rate,
"use_batch_norm": args.use_batch_norm,
"image_size": args.image_size,
# Training config
"epochs": args.epochs,
"batch_size": args.batch_size,
"learning_rate": args.lr,
"weight_decay": args.weight_decay,
"optimizer": args.optimizer,
"scheduler": args.scheduler,
"grad_clip": args.grad_clip,
# Loss config
"matrix_weight": args.matrix_weight,
"geometric_weight": args.geometric_weight,
"reg_weight": args.reg_weight,
"grid_size": 8,
# Data config
"data_dir": args.data_dir,
"train_split": args.train_split,
"num_workers": args.num_workers,
# Output config
"output_dir": args.output_dir,
"seed": args.seed,
}
# Create trainer
trainer = HomographyTrainer(
model=model,
train_loader=train_loader,
val_loader=val_loader,
device=device,
config=config,
)
# Resume from checkpoint if specified
if args.resume:
print(f"Resuming from checkpoint: {args.resume}")
trainer.load_checkpoint(args.resume)
# Evaluate only mode
if args.eval_only:
trainer.evaluate()
return
# Train the model
trainer.train(num_epochs=args.epochs)
# Final evaluation
print("\nPerforming final evaluation...")
trainer.evaluate()
print("\nTraining completed successfully!")
print(f"Results saved to: {args.output_dir}")
if __name__ == "__main__":
main()