Files

249 lines
8.5 KiB
Python

"""
Training script for image similarity estimation.
"""
import os
import time
from datetime import datetime
import torch
import torch.nn as nn
import torch.optim as optim
from dataloader import config, create_data_loaders
from model import SimilarityCNN, SimilarityLoss, create_similarity_model
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm
class SimilarityTrainer:
def __init__(
self,
model: nn.Module,
train_loader: DataLoader,
val_loader: DataLoader,
device: torch.device,
config: dict,
):
self.model = model.to(device)
self.train_loader = train_loader
self.val_loader = val_loader
self.device = device
self.config = config
self.criterion = SimilarityLoss()
self.optimizer = optim.Adam(
model.parameters(),
lr=config.get("learning_rate", 2e-4),
betas=(config.get("beta1", 0.5), config.get("beta2", 0.999)),
)
self.writer = None
self.best_val_loss = float("inf")
self.epochs_without_improvement = 0
def train_epoch(self, epoch: int) -> dict:
self.model.train()
total_loss = 0
total_samples = 0
pbar = tqdm(self.train_loader, desc=f"Epoch {epoch}")
for batch_idx, batch in enumerate(pbar):
google_img = batch["google_img"].to(self.device)
yandex_img = batch["yandex_img"].to(self.device)
target = batch["same_domain"].float().to(self.device).unsqueeze(1)
self.optimizer.zero_grad()
output = self.model(google_img, yandex_img)
loss = self.criterion(output, target)
loss.backward()
self.optimizer.step()
total_loss += loss.item() * google_img.size(0)
total_samples += google_img.size(0)
if batch_idx % self.config.get("log_interval", 10) == 0:
metrics = self.criterion.compute_metrics(output, target)
pbar.set_postfix(
{
"loss": loss.item(),
"acc": metrics["accuracy"],
}
)
if self.writer:
self.writer.add_scalar(
"train/loss",
loss.item(),
epoch * len(self.train_loader) + batch_idx,
)
self.writer.add_scalar(
"train/accuracy",
metrics["accuracy"],
epoch * len(self.train_loader) + batch_idx,
)
avg_loss = total_loss / total_samples
return {"loss": avg_loss}
def validate(self) -> dict:
self.model.eval()
total_loss = 0
total_samples = 0
all_metrics = []
with torch.no_grad():
for batch in tqdm(self.val_loader, desc="Validation"):
google_img = batch["google_img"].to(self.device)
yandex_img = batch["yandex_img"].to(self.device)
target = batch["same_domain"].float().to(self.device).unsqueeze(1)
output = self.model(google_img, yandex_img)
loss = self.criterion(output, target)
total_loss += loss.item() * google_img.size(0)
total_samples += google_img.size(0)
metrics = self.criterion.compute_metrics(output, target)
all_metrics.append(metrics)
avg_loss = total_loss / total_samples
avg_metrics = {}
for key in all_metrics[0].keys():
avg_metrics[key] = sum(m[key] for m in all_metrics) / len(all_metrics)
return {"loss": avg_loss, **avg_metrics}
def train(self, num_epochs: int):
log_dir = self.config.get("output_dir", "runs/similarity")
os.makedirs(log_dir, exist_ok=True)
self.writer = SummaryWriter(log_dir)
print(f"Starting training for {num_epochs} epochs")
print(f"Logging to: {log_dir}")
for epoch in range(1, num_epochs + 1):
print(f"\nEpoch {epoch}/{num_epochs}")
train_metrics = self.train_epoch(epoch)
val_metrics = self.validate()
print(f"Train Loss: {train_metrics['loss']:.4f}")
print(f"Val Loss: {val_metrics['loss']:.4f}")
print(f"Val Accuracy: {val_metrics['accuracy']:.4f}")
print(f"Val F1: {val_metrics['f1']:.4f}")
if self.writer:
self.writer.add_scalar("epoch/train_loss", train_metrics["loss"], epoch)
self.writer.add_scalar("epoch/val_loss", val_metrics["loss"], epoch)
self.writer.add_scalar(
"epoch/val_accuracy", val_metrics["accuracy"], epoch
)
if val_metrics["loss"] < self.best_val_loss:
self.best_val_loss = val_metrics["loss"]
self.epochs_without_improvement = 0
self.save_checkpoint(epoch, val_metrics["loss"], is_best=True)
print(f"New best model saved with val loss: {val_metrics['loss']:.4f}")
else:
self.epochs_without_improvement += 1
self.save_checkpoint(epoch, val_metrics["loss"], is_best=False)
patience = self.config.get("early_stopping_patience", 20)
if self.epochs_without_improvement >= patience:
print(
f"Early stopping triggered after {patience} epochs without improvement"
)
break
self.writer.close()
def save_checkpoint(self, epoch: int, val_loss: float, is_best: bool = False):
checkpoint_dir = os.path.join(
self.config.get("output_dir", "runs/similarity"), "checkpoints"
)
os.makedirs(checkpoint_dir, exist_ok=True)
checkpoint = {
"epoch": epoch,
"model_state_dict": self.model.state_dict(),
"optimizer_state_dict": self.optimizer.state_dict(),
"val_loss": val_loss,
"config": self.config,
}
checkpoint_path = os.path.join(checkpoint_dir, f"checkpoint_epoch_{epoch}.pt")
torch.save(checkpoint, checkpoint_path)
if is_best:
best_path = os.path.join(checkpoint_dir, "best_model.pt")
torch.save(checkpoint, best_path)
def load_checkpoint(self, checkpoint_path: str):
checkpoint = torch.load(checkpoint_path, map_location=self.device)
self.model.load_state_dict(checkpoint["model_state_dict"])
self.optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
return checkpoint["epoch"], checkpoint["val_loss"]
def main():
# Use config from dataloader.py
config_dict = config.copy()
# Ensure image_size is tuple
if isinstance(config_dict.get("image_size"), list):
config_dict["image_size"] = tuple(config_dict["image_size"])
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
print("Creating data loaders...")
train_loader, val_loader = create_data_loaders(
root_dir=config_dict["data_dir"],
batch_size=config_dict["batch_size"],
train_split=config_dict["train_split"],
num_workers=config_dict["num_workers"],
image_size=config_dict["image_size"],
augment_train=True,
augment_val=False,
device=device,
)
print(f"Train batches: {len(train_loader)}")
print(f"Val batches: {len(val_loader)}")
print("Creating model...")
model = create_similarity_model(
model_type="cnn",
input_size=config_dict["image_size"][0]
if isinstance(config_dict["image_size"], (tuple, list))
else config_dict["image_size"],
input_channels=3,
hidden_channels=64,
num_blocks=4,
dropout_rate=0.3,
use_batch_norm=True,
)
print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")
trainer = SimilarityTrainer(
model=model,
train_loader=train_loader,
val_loader=val_loader,
device=device,
config=config_dict,
)
print("Starting training...")
trainer.train(config_dict["epochs"])
print("Training completed!")
if __name__ == "__main__":
main()