""" Training script for image similarity estimation. """ import argparse import os import time from datetime import datetime import torch import torch.nn as nn import torch.optim as optim from dataloader import create_data_loaders from model import SimilarityCNN, SimilarityLoss, create_similarity_model from torch.utils.data import DataLoader from torch.utils.tensorboard import SummaryWriter from tqdm import tqdm class SimilarityTrainer: def __init__( self, model: nn.Module, train_loader: DataLoader, val_loader: DataLoader, device: torch.device, config: dict, ): self.model = model.to(device) self.train_loader = train_loader self.val_loader = val_loader self.device = device self.config = config self.criterion = SimilarityLoss() self.optimizer = optim.Adam( model.parameters(), lr=config.get("learning_rate", 2e-4), betas=(config.get("beta1", 0.5), config.get("beta2", 0.999)), ) self.writer = None self.best_val_loss = float("inf") self.epochs_without_improvement = 0 def train_epoch(self, epoch: int) -> dict: self.model.train() total_loss = 0 total_samples = 0 pbar = tqdm(self.train_loader, desc=f"Epoch {epoch}") for batch_idx, batch in enumerate(pbar): google_img = batch["google_img"].to(self.device) yandex_img = batch["yandex_img"].to(self.device) target = batch["same_domain"].float().to(self.device).unsqueeze(1) self.optimizer.zero_grad() output = self.model(google_img, yandex_img) loss = self.criterion(output, target) loss.backward() self.optimizer.step() total_loss += loss.item() * google_img.size(0) total_samples += google_img.size(0) if batch_idx % self.config.get("log_interval", 10) == 0: metrics = self.criterion.compute_metrics(output, target) pbar.set_postfix( { "loss": loss.item(), "acc": metrics["accuracy"], } ) if self.writer: self.writer.add_scalar( "train/loss", loss.item(), epoch * len(self.train_loader) + batch_idx, ) self.writer.add_scalar( "train/accuracy", metrics["accuracy"], epoch * len(self.train_loader) + batch_idx, ) avg_loss = total_loss / total_samples return {"loss": avg_loss} def validate(self) -> dict: self.model.eval() total_loss = 0 total_samples = 0 all_metrics = [] with torch.no_grad(): for batch in tqdm(self.val_loader, desc="Validation"): google_img = batch["google_img"].to(self.device) yandex_img = batch["yandex_img"].to(self.device) target = batch["same_domain"].float().to(self.device).unsqueeze(1) output = self.model(google_img, yandex_img) loss = self.criterion(output, target) total_loss += loss.item() * google_img.size(0) total_samples += google_img.size(0) metrics = self.criterion.compute_metrics(output, target) all_metrics.append(metrics) avg_loss = total_loss / total_samples avg_metrics = {} for key in all_metrics[0].keys(): avg_metrics[key] = sum(m[key] for m in all_metrics) / len(all_metrics) return {"loss": avg_loss, **avg_metrics} def train(self, num_epochs: int): log_dir = self.config.get("output_dir", "runs/similarity") os.makedirs(log_dir, exist_ok=True) self.writer = SummaryWriter(log_dir) print(f"Starting training for {num_epochs} epochs") print(f"Logging to: {log_dir}") for epoch in range(1, num_epochs + 1): print(f"\nEpoch {epoch}/{num_epochs}") train_metrics = self.train_epoch(epoch) val_metrics = self.validate() print(f"Train Loss: {train_metrics['loss']:.4f}") print(f"Val Loss: {val_metrics['loss']:.4f}") print(f"Val Accuracy: {val_metrics['accuracy']:.4f}") print(f"Val F1: {val_metrics['f1']:.4f}") if self.writer: self.writer.add_scalar("epoch/train_loss", train_metrics["loss"], epoch) self.writer.add_scalar("epoch/val_loss", val_metrics["loss"], epoch) self.writer.add_scalar( "epoch/val_accuracy", val_metrics["accuracy"], epoch ) if val_metrics["loss"] < self.best_val_loss: self.best_val_loss = val_metrics["loss"] self.epochs_without_improvement = 0 self.save_checkpoint(epoch, val_metrics["loss"], is_best=True) print(f"New best model saved with val loss: {val_metrics['loss']:.4f}") else: self.epochs_without_improvement += 1 self.save_checkpoint(epoch, val_metrics["loss"], is_best=False) patience = self.config.get("early_stopping_patience", 20) if self.epochs_without_improvement >= patience: print( f"Early stopping triggered after {patience} epochs without improvement" ) break self.writer.close() def save_checkpoint(self, epoch: int, val_loss: float, is_best: bool = False): checkpoint_dir = os.path.join( self.config.get("output_dir", "runs/similarity"), "checkpoints" ) os.makedirs(checkpoint_dir, exist_ok=True) checkpoint = { "epoch": epoch, "model_state_dict": self.model.state_dict(), "optimizer_state_dict": self.optimizer.state_dict(), "val_loss": val_loss, "config": self.config, } checkpoint_path = os.path.join(checkpoint_dir, f"checkpoint_epoch_{epoch}.pt") torch.save(checkpoint, checkpoint_path) if is_best: best_path = os.path.join(checkpoint_dir, "best_model.pt") torch.save(checkpoint, best_path) def load_checkpoint(self, checkpoint_path: str): checkpoint = torch.load(checkpoint_path, map_location=self.device) self.model.load_state_dict(checkpoint["model_state_dict"]) self.optimizer.load_state_dict(checkpoint["optimizer_state_dict"]) return checkpoint["epoch"], checkpoint["val_loss"] def main(): parser = argparse.ArgumentParser(description="Train similarity estimation model") parser.add_argument( "--data_dir", type=str, default=r"C:\Users\admin\Projects\autopilot\datasets\ya_go_maps\images", ) parser.add_argument("--batch_size", type=int, default=32) parser.add_argument("--epochs", type=int, default=100) parser.add_argument("--learning_rate", type=float, default=2e-4) parser.add_argument("--image_size", type=int, default=256) parser.add_argument("--train_split", type=float, default=0.8) parser.add_argument("--output_dir", type=str, default="runs/similarity") parser.add_argument("--num_workers", type=int, default=0) parser.add_argument( "--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu" ) args = parser.parse_args() config = { "data_dir": args.data_dir, "batch_size": args.batch_size, "epochs": args.epochs, "learning_rate": args.learning_rate, "image_size": (args.image_size, args.image_size), "train_split": args.train_split, "output_dir": args.output_dir, "num_workers": args.num_workers, "log_interval": 10, "save_interval": 5, "early_stopping_patience": 20, "beta1": 0.5, "beta2": 0.999, } device = torch.device(args.device) print(f"Using device: {device}") print("Creating data loaders...") train_loader, val_loader = create_data_loaders( root_dir=config["data_dir"], batch_size=config["batch_size"], train_split=config["train_split"], num_workers=config["num_workers"], image_size=config["image_size"], augment_train=True, augment_val=False, device=device, ) print(f"Train batches: {len(train_loader)}") print(f"Val batches: {len(val_loader)}") print("Creating model...") model = create_similarity_model( model_type="cnn", input_size=config["image_size"], input_channels=3, hidden_channels=64, num_blocks=4, dropout_rate=0.3, use_batch_norm=True, ) print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}") trainer = SimilarityTrainer( model=model, train_loader=train_loader, val_loader=val_loader, device=device, config=config, ) print("Starting training...") trainer.train(config["epochs"]) print("Training completed!") if __name__ == "__main__": main()