""" Training script for image similarity estimation. """ import os import time from datetime import datetime import torch import torch.nn as nn import torch.optim as optim from dataloader import config, create_data_loaders from model import SimilarityCNN, SimilarityLoss, create_similarity_model from torch.utils.data import DataLoader from torch.utils.tensorboard import SummaryWriter from tqdm import tqdm class SimilarityTrainer: def __init__( self, model: nn.Module, train_loader: DataLoader, val_loader: DataLoader, device: torch.device, config: dict, ): self.model = model.to(device) self.train_loader = train_loader self.val_loader = val_loader self.device = device self.config = config self.criterion = SimilarityLoss() self.optimizer = optim.Adam( model.parameters(), lr=config.get("learning_rate", 2e-4), betas=(config.get("beta1", 0.5), config.get("beta2", 0.999)), ) self.writer = None self.best_val_loss = float("inf") self.epochs_without_improvement = 0 def train_epoch(self, epoch: int) -> dict: self.model.train() total_loss = 0 total_samples = 0 pbar = tqdm(self.train_loader, desc=f"Epoch {epoch}") for batch_idx, batch in enumerate(pbar): google_img = batch["google_img"].to(self.device) yandex_img = batch["yandex_img"].to(self.device) target = batch["same_domain"].float().to(self.device).unsqueeze(1) self.optimizer.zero_grad() output = self.model(google_img, yandex_img) loss = self.criterion(output, target) loss.backward() self.optimizer.step() total_loss += loss.item() * google_img.size(0) total_samples += google_img.size(0) if batch_idx % self.config.get("log_interval", 10) == 0: metrics = self.criterion.compute_metrics(output, target) pbar.set_postfix( { "loss": loss.item(), "acc": metrics["accuracy"], } ) if self.writer: self.writer.add_scalar( "train/loss", loss.item(), epoch * len(self.train_loader) + batch_idx, ) self.writer.add_scalar( "train/accuracy", metrics["accuracy"], epoch * len(self.train_loader) + batch_idx, ) avg_loss = total_loss / total_samples return {"loss": avg_loss} def validate(self) -> dict: self.model.eval() total_loss = 0 total_samples = 0 all_metrics = [] with torch.no_grad(): for batch in tqdm(self.val_loader, desc="Validation"): google_img = batch["google_img"].to(self.device) yandex_img = batch["yandex_img"].to(self.device) target = batch["same_domain"].float().to(self.device).unsqueeze(1) output = self.model(google_img, yandex_img) loss = self.criterion(output, target) total_loss += loss.item() * google_img.size(0) total_samples += google_img.size(0) metrics = self.criterion.compute_metrics(output, target) all_metrics.append(metrics) avg_loss = total_loss / total_samples avg_metrics = {} for key in all_metrics[0].keys(): avg_metrics[key] = sum(m[key] for m in all_metrics) / len(all_metrics) return {"loss": avg_loss, **avg_metrics} def train(self, num_epochs: int): log_dir = self.config.get("output_dir", "runs/similarity") os.makedirs(log_dir, exist_ok=True) self.writer = SummaryWriter(log_dir) print(f"Starting training for {num_epochs} epochs") print(f"Logging to: {log_dir}") for epoch in range(1, num_epochs + 1): print(f"\nEpoch {epoch}/{num_epochs}") train_metrics = self.train_epoch(epoch) val_metrics = self.validate() print(f"Train Loss: {train_metrics['loss']:.4f}") print(f"Val Loss: {val_metrics['loss']:.4f}") print(f"Val Accuracy: {val_metrics['accuracy']:.4f}") print(f"Val F1: {val_metrics['f1']:.4f}") if self.writer: self.writer.add_scalar("epoch/train_loss", train_metrics["loss"], epoch) self.writer.add_scalar("epoch/val_loss", val_metrics["loss"], epoch) self.writer.add_scalar( "epoch/val_accuracy", val_metrics["accuracy"], epoch ) if val_metrics["loss"] < self.best_val_loss: self.best_val_loss = val_metrics["loss"] self.epochs_without_improvement = 0 self.save_checkpoint(epoch, val_metrics["loss"], is_best=True) print(f"New best model saved with val loss: {val_metrics['loss']:.4f}") else: self.epochs_without_improvement += 1 self.save_checkpoint(epoch, val_metrics["loss"], is_best=False) patience = self.config.get("early_stopping_patience", 20) if self.epochs_without_improvement >= patience: print( f"Early stopping triggered after {patience} epochs without improvement" ) break self.writer.close() def save_checkpoint(self, epoch: int, val_loss: float, is_best: bool = False): checkpoint_dir = os.path.join( self.config.get("output_dir", "runs/similarity"), "checkpoints" ) os.makedirs(checkpoint_dir, exist_ok=True) checkpoint = { "epoch": epoch, "model_state_dict": self.model.state_dict(), "optimizer_state_dict": self.optimizer.state_dict(), "val_loss": val_loss, "config": self.config, } checkpoint_path = os.path.join(checkpoint_dir, f"checkpoint_epoch_{epoch}.pt") torch.save(checkpoint, checkpoint_path) if is_best: best_path = os.path.join(checkpoint_dir, "best_model.pt") torch.save(checkpoint, best_path) def load_checkpoint(self, checkpoint_path: str): checkpoint = torch.load(checkpoint_path, map_location=self.device) self.model.load_state_dict(checkpoint["model_state_dict"]) self.optimizer.load_state_dict(checkpoint["optimizer_state_dict"]) return checkpoint["epoch"], checkpoint["val_loss"] def main(): # Use config from dataloader.py config_dict = config.copy() # Ensure image_size is tuple if isinstance(config_dict.get("image_size"), list): config_dict["image_size"] = tuple(config_dict["image_size"]) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"Using device: {device}") print("Creating data loaders...") train_loader, val_loader = create_data_loaders( root_dir=config_dict["data_dir"], batch_size=config_dict["batch_size"], train_split=config_dict["train_split"], num_workers=config_dict["num_workers"], image_size=config_dict["image_size"], augment_train=True, augment_val=False, device=device, ) print(f"Train batches: {len(train_loader)}") print(f"Val batches: {len(val_loader)}") print("Creating model...") model = create_similarity_model( model_type="cnn", input_size=config_dict["image_size"][0] if isinstance(config_dict["image_size"], (tuple, list)) else config_dict["image_size"], input_channels=3, hidden_channels=64, num_blocks=4, dropout_rate=0.3, use_batch_norm=True, ) print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}") trainer = SimilarityTrainer( model=model, train_loader=train_loader, val_loader=val_loader, device=device, config=config_dict, ) print("Starting training...") trainer.train(config_dict["epochs"]) print("Training completed!") if __name__ == "__main__": main()