feat: add SiaN model
This commit is contained in:
611
models/SiaN/train_homography.py
Normal file
611
models/SiaN/train_homography.py
Normal file
@@ -0,0 +1,611 @@
|
||||
"""
|
||||
Training script for homography estimation between Google and Yandex map images.
|
||||
|
||||
This script trains a CNN model to estimate homography matrices that align
|
||||
Google map images with Yandex map images.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
from homography import HomographyDataset, create_data_loaders
|
||||
from homography_cnn import HomographyCNN, HomographyLoss, create_homography_model
|
||||
from torch.utils.data import DataLoader
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
class HomographyTrainer:
|
||||
"""Trainer class for homography estimation model."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: nn.Module,
|
||||
train_loader: DataLoader,
|
||||
val_loader: DataLoader,
|
||||
device: torch.device,
|
||||
config: Dict,
|
||||
):
|
||||
"""
|
||||
Initialize the trainer.
|
||||
|
||||
Args:
|
||||
model: Homography estimation model
|
||||
train_loader: Training data loader
|
||||
val_loader: Validation data loader
|
||||
device: Device to run training on
|
||||
config: Training configuration dictionary
|
||||
"""
|
||||
self.model = model.to(device)
|
||||
self.train_loader = train_loader
|
||||
self.val_loader = val_loader
|
||||
self.device = device
|
||||
self.config = config
|
||||
|
||||
# Loss function
|
||||
self.criterion = HomographyLoss(
|
||||
matrix_weight=config.get("matrix_weight", 1.0),
|
||||
geometric_weight=config.get("geometric_weight", 0.5),
|
||||
reg_weight=config.get("reg_weight", 0.1),
|
||||
grid_size=config.get("grid_size", 8),
|
||||
).to(device)
|
||||
|
||||
# Optimizer
|
||||
optimizer_name = config.get("optimizer", "adam").lower()
|
||||
lr = config.get("learning_rate", 1e-3)
|
||||
weight_decay = config.get("weight_decay", 1e-4)
|
||||
|
||||
if optimizer_name == "adam":
|
||||
self.optimizer = optim.Adam(
|
||||
self.model.parameters(), lr=lr, weight_decay=weight_decay
|
||||
)
|
||||
elif optimizer_name == "adamw":
|
||||
self.optimizer = optim.AdamW(
|
||||
self.model.parameters(), lr=lr, weight_decay=weight_decay
|
||||
)
|
||||
elif optimizer_name == "sgd":
|
||||
self.optimizer = optim.SGD(
|
||||
self.model.parameters(),
|
||||
lr=lr,
|
||||
momentum=config.get("momentum", 0.9),
|
||||
weight_decay=weight_decay,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown optimizer: {optimizer_name}")
|
||||
|
||||
# Learning rate scheduler
|
||||
scheduler_name = config.get("scheduler", "plateau").lower()
|
||||
if scheduler_name == "plateau":
|
||||
self.scheduler = optim.lr_scheduler.ReduceLROnPlateau(
|
||||
self.optimizer,
|
||||
mode="min",
|
||||
factor=config.get("scheduler_factor", 0.5),
|
||||
patience=config.get("scheduler_patience", 5),
|
||||
verbose=True,
|
||||
)
|
||||
elif scheduler_name == "cosine":
|
||||
self.scheduler = optim.lr_scheduler.CosineAnnealingLR(
|
||||
self.optimizer,
|
||||
T_max=config.get("epochs", 100),
|
||||
eta_min=config.get("min_lr", 1e-6),
|
||||
)
|
||||
elif scheduler_name == "step":
|
||||
self.scheduler = optim.lr_scheduler.StepLR(
|
||||
self.optimizer,
|
||||
step_size=config.get("step_size", 30),
|
||||
gamma=config.get("gamma", 0.1),
|
||||
)
|
||||
else:
|
||||
self.scheduler = None
|
||||
|
||||
# Training state
|
||||
self.current_epoch = 0
|
||||
self.best_val_loss = float("inf")
|
||||
self.train_losses: List[float] = []
|
||||
self.val_losses: List[float] = []
|
||||
self.val_metrics: List[Dict] = []
|
||||
|
||||
# Create output directory
|
||||
self.output_dir = Path(config.get("output_dir", "runs/homography"))
|
||||
self.output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# TensorBoard writer
|
||||
self.writer = SummaryWriter(log_dir=self.output_dir / "tensorboard")
|
||||
|
||||
# Save configuration
|
||||
config_path = self.output_dir / "config.json"
|
||||
with open(config_path, "w") as f:
|
||||
json.dump(config, f, indent=2)
|
||||
|
||||
print(f"Training configuration saved to {config_path}")
|
||||
print(
|
||||
f"Model has {sum(p.numel() for p in self.model.parameters()):,} parameters"
|
||||
)
|
||||
|
||||
def train_epoch(self) -> float:
|
||||
"""
|
||||
Train for one epoch.
|
||||
|
||||
Returns:
|
||||
Average training loss for the epoch
|
||||
"""
|
||||
self.model.train()
|
||||
total_loss = 0.0
|
||||
num_batches = len(self.train_loader)
|
||||
|
||||
progress_bar = tqdm(self.train_loader, desc=f"Epoch {self.current_epoch + 1}")
|
||||
for batch_idx, batch in enumerate(progress_bar):
|
||||
# Move data to device
|
||||
google_img = batch["google_img"].to(self.device)
|
||||
yandex_img = batch["yandex_img"].to(self.device)
|
||||
target_homography = batch["homography"].to(self.device)
|
||||
|
||||
# Forward pass
|
||||
self.optimizer.zero_grad()
|
||||
pred_homography = self.model(google_img, yandex_img, return_matrix=True)
|
||||
|
||||
# Compute loss
|
||||
loss = self.criterion(
|
||||
pred_homography,
|
||||
target_homography,
|
||||
google_img,
|
||||
yandex_img,
|
||||
)
|
||||
|
||||
# Backward pass
|
||||
loss.backward()
|
||||
|
||||
# Gradient clipping
|
||||
if self.config.get("grad_clip", 1.0) > 0:
|
||||
torch.nn.utils.clip_grad_norm_(
|
||||
self.model.parameters(),
|
||||
self.config.get("grad_clip", 1.0),
|
||||
)
|
||||
|
||||
# Optimizer step
|
||||
self.optimizer.step()
|
||||
|
||||
# Update statistics
|
||||
total_loss += loss.item()
|
||||
|
||||
# Update progress bar
|
||||
progress_bar.set_postfix({"loss": loss.item()})
|
||||
|
||||
# Log batch loss to TensorBoard
|
||||
global_step = self.current_epoch * num_batches + batch_idx
|
||||
self.writer.add_scalar("train/batch_loss", loss.item(), global_step)
|
||||
|
||||
avg_loss = total_loss / num_batches
|
||||
self.train_losses.append(avg_loss)
|
||||
|
||||
return avg_loss
|
||||
|
||||
@torch.no_grad()
|
||||
def validate(self) -> Tuple[float, Dict]:
|
||||
"""
|
||||
Validate the model.
|
||||
|
||||
Returns:
|
||||
Tuple of (average validation loss, validation metrics)
|
||||
"""
|
||||
self.model.eval()
|
||||
total_loss = 0.0
|
||||
all_metrics = []
|
||||
|
||||
progress_bar = tqdm(self.val_loader, desc="Validation")
|
||||
for batch in progress_bar:
|
||||
# Move data to device
|
||||
google_img = batch["google_img"].to(self.device)
|
||||
yandex_img = batch["yandex_img"].to(self.device)
|
||||
target_homography = batch["homography"].to(self.device)
|
||||
|
||||
# Forward pass
|
||||
pred_homography = self.model(google_img, yandex_img, return_matrix=True)
|
||||
|
||||
# Compute loss
|
||||
loss = self.criterion(
|
||||
pred_homography,
|
||||
target_homography,
|
||||
google_img,
|
||||
yandex_img,
|
||||
)
|
||||
|
||||
# Compute metrics
|
||||
metrics = self.criterion.compute_metrics(pred_homography, target_homography)
|
||||
|
||||
# Update statistics
|
||||
total_loss += loss.item()
|
||||
all_metrics.append(metrics)
|
||||
|
||||
# Update progress bar
|
||||
progress_bar.set_postfix({"loss": loss.item()})
|
||||
|
||||
avg_loss = total_loss / len(self.val_loader)
|
||||
self.val_losses.append(avg_loss)
|
||||
|
||||
# Aggregate metrics
|
||||
avg_metrics = {}
|
||||
for key in all_metrics[0].keys():
|
||||
avg_metrics[key] = np.mean([m[key] for m in all_metrics])
|
||||
|
||||
self.val_metrics.append(avg_metrics)
|
||||
|
||||
return avg_loss, avg_metrics
|
||||
|
||||
def save_checkpoint(self, is_best: bool = False):
|
||||
"""Save model checkpoint."""
|
||||
checkpoint = {
|
||||
"epoch": self.current_epoch,
|
||||
"model_state_dict": self.model.state_dict(),
|
||||
"optimizer_state_dict": self.optimizer.state_dict(),
|
||||
"train_losses": self.train_losses,
|
||||
"val_losses": self.val_losses,
|
||||
"val_metrics": self.val_metrics,
|
||||
"best_val_loss": self.best_val_loss,
|
||||
"config": self.config,
|
||||
}
|
||||
|
||||
if self.scheduler is not None:
|
||||
checkpoint["scheduler_state_dict"] = self.scheduler.state_dict()
|
||||
|
||||
# Save latest checkpoint
|
||||
checkpoint_path = self.output_dir / "checkpoint_latest.pth"
|
||||
torch.save(checkpoint, checkpoint_path)
|
||||
|
||||
# Save best checkpoint
|
||||
if is_best:
|
||||
best_path = self.output_dir / "checkpoint_best.pth"
|
||||
torch.save(checkpoint, best_path)
|
||||
print(f"Best model saved to {best_path}")
|
||||
|
||||
def load_checkpoint(self, checkpoint_path: str):
|
||||
"""Load model checkpoint."""
|
||||
checkpoint = torch.load(checkpoint_path, map_location=self.device)
|
||||
|
||||
self.model.load_state_dict(checkpoint["model_state_dict"])
|
||||
self.optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
|
||||
|
||||
if self.scheduler is not None and "scheduler_state_dict" in checkpoint:
|
||||
self.scheduler.load_state_dict(checkpoint["scheduler_state_dict"])
|
||||
|
||||
self.current_epoch = checkpoint["epoch"]
|
||||
self.train_losses = checkpoint["train_losses"]
|
||||
self.val_losses = checkpoint["val_losses"]
|
||||
self.val_metrics = checkpoint["val_metrics"]
|
||||
self.best_val_loss = checkpoint["best_val_loss"]
|
||||
|
||||
print(f"Loaded checkpoint from epoch {self.current_epoch}")
|
||||
|
||||
def train(self, num_epochs: int):
|
||||
"""
|
||||
Train the model for specified number of epochs.
|
||||
|
||||
Args:
|
||||
num_epochs: Number of epochs to train
|
||||
"""
|
||||
print(f"Starting training for {num_epochs} epochs...")
|
||||
start_time = time.time()
|
||||
|
||||
for epoch in range(num_epochs):
|
||||
self.current_epoch = epoch
|
||||
|
||||
# Train for one epoch
|
||||
train_loss = self.train_epoch()
|
||||
|
||||
# Validate
|
||||
val_loss, val_metrics = self.validate()
|
||||
|
||||
# Update learning rate scheduler
|
||||
if self.scheduler is not None:
|
||||
if isinstance(self.scheduler, optim.lr_scheduler.ReduceLROnPlateau):
|
||||
self.scheduler.step(val_loss)
|
||||
else:
|
||||
self.scheduler.step()
|
||||
|
||||
# Log to TensorBoard
|
||||
self.writer.add_scalar("train/epoch_loss", train_loss, epoch)
|
||||
self.writer.add_scalar("val/epoch_loss", val_loss, epoch)
|
||||
for metric_name, metric_value in val_metrics.items():
|
||||
self.writer.add_scalar(f"val/{metric_name}", metric_value, epoch)
|
||||
|
||||
# Print epoch summary
|
||||
print(f"\nEpoch {epoch + 1}/{num_epochs}:")
|
||||
print(f" Train Loss: {train_loss:.6f}")
|
||||
print(f" Val Loss: {val_loss:.6f}")
|
||||
print(" Val Metrics:")
|
||||
for metric_name, metric_value in val_metrics.items():
|
||||
print(f" {metric_name}: {metric_value:.6f}")
|
||||
|
||||
# Save checkpoint
|
||||
is_best = val_loss < self.best_val_loss
|
||||
if is_best:
|
||||
self.best_val_loss = val_loss
|
||||
|
||||
self.save_checkpoint(is_best=is_best)
|
||||
|
||||
# Early stopping
|
||||
if self.config.get("early_stopping_patience", 0) > 0:
|
||||
if (
|
||||
epoch - np.argmin(self.val_losses)
|
||||
>= self.config["early_stopping_patience"]
|
||||
):
|
||||
print(f"Early stopping at epoch {epoch + 1}")
|
||||
break
|
||||
|
||||
# Training completed
|
||||
training_time = time.time() - start_time
|
||||
print(f"\nTraining completed in {training_time:.2f} seconds")
|
||||
print(f"Best validation loss: {self.best_val_loss:.6f}")
|
||||
|
||||
# Save final model
|
||||
final_model_path = self.output_dir / "model_final.pth"
|
||||
torch.save(self.model.state_dict(), final_model_path)
|
||||
print(f"Final model saved to {final_model_path}")
|
||||
|
||||
# Close TensorBoard writer
|
||||
self.writer.close()
|
||||
|
||||
def evaluate(self, test_loader: Optional[DataLoader] = None) -> Dict:
|
||||
"""
|
||||
Evaluate the model on test data.
|
||||
|
||||
Args:
|
||||
test_loader: Test data loader (uses validation loader if None)
|
||||
|
||||
Returns:
|
||||
Dictionary of evaluation metrics
|
||||
"""
|
||||
if test_loader is None:
|
||||
test_loader = self.val_loader
|
||||
|
||||
self.model.eval()
|
||||
all_metrics = []
|
||||
|
||||
print("Evaluating model...")
|
||||
with torch.no_grad():
|
||||
for batch in tqdm(test_loader):
|
||||
# Move data to device
|
||||
google_img = batch["google_img"].to(self.device)
|
||||
yandex_img = batch["yandex_img"].to(self.device)
|
||||
target_homography = batch["homography"].to(self.device)
|
||||
|
||||
# Forward pass
|
||||
pred_homography = self.model(google_img, yandex_img, return_matrix=True)
|
||||
|
||||
# Compute metrics
|
||||
metrics = self.criterion.compute_metrics(
|
||||
pred_homography, target_homography
|
||||
)
|
||||
all_metrics.append(metrics)
|
||||
|
||||
# Aggregate metrics
|
||||
avg_metrics = {}
|
||||
for key in all_metrics[0].keys():
|
||||
avg_metrics[key] = np.mean([m[key] for m in all_metrics])
|
||||
|
||||
# Print evaluation results
|
||||
print("\nEvaluation Results:")
|
||||
for metric_name, metric_value in avg_metrics.items():
|
||||
print(f" {metric_name}: {metric_value:.6f}")
|
||||
|
||||
# Save evaluation results
|
||||
eval_path = self.output_dir / "evaluation_results.json"
|
||||
with open(eval_path, "w") as f:
|
||||
json.dump(avg_metrics, f, indent=2)
|
||||
print(f"Evaluation results saved to {eval_path}")
|
||||
|
||||
return avg_metrics
|
||||
|
||||
|
||||
def parse_args():
|
||||
"""Parse command line arguments."""
|
||||
parser = argparse.ArgumentParser(description="Train homography estimation model")
|
||||
|
||||
# Data arguments
|
||||
parser.add_argument(
|
||||
"--data_dir",
|
||||
type=str,
|
||||
default=r"C:\Users\admin\Projects\autopilot\datasets\ya_go_maps\images",
|
||||
help="Directory containing image pairs",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--batch_size", type=int, default=32, help="Batch size for training"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--image_size",
|
||||
type=int,
|
||||
nargs=2,
|
||||
default=[256, 256],
|
||||
help="Image size (height width)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--train_split", type=float, default=0.8, help="Train/validation split ratio"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num_workers", type=int, default=4, help="Number of data loader workers"
|
||||
)
|
||||
|
||||
# Model arguments
|
||||
parser.add_argument(
|
||||
"--model_type", type=str, default="cnn", choices=["cnn"], help="Model type"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--hidden_channels", type=int, default=64, help="Number of hidden channels"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num_blocks", type=int, default=4, help="Number of convolutional blocks"
|
||||
)
|
||||
parser.add_argument("--dropout_rate", type=float, default=0.3, help="Dropout rate")
|
||||
parser.add_argument(
|
||||
"--use_batch_norm", action="store_true", help="Use batch normalization"
|
||||
)
|
||||
|
||||
# Training arguments
|
||||
parser.add_argument("--epochs", type=int, default=100, help="Number of epochs")
|
||||
parser.add_argument("--lr", type=float, default=1e-3, help="Learning rate")
|
||||
parser.add_argument("--weight_decay", type=float, default=1e-4, help="Weight decay")
|
||||
parser.add_argument(
|
||||
"--optimizer", type=str, default="adam", choices=["adam", "adamw", "sgd"]
|
||||
)
|
||||
parser.add_argument(
|
||||
"--scheduler",
|
||||
type=str,
|
||||
default="plateau",
|
||||
choices=["plateau", "cosine", "step", "none"],
|
||||
)
|
||||
parser.add_argument(
|
||||
"--grad_clip", type=float, default=1.0, help="Gradient clipping value"
|
||||
)
|
||||
|
||||
# Loss arguments
|
||||
parser.add_argument(
|
||||
"--matrix_weight", type=float, default=1.0, help="Weight for matrix loss"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--geometric_weight",
|
||||
type=float,
|
||||
default=0.5,
|
||||
help="Weight for geometric loss",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--reg_weight", type=float, default=0.1, help="Weight for regularization loss"
|
||||
)
|
||||
|
||||
# Other arguments
|
||||
parser.add_argument(
|
||||
"--output_dir",
|
||||
type=str,
|
||||
default="runs/homography",
|
||||
help="Output directory for checkpoints and logs",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--resume",
|
||||
type=str,
|
||||
help="Path to checkpoint to resume training from",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--eval_only",
|
||||
action="store_true",
|
||||
help="Only evaluate the model (no training)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--seed", type=int, default=42, help="Random seed for reproducibility"
|
||||
)
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def main():
|
||||
"""Main training function."""
|
||||
args = parse_args()
|
||||
|
||||
# Set random seeds for reproducibility
|
||||
torch.manual_seed(args.seed)
|
||||
np.random.seed(args.seed)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed(args.seed)
|
||||
torch.cuda.manual_seed_all(args.seed)
|
||||
|
||||
# Set device
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
print(f"Using device: {device}")
|
||||
|
||||
# Create data loaders
|
||||
print("Creating data loaders...")
|
||||
train_loader, val_loader = create_data_loaders(
|
||||
root_dir=args.data_dir,
|
||||
batch_size=args.batch_size,
|
||||
train_split=args.train_split,
|
||||
num_workers=args.num_workers,
|
||||
image_size=tuple(args.image_size),
|
||||
augment_train=True,
|
||||
augment_val=False,
|
||||
)
|
||||
|
||||
print(f"Train batches: {len(train_loader)}")
|
||||
print(f"Val batches: {len(val_loader)}")
|
||||
|
||||
# Create model
|
||||
print("Creating model...")
|
||||
model = create_homography_model(
|
||||
model_type=args.model_type,
|
||||
input_size=tuple(args.image_size),
|
||||
input_channels=3,
|
||||
hidden_channels=args.hidden_channels,
|
||||
num_blocks=args.num_blocks,
|
||||
dropout_rate=args.dropout_rate,
|
||||
use_batch_norm=args.use_batch_norm,
|
||||
)
|
||||
|
||||
# Create trainer configuration
|
||||
config = {
|
||||
# Model config
|
||||
"model_type": args.model_type,
|
||||
"hidden_channels": args.hidden_channels,
|
||||
"num_blocks": args.num_blocks,
|
||||
"dropout_rate": args.dropout_rate,
|
||||
"use_batch_norm": args.use_batch_norm,
|
||||
"image_size": args.image_size,
|
||||
# Training config
|
||||
"epochs": args.epochs,
|
||||
"batch_size": args.batch_size,
|
||||
"learning_rate": args.lr,
|
||||
"weight_decay": args.weight_decay,
|
||||
"optimizer": args.optimizer,
|
||||
"scheduler": args.scheduler,
|
||||
"grad_clip": args.grad_clip,
|
||||
# Loss config
|
||||
"matrix_weight": args.matrix_weight,
|
||||
"geometric_weight": args.geometric_weight,
|
||||
"reg_weight": args.reg_weight,
|
||||
"grid_size": 8,
|
||||
# Data config
|
||||
"data_dir": args.data_dir,
|
||||
"train_split": args.train_split,
|
||||
"num_workers": args.num_workers,
|
||||
# Output config
|
||||
"output_dir": args.output_dir,
|
||||
"seed": args.seed,
|
||||
}
|
||||
|
||||
# Create trainer
|
||||
trainer = HomographyTrainer(
|
||||
model=model,
|
||||
train_loader=train_loader,
|
||||
val_loader=val_loader,
|
||||
device=device,
|
||||
config=config,
|
||||
)
|
||||
|
||||
# Resume from checkpoint if specified
|
||||
if args.resume:
|
||||
print(f"Resuming from checkpoint: {args.resume}")
|
||||
trainer.load_checkpoint(args.resume)
|
||||
|
||||
# Evaluate only mode
|
||||
if args.eval_only:
|
||||
trainer.evaluate()
|
||||
return
|
||||
|
||||
# Train the model
|
||||
trainer.train(num_epochs=args.epochs)
|
||||
|
||||
# Final evaluation
|
||||
print("\nPerforming final evaluation...")
|
||||
trainer.evaluate()
|
||||
|
||||
print("\nTraining completed successfully!")
|
||||
print(f"Results saved to: {args.output_dir}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user