Files
autopilot/models/SiaN-similarity/train.py

276 lines
9.4 KiB
Python

"""
Training script for image similarity estimation.
"""
import argparse
import os
import time
from datetime import datetime
import torch
import torch.nn as nn
import torch.optim as optim
from dataloader import create_data_loaders
from model import SimilarityCNN, SimilarityLoss, create_similarity_model
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm
class SimilarityTrainer:
def __init__(
self,
model: nn.Module,
train_loader: DataLoader,
val_loader: DataLoader,
device: torch.device,
config: dict,
):
self.model = model.to(device)
self.train_loader = train_loader
self.val_loader = val_loader
self.device = device
self.config = config
self.criterion = SimilarityLoss()
self.optimizer = optim.Adam(
model.parameters(),
lr=config.get("learning_rate", 2e-4),
betas=(config.get("beta1", 0.5), config.get("beta2", 0.999)),
)
self.writer = None
self.best_val_loss = float("inf")
self.epochs_without_improvement = 0
def train_epoch(self, epoch: int) -> dict:
self.model.train()
total_loss = 0
total_samples = 0
pbar = tqdm(self.train_loader, desc=f"Epoch {epoch}")
for batch_idx, batch in enumerate(pbar):
google_img = batch["google_img"].to(self.device)
yandex_img = batch["yandex_img"].to(self.device)
target = batch["same_domain"].float().to(self.device).unsqueeze(1)
self.optimizer.zero_grad()
output = self.model(google_img, yandex_img)
loss = self.criterion(output, target)
loss.backward()
self.optimizer.step()
total_loss += loss.item() * google_img.size(0)
total_samples += google_img.size(0)
if batch_idx % self.config.get("log_interval", 10) == 0:
metrics = self.criterion.compute_metrics(output, target)
pbar.set_postfix(
{
"loss": loss.item(),
"acc": metrics["accuracy"],
}
)
if self.writer:
self.writer.add_scalar(
"train/loss",
loss.item(),
epoch * len(self.train_loader) + batch_idx,
)
self.writer.add_scalar(
"train/accuracy",
metrics["accuracy"],
epoch * len(self.train_loader) + batch_idx,
)
avg_loss = total_loss / total_samples
return {"loss": avg_loss}
def validate(self) -> dict:
self.model.eval()
total_loss = 0
total_samples = 0
all_metrics = []
with torch.no_grad():
for batch in tqdm(self.val_loader, desc="Validation"):
google_img = batch["google_img"].to(self.device)
yandex_img = batch["yandex_img"].to(self.device)
target = batch["same_domain"].float().to(self.device).unsqueeze(1)
output = self.model(google_img, yandex_img)
loss = self.criterion(output, target)
total_loss += loss.item() * google_img.size(0)
total_samples += google_img.size(0)
metrics = self.criterion.compute_metrics(output, target)
all_metrics.append(metrics)
avg_loss = total_loss / total_samples
avg_metrics = {}
for key in all_metrics[0].keys():
avg_metrics[key] = sum(m[key] for m in all_metrics) / len(all_metrics)
return {"loss": avg_loss, **avg_metrics}
def train(self, num_epochs: int):
log_dir = self.config.get("output_dir", "runs/similarity")
os.makedirs(log_dir, exist_ok=True)
self.writer = SummaryWriter(log_dir)
print(f"Starting training for {num_epochs} epochs")
print(f"Logging to: {log_dir}")
for epoch in range(1, num_epochs + 1):
print(f"\nEpoch {epoch}/{num_epochs}")
train_metrics = self.train_epoch(epoch)
val_metrics = self.validate()
print(f"Train Loss: {train_metrics['loss']:.4f}")
print(f"Val Loss: {val_metrics['loss']:.4f}")
print(f"Val Accuracy: {val_metrics['accuracy']:.4f}")
print(f"Val F1: {val_metrics['f1']:.4f}")
if self.writer:
self.writer.add_scalar("epoch/train_loss", train_metrics["loss"], epoch)
self.writer.add_scalar("epoch/val_loss", val_metrics["loss"], epoch)
self.writer.add_scalar(
"epoch/val_accuracy", val_metrics["accuracy"], epoch
)
if val_metrics["loss"] < self.best_val_loss:
self.best_val_loss = val_metrics["loss"]
self.epochs_without_improvement = 0
self.save_checkpoint(epoch, val_metrics["loss"], is_best=True)
print(f"New best model saved with val loss: {val_metrics['loss']:.4f}")
else:
self.epochs_without_improvement += 1
self.save_checkpoint(epoch, val_metrics["loss"], is_best=False)
patience = self.config.get("early_stopping_patience", 20)
if self.epochs_without_improvement >= patience:
print(
f"Early stopping triggered after {patience} epochs without improvement"
)
break
self.writer.close()
def save_checkpoint(self, epoch: int, val_loss: float, is_best: bool = False):
checkpoint_dir = os.path.join(
self.config.get("output_dir", "runs/similarity"), "checkpoints"
)
os.makedirs(checkpoint_dir, exist_ok=True)
checkpoint = {
"epoch": epoch,
"model_state_dict": self.model.state_dict(),
"optimizer_state_dict": self.optimizer.state_dict(),
"val_loss": val_loss,
"config": self.config,
}
checkpoint_path = os.path.join(checkpoint_dir, f"checkpoint_epoch_{epoch}.pt")
torch.save(checkpoint, checkpoint_path)
if is_best:
best_path = os.path.join(checkpoint_dir, "best_model.pt")
torch.save(checkpoint, best_path)
def load_checkpoint(self, checkpoint_path: str):
checkpoint = torch.load(checkpoint_path, map_location=self.device)
self.model.load_state_dict(checkpoint["model_state_dict"])
self.optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
return checkpoint["epoch"], checkpoint["val_loss"]
def main():
parser = argparse.ArgumentParser(description="Train similarity estimation model")
parser.add_argument(
"--data_dir",
type=str,
default=r"C:\Users\admin\Projects\autopilot\datasets\ya_go_maps\images",
)
parser.add_argument("--batch_size", type=int, default=32)
parser.add_argument("--epochs", type=int, default=100)
parser.add_argument("--learning_rate", type=float, default=2e-4)
parser.add_argument("--image_size", type=int, default=256)
parser.add_argument("--train_split", type=float, default=0.8)
parser.add_argument("--output_dir", type=str, default="runs/similarity")
parser.add_argument("--num_workers", type=int, default=0)
parser.add_argument(
"--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu"
)
args = parser.parse_args()
config = {
"data_dir": args.data_dir,
"batch_size": args.batch_size,
"epochs": args.epochs,
"learning_rate": args.learning_rate,
"image_size": (args.image_size, args.image_size),
"train_split": args.train_split,
"output_dir": args.output_dir,
"num_workers": args.num_workers,
"log_interval": 10,
"save_interval": 5,
"early_stopping_patience": 20,
"beta1": 0.5,
"beta2": 0.999,
}
device = torch.device(args.device)
print(f"Using device: {device}")
print("Creating data loaders...")
train_loader, val_loader = create_data_loaders(
root_dir=config["data_dir"],
batch_size=config["batch_size"],
train_split=config["train_split"],
num_workers=config["num_workers"],
image_size=config["image_size"],
augment_train=True,
augment_val=False,
device=device,
)
print(f"Train batches: {len(train_loader)}")
print(f"Val batches: {len(val_loader)}")
print("Creating model...")
model = create_similarity_model(
model_type="cnn",
input_size=config["image_size"],
input_channels=3,
hidden_channels=64,
num_blocks=4,
dropout_rate=0.3,
use_batch_norm=True,
)
print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")
trainer = SimilarityTrainer(
model=model,
train_loader=train_loader,
val_loader=val_loader,
device=device,
config=config,
)
print("Starting training...")
trainer.train(config["epochs"])
print("Training completed!")
if __name__ == "__main__":
main()