612 lines
20 KiB
Python
612 lines
20 KiB
Python
"""
|
|
Training script for homography estimation between Google and Yandex map images.
|
|
|
|
This script trains a CNN model to estimate homography matrices that align
|
|
Google map images with Yandex map images.
|
|
"""
|
|
|
|
import argparse
|
|
import json
|
|
import os
|
|
import time
|
|
from datetime import datetime
|
|
from pathlib import Path
|
|
from typing import Dict, List, Optional, Tuple
|
|
|
|
import numpy as np
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.optim as optim
|
|
from homography import HomographyDataset, create_data_loaders
|
|
from homography_cnn import HomographyCNN, HomographyLoss, create_homography_model
|
|
from torch.utils.data import DataLoader
|
|
from torch.utils.tensorboard import SummaryWriter
|
|
from tqdm import tqdm
|
|
|
|
|
|
class HomographyTrainer:
|
|
"""Trainer class for homography estimation model."""
|
|
|
|
def __init__(
|
|
self,
|
|
model: nn.Module,
|
|
train_loader: DataLoader,
|
|
val_loader: DataLoader,
|
|
device: torch.device,
|
|
config: Dict,
|
|
):
|
|
"""
|
|
Initialize the trainer.
|
|
|
|
Args:
|
|
model: Homography estimation model
|
|
train_loader: Training data loader
|
|
val_loader: Validation data loader
|
|
device: Device to run training on
|
|
config: Training configuration dictionary
|
|
"""
|
|
self.model = model.to(device)
|
|
self.train_loader = train_loader
|
|
self.val_loader = val_loader
|
|
self.device = device
|
|
self.config = config
|
|
|
|
# Loss function
|
|
self.criterion = HomographyLoss(
|
|
matrix_weight=config.get("matrix_weight", 1.0),
|
|
geometric_weight=config.get("geometric_weight", 0.5),
|
|
reg_weight=config.get("reg_weight", 0.1),
|
|
grid_size=config.get("grid_size", 8),
|
|
).to(device)
|
|
|
|
# Optimizer
|
|
optimizer_name = config.get("optimizer", "adam").lower()
|
|
lr = config.get("learning_rate", 1e-3)
|
|
weight_decay = config.get("weight_decay", 1e-4)
|
|
|
|
if optimizer_name == "adam":
|
|
self.optimizer = optim.Adam(
|
|
self.model.parameters(), lr=lr, weight_decay=weight_decay
|
|
)
|
|
elif optimizer_name == "adamw":
|
|
self.optimizer = optim.AdamW(
|
|
self.model.parameters(), lr=lr, weight_decay=weight_decay
|
|
)
|
|
elif optimizer_name == "sgd":
|
|
self.optimizer = optim.SGD(
|
|
self.model.parameters(),
|
|
lr=lr,
|
|
momentum=config.get("momentum", 0.9),
|
|
weight_decay=weight_decay,
|
|
)
|
|
else:
|
|
raise ValueError(f"Unknown optimizer: {optimizer_name}")
|
|
|
|
# Learning rate scheduler
|
|
scheduler_name = config.get("scheduler", "plateau").lower()
|
|
if scheduler_name == "plateau":
|
|
self.scheduler = optim.lr_scheduler.ReduceLROnPlateau(
|
|
self.optimizer,
|
|
mode="min",
|
|
factor=config.get("scheduler_factor", 0.5),
|
|
patience=config.get("scheduler_patience", 5),
|
|
verbose=True,
|
|
)
|
|
elif scheduler_name == "cosine":
|
|
self.scheduler = optim.lr_scheduler.CosineAnnealingLR(
|
|
self.optimizer,
|
|
T_max=config.get("epochs", 100),
|
|
eta_min=config.get("min_lr", 1e-6),
|
|
)
|
|
elif scheduler_name == "step":
|
|
self.scheduler = optim.lr_scheduler.StepLR(
|
|
self.optimizer,
|
|
step_size=config.get("step_size", 30),
|
|
gamma=config.get("gamma", 0.1),
|
|
)
|
|
else:
|
|
self.scheduler = None
|
|
|
|
# Training state
|
|
self.current_epoch = 0
|
|
self.best_val_loss = float("inf")
|
|
self.train_losses: List[float] = []
|
|
self.val_losses: List[float] = []
|
|
self.val_metrics: List[Dict] = []
|
|
|
|
# Create output directory
|
|
self.output_dir = Path(config.get("output_dir", "runs/homography"))
|
|
self.output_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
# TensorBoard writer
|
|
self.writer = SummaryWriter(log_dir=self.output_dir / "tensorboard")
|
|
|
|
# Save configuration
|
|
config_path = self.output_dir / "config.json"
|
|
with open(config_path, "w") as f:
|
|
json.dump(config, f, indent=2)
|
|
|
|
print(f"Training configuration saved to {config_path}")
|
|
print(
|
|
f"Model has {sum(p.numel() for p in self.model.parameters()):,} parameters"
|
|
)
|
|
|
|
def train_epoch(self) -> float:
|
|
"""
|
|
Train for one epoch.
|
|
|
|
Returns:
|
|
Average training loss for the epoch
|
|
"""
|
|
self.model.train()
|
|
total_loss = 0.0
|
|
num_batches = len(self.train_loader)
|
|
|
|
progress_bar = tqdm(self.train_loader, desc=f"Epoch {self.current_epoch + 1}")
|
|
for batch_idx, batch in enumerate(progress_bar):
|
|
# Move data to device
|
|
google_img = batch["google_img"].to(self.device)
|
|
yandex_img = batch["yandex_img"].to(self.device)
|
|
target_homography = batch["homography"].to(self.device)
|
|
|
|
# Forward pass
|
|
self.optimizer.zero_grad()
|
|
pred_homography = self.model(google_img, yandex_img, return_matrix=True)
|
|
|
|
# Compute loss
|
|
loss = self.criterion(
|
|
pred_homography,
|
|
target_homography,
|
|
google_img,
|
|
yandex_img,
|
|
)
|
|
|
|
# Backward pass
|
|
loss.backward()
|
|
|
|
# Gradient clipping
|
|
if self.config.get("grad_clip", 1.0) > 0:
|
|
torch.nn.utils.clip_grad_norm_(
|
|
self.model.parameters(),
|
|
self.config.get("grad_clip", 1.0),
|
|
)
|
|
|
|
# Optimizer step
|
|
self.optimizer.step()
|
|
|
|
# Update statistics
|
|
total_loss += loss.item()
|
|
|
|
# Update progress bar
|
|
progress_bar.set_postfix({"loss": loss.item()})
|
|
|
|
# Log batch loss to TensorBoard
|
|
global_step = self.current_epoch * num_batches + batch_idx
|
|
self.writer.add_scalar("train/batch_loss", loss.item(), global_step)
|
|
|
|
avg_loss = total_loss / num_batches
|
|
self.train_losses.append(avg_loss)
|
|
|
|
return avg_loss
|
|
|
|
@torch.no_grad()
|
|
def validate(self) -> Tuple[float, Dict]:
|
|
"""
|
|
Validate the model.
|
|
|
|
Returns:
|
|
Tuple of (average validation loss, validation metrics)
|
|
"""
|
|
self.model.eval()
|
|
total_loss = 0.0
|
|
all_metrics = []
|
|
|
|
progress_bar = tqdm(self.val_loader, desc="Validation")
|
|
for batch in progress_bar:
|
|
# Move data to device
|
|
google_img = batch["google_img"].to(self.device)
|
|
yandex_img = batch["yandex_img"].to(self.device)
|
|
target_homography = batch["homography"].to(self.device)
|
|
|
|
# Forward pass
|
|
pred_homography = self.model(google_img, yandex_img, return_matrix=True)
|
|
|
|
# Compute loss
|
|
loss = self.criterion(
|
|
pred_homography,
|
|
target_homography,
|
|
google_img,
|
|
yandex_img,
|
|
)
|
|
|
|
# Compute metrics
|
|
metrics = self.criterion.compute_metrics(pred_homography, target_homography)
|
|
|
|
# Update statistics
|
|
total_loss += loss.item()
|
|
all_metrics.append(metrics)
|
|
|
|
# Update progress bar
|
|
progress_bar.set_postfix({"loss": loss.item()})
|
|
|
|
avg_loss = total_loss / len(self.val_loader)
|
|
self.val_losses.append(avg_loss)
|
|
|
|
# Aggregate metrics
|
|
avg_metrics = {}
|
|
for key in all_metrics[0].keys():
|
|
avg_metrics[key] = np.mean([m[key] for m in all_metrics])
|
|
|
|
self.val_metrics.append(avg_metrics)
|
|
|
|
return avg_loss, avg_metrics
|
|
|
|
def save_checkpoint(self, is_best: bool = False):
|
|
"""Save model checkpoint."""
|
|
checkpoint = {
|
|
"epoch": self.current_epoch,
|
|
"model_state_dict": self.model.state_dict(),
|
|
"optimizer_state_dict": self.optimizer.state_dict(),
|
|
"train_losses": self.train_losses,
|
|
"val_losses": self.val_losses,
|
|
"val_metrics": self.val_metrics,
|
|
"best_val_loss": self.best_val_loss,
|
|
"config": self.config,
|
|
}
|
|
|
|
if self.scheduler is not None:
|
|
checkpoint["scheduler_state_dict"] = self.scheduler.state_dict()
|
|
|
|
# Save latest checkpoint
|
|
checkpoint_path = self.output_dir / "checkpoint_latest.pth"
|
|
torch.save(checkpoint, checkpoint_path)
|
|
|
|
# Save best checkpoint
|
|
if is_best:
|
|
best_path = self.output_dir / "checkpoint_best.pth"
|
|
torch.save(checkpoint, best_path)
|
|
print(f"Best model saved to {best_path}")
|
|
|
|
def load_checkpoint(self, checkpoint_path: str):
|
|
"""Load model checkpoint."""
|
|
checkpoint = torch.load(checkpoint_path, map_location=self.device)
|
|
|
|
self.model.load_state_dict(checkpoint["model_state_dict"])
|
|
self.optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
|
|
|
|
if self.scheduler is not None and "scheduler_state_dict" in checkpoint:
|
|
self.scheduler.load_state_dict(checkpoint["scheduler_state_dict"])
|
|
|
|
self.current_epoch = checkpoint["epoch"]
|
|
self.train_losses = checkpoint["train_losses"]
|
|
self.val_losses = checkpoint["val_losses"]
|
|
self.val_metrics = checkpoint["val_metrics"]
|
|
self.best_val_loss = checkpoint["best_val_loss"]
|
|
|
|
print(f"Loaded checkpoint from epoch {self.current_epoch}")
|
|
|
|
def train(self, num_epochs: int):
|
|
"""
|
|
Train the model for specified number of epochs.
|
|
|
|
Args:
|
|
num_epochs: Number of epochs to train
|
|
"""
|
|
print(f"Starting training for {num_epochs} epochs...")
|
|
start_time = time.time()
|
|
|
|
for epoch in range(num_epochs):
|
|
self.current_epoch = epoch
|
|
|
|
# Train for one epoch
|
|
train_loss = self.train_epoch()
|
|
|
|
# Validate
|
|
val_loss, val_metrics = self.validate()
|
|
|
|
# Update learning rate scheduler
|
|
if self.scheduler is not None:
|
|
if isinstance(self.scheduler, optim.lr_scheduler.ReduceLROnPlateau):
|
|
self.scheduler.step(val_loss)
|
|
else:
|
|
self.scheduler.step()
|
|
|
|
# Log to TensorBoard
|
|
self.writer.add_scalar("train/epoch_loss", train_loss, epoch)
|
|
self.writer.add_scalar("val/epoch_loss", val_loss, epoch)
|
|
for metric_name, metric_value in val_metrics.items():
|
|
self.writer.add_scalar(f"val/{metric_name}", metric_value, epoch)
|
|
|
|
# Print epoch summary
|
|
print(f"\nEpoch {epoch + 1}/{num_epochs}:")
|
|
print(f" Train Loss: {train_loss:.6f}")
|
|
print(f" Val Loss: {val_loss:.6f}")
|
|
print(" Val Metrics:")
|
|
for metric_name, metric_value in val_metrics.items():
|
|
print(f" {metric_name}: {metric_value:.6f}")
|
|
|
|
# Save checkpoint
|
|
is_best = val_loss < self.best_val_loss
|
|
if is_best:
|
|
self.best_val_loss = val_loss
|
|
|
|
self.save_checkpoint(is_best=is_best)
|
|
|
|
# Early stopping
|
|
if self.config.get("early_stopping_patience", 0) > 0:
|
|
if (
|
|
epoch - np.argmin(self.val_losses)
|
|
>= self.config["early_stopping_patience"]
|
|
):
|
|
print(f"Early stopping at epoch {epoch + 1}")
|
|
break
|
|
|
|
# Training completed
|
|
training_time = time.time() - start_time
|
|
print(f"\nTraining completed in {training_time:.2f} seconds")
|
|
print(f"Best validation loss: {self.best_val_loss:.6f}")
|
|
|
|
# Save final model
|
|
final_model_path = self.output_dir / "model_final.pth"
|
|
torch.save(self.model.state_dict(), final_model_path)
|
|
print(f"Final model saved to {final_model_path}")
|
|
|
|
# Close TensorBoard writer
|
|
self.writer.close()
|
|
|
|
def evaluate(self, test_loader: Optional[DataLoader] = None) -> Dict:
|
|
"""
|
|
Evaluate the model on test data.
|
|
|
|
Args:
|
|
test_loader: Test data loader (uses validation loader if None)
|
|
|
|
Returns:
|
|
Dictionary of evaluation metrics
|
|
"""
|
|
if test_loader is None:
|
|
test_loader = self.val_loader
|
|
|
|
self.model.eval()
|
|
all_metrics = []
|
|
|
|
print("Evaluating model...")
|
|
with torch.no_grad():
|
|
for batch in tqdm(test_loader):
|
|
# Move data to device
|
|
google_img = batch["google_img"].to(self.device)
|
|
yandex_img = batch["yandex_img"].to(self.device)
|
|
target_homography = batch["homography"].to(self.device)
|
|
|
|
# Forward pass
|
|
pred_homography = self.model(google_img, yandex_img, return_matrix=True)
|
|
|
|
# Compute metrics
|
|
metrics = self.criterion.compute_metrics(
|
|
pred_homography, target_homography
|
|
)
|
|
all_metrics.append(metrics)
|
|
|
|
# Aggregate metrics
|
|
avg_metrics = {}
|
|
for key in all_metrics[0].keys():
|
|
avg_metrics[key] = np.mean([m[key] for m in all_metrics])
|
|
|
|
# Print evaluation results
|
|
print("\nEvaluation Results:")
|
|
for metric_name, metric_value in avg_metrics.items():
|
|
print(f" {metric_name}: {metric_value:.6f}")
|
|
|
|
# Save evaluation results
|
|
eval_path = self.output_dir / "evaluation_results.json"
|
|
with open(eval_path, "w") as f:
|
|
json.dump(avg_metrics, f, indent=2)
|
|
print(f"Evaluation results saved to {eval_path}")
|
|
|
|
return avg_metrics
|
|
|
|
|
|
def parse_args():
|
|
"""Parse command line arguments."""
|
|
parser = argparse.ArgumentParser(description="Train homography estimation model")
|
|
|
|
# Data arguments
|
|
parser.add_argument(
|
|
"--data_dir",
|
|
type=str,
|
|
default=r"C:\Users\admin\Projects\autopilot\datasets\ya_go_maps\images",
|
|
help="Directory containing image pairs",
|
|
)
|
|
parser.add_argument(
|
|
"--batch_size", type=int, default=32, help="Batch size for training"
|
|
)
|
|
parser.add_argument(
|
|
"--image_size",
|
|
type=int,
|
|
nargs=2,
|
|
default=[256, 256],
|
|
help="Image size (height width)",
|
|
)
|
|
parser.add_argument(
|
|
"--train_split", type=float, default=0.8, help="Train/validation split ratio"
|
|
)
|
|
parser.add_argument(
|
|
"--num_workers", type=int, default=4, help="Number of data loader workers"
|
|
)
|
|
|
|
# Model arguments
|
|
parser.add_argument(
|
|
"--model_type", type=str, default="cnn", choices=["cnn"], help="Model type"
|
|
)
|
|
parser.add_argument(
|
|
"--hidden_channels", type=int, default=64, help="Number of hidden channels"
|
|
)
|
|
parser.add_argument(
|
|
"--num_blocks", type=int, default=4, help="Number of convolutional blocks"
|
|
)
|
|
parser.add_argument("--dropout_rate", type=float, default=0.3, help="Dropout rate")
|
|
parser.add_argument(
|
|
"--use_batch_norm", action="store_true", help="Use batch normalization"
|
|
)
|
|
|
|
# Training arguments
|
|
parser.add_argument("--epochs", type=int, default=100, help="Number of epochs")
|
|
parser.add_argument("--lr", type=float, default=1e-3, help="Learning rate")
|
|
parser.add_argument("--weight_decay", type=float, default=1e-4, help="Weight decay")
|
|
parser.add_argument(
|
|
"--optimizer", type=str, default="adam", choices=["adam", "adamw", "sgd"]
|
|
)
|
|
parser.add_argument(
|
|
"--scheduler",
|
|
type=str,
|
|
default="plateau",
|
|
choices=["plateau", "cosine", "step", "none"],
|
|
)
|
|
parser.add_argument(
|
|
"--grad_clip", type=float, default=1.0, help="Gradient clipping value"
|
|
)
|
|
|
|
# Loss arguments
|
|
parser.add_argument(
|
|
"--matrix_weight", type=float, default=1.0, help="Weight for matrix loss"
|
|
)
|
|
parser.add_argument(
|
|
"--geometric_weight",
|
|
type=float,
|
|
default=0.5,
|
|
help="Weight for geometric loss",
|
|
)
|
|
parser.add_argument(
|
|
"--reg_weight", type=float, default=0.1, help="Weight for regularization loss"
|
|
)
|
|
|
|
# Other arguments
|
|
parser.add_argument(
|
|
"--output_dir",
|
|
type=str,
|
|
default="runs/homography",
|
|
help="Output directory for checkpoints and logs",
|
|
)
|
|
parser.add_argument(
|
|
"--resume",
|
|
type=str,
|
|
help="Path to checkpoint to resume training from",
|
|
)
|
|
parser.add_argument(
|
|
"--eval_only",
|
|
action="store_true",
|
|
help="Only evaluate the model (no training)",
|
|
)
|
|
parser.add_argument(
|
|
"--seed", type=int, default=42, help="Random seed for reproducibility"
|
|
)
|
|
|
|
return parser.parse_args()
|
|
|
|
|
|
def main():
|
|
"""Main training function."""
|
|
args = parse_args()
|
|
|
|
# Set random seeds for reproducibility
|
|
torch.manual_seed(args.seed)
|
|
np.random.seed(args.seed)
|
|
if torch.cuda.is_available():
|
|
torch.cuda.manual_seed(args.seed)
|
|
torch.cuda.manual_seed_all(args.seed)
|
|
|
|
# Set device
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
print(f"Using device: {device}")
|
|
|
|
# Create data loaders
|
|
print("Creating data loaders...")
|
|
train_loader, val_loader = create_data_loaders(
|
|
root_dir=args.data_dir,
|
|
batch_size=args.batch_size,
|
|
train_split=args.train_split,
|
|
num_workers=args.num_workers,
|
|
image_size=tuple(args.image_size),
|
|
augment_train=True,
|
|
augment_val=False,
|
|
)
|
|
|
|
print(f"Train batches: {len(train_loader)}")
|
|
print(f"Val batches: {len(val_loader)}")
|
|
|
|
# Create model
|
|
print("Creating model...")
|
|
model = create_homography_model(
|
|
model_type=args.model_type,
|
|
input_size=tuple(args.image_size),
|
|
input_channels=3,
|
|
hidden_channels=args.hidden_channels,
|
|
num_blocks=args.num_blocks,
|
|
dropout_rate=args.dropout_rate,
|
|
use_batch_norm=args.use_batch_norm,
|
|
)
|
|
|
|
# Create trainer configuration
|
|
config = {
|
|
# Model config
|
|
"model_type": args.model_type,
|
|
"hidden_channels": args.hidden_channels,
|
|
"num_blocks": args.num_blocks,
|
|
"dropout_rate": args.dropout_rate,
|
|
"use_batch_norm": args.use_batch_norm,
|
|
"image_size": args.image_size,
|
|
# Training config
|
|
"epochs": args.epochs,
|
|
"batch_size": args.batch_size,
|
|
"learning_rate": args.lr,
|
|
"weight_decay": args.weight_decay,
|
|
"optimizer": args.optimizer,
|
|
"scheduler": args.scheduler,
|
|
"grad_clip": args.grad_clip,
|
|
# Loss config
|
|
"matrix_weight": args.matrix_weight,
|
|
"geometric_weight": args.geometric_weight,
|
|
"reg_weight": args.reg_weight,
|
|
"grid_size": 8,
|
|
# Data config
|
|
"data_dir": args.data_dir,
|
|
"train_split": args.train_split,
|
|
"num_workers": args.num_workers,
|
|
# Output config
|
|
"output_dir": args.output_dir,
|
|
"seed": args.seed,
|
|
}
|
|
|
|
# Create trainer
|
|
trainer = HomographyTrainer(
|
|
model=model,
|
|
train_loader=train_loader,
|
|
val_loader=val_loader,
|
|
device=device,
|
|
config=config,
|
|
)
|
|
|
|
# Resume from checkpoint if specified
|
|
if args.resume:
|
|
print(f"Resuming from checkpoint: {args.resume}")
|
|
trainer.load_checkpoint(args.resume)
|
|
|
|
# Evaluate only mode
|
|
if args.eval_only:
|
|
trainer.evaluate()
|
|
return
|
|
|
|
# Train the model
|
|
trainer.train(num_epochs=args.epochs)
|
|
|
|
# Final evaluation
|
|
print("\nPerforming final evaluation...")
|
|
trainer.evaluate()
|
|
|
|
print("\nTraining completed successfully!")
|
|
print(f"Results saved to: {args.output_dir}")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|