276 lines
9.4 KiB
Python
276 lines
9.4 KiB
Python
"""
|
|
Training script for image similarity estimation.
|
|
"""
|
|
|
|
import argparse
|
|
import os
|
|
import time
|
|
from datetime import datetime
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.optim as optim
|
|
from dataloader import create_data_loaders
|
|
from model import SimilarityCNN, SimilarityLoss, create_similarity_model
|
|
from torch.utils.data import DataLoader
|
|
from torch.utils.tensorboard import SummaryWriter
|
|
from tqdm import tqdm
|
|
|
|
|
|
class SimilarityTrainer:
|
|
def __init__(
|
|
self,
|
|
model: nn.Module,
|
|
train_loader: DataLoader,
|
|
val_loader: DataLoader,
|
|
device: torch.device,
|
|
config: dict,
|
|
):
|
|
self.model = model.to(device)
|
|
self.train_loader = train_loader
|
|
self.val_loader = val_loader
|
|
self.device = device
|
|
self.config = config
|
|
|
|
self.criterion = SimilarityLoss()
|
|
self.optimizer = optim.Adam(
|
|
model.parameters(),
|
|
lr=config.get("learning_rate", 2e-4),
|
|
betas=(config.get("beta1", 0.5), config.get("beta2", 0.999)),
|
|
)
|
|
|
|
self.writer = None
|
|
self.best_val_loss = float("inf")
|
|
self.epochs_without_improvement = 0
|
|
|
|
def train_epoch(self, epoch: int) -> dict:
|
|
self.model.train()
|
|
total_loss = 0
|
|
total_samples = 0
|
|
|
|
pbar = tqdm(self.train_loader, desc=f"Epoch {epoch}")
|
|
for batch_idx, batch in enumerate(pbar):
|
|
google_img = batch["google_img"].to(self.device)
|
|
yandex_img = batch["yandex_img"].to(self.device)
|
|
target = batch["same_domain"].float().to(self.device).unsqueeze(1)
|
|
|
|
self.optimizer.zero_grad()
|
|
|
|
output = self.model(google_img, yandex_img)
|
|
loss = self.criterion(output, target)
|
|
|
|
loss.backward()
|
|
self.optimizer.step()
|
|
|
|
total_loss += loss.item() * google_img.size(0)
|
|
total_samples += google_img.size(0)
|
|
|
|
if batch_idx % self.config.get("log_interval", 10) == 0:
|
|
metrics = self.criterion.compute_metrics(output, target)
|
|
pbar.set_postfix(
|
|
{
|
|
"loss": loss.item(),
|
|
"acc": metrics["accuracy"],
|
|
}
|
|
)
|
|
|
|
if self.writer:
|
|
self.writer.add_scalar(
|
|
"train/loss",
|
|
loss.item(),
|
|
epoch * len(self.train_loader) + batch_idx,
|
|
)
|
|
self.writer.add_scalar(
|
|
"train/accuracy",
|
|
metrics["accuracy"],
|
|
epoch * len(self.train_loader) + batch_idx,
|
|
)
|
|
|
|
avg_loss = total_loss / total_samples
|
|
return {"loss": avg_loss}
|
|
|
|
def validate(self) -> dict:
|
|
self.model.eval()
|
|
total_loss = 0
|
|
total_samples = 0
|
|
all_metrics = []
|
|
|
|
with torch.no_grad():
|
|
for batch in tqdm(self.val_loader, desc="Validation"):
|
|
google_img = batch["google_img"].to(self.device)
|
|
yandex_img = batch["yandex_img"].to(self.device)
|
|
target = batch["same_domain"].float().to(self.device).unsqueeze(1)
|
|
|
|
output = self.model(google_img, yandex_img)
|
|
loss = self.criterion(output, target)
|
|
|
|
total_loss += loss.item() * google_img.size(0)
|
|
total_samples += google_img.size(0)
|
|
|
|
metrics = self.criterion.compute_metrics(output, target)
|
|
all_metrics.append(metrics)
|
|
|
|
avg_loss = total_loss / total_samples
|
|
|
|
avg_metrics = {}
|
|
for key in all_metrics[0].keys():
|
|
avg_metrics[key] = sum(m[key] for m in all_metrics) / len(all_metrics)
|
|
|
|
return {"loss": avg_loss, **avg_metrics}
|
|
|
|
def train(self, num_epochs: int):
|
|
log_dir = self.config.get("output_dir", "runs/similarity")
|
|
os.makedirs(log_dir, exist_ok=True)
|
|
self.writer = SummaryWriter(log_dir)
|
|
|
|
print(f"Starting training for {num_epochs} epochs")
|
|
print(f"Logging to: {log_dir}")
|
|
|
|
for epoch in range(1, num_epochs + 1):
|
|
print(f"\nEpoch {epoch}/{num_epochs}")
|
|
|
|
train_metrics = self.train_epoch(epoch)
|
|
val_metrics = self.validate()
|
|
|
|
print(f"Train Loss: {train_metrics['loss']:.4f}")
|
|
print(f"Val Loss: {val_metrics['loss']:.4f}")
|
|
print(f"Val Accuracy: {val_metrics['accuracy']:.4f}")
|
|
print(f"Val F1: {val_metrics['f1']:.4f}")
|
|
|
|
if self.writer:
|
|
self.writer.add_scalar("epoch/train_loss", train_metrics["loss"], epoch)
|
|
self.writer.add_scalar("epoch/val_loss", val_metrics["loss"], epoch)
|
|
self.writer.add_scalar(
|
|
"epoch/val_accuracy", val_metrics["accuracy"], epoch
|
|
)
|
|
|
|
if val_metrics["loss"] < self.best_val_loss:
|
|
self.best_val_loss = val_metrics["loss"]
|
|
self.epochs_without_improvement = 0
|
|
self.save_checkpoint(epoch, val_metrics["loss"], is_best=True)
|
|
print(f"New best model saved with val loss: {val_metrics['loss']:.4f}")
|
|
else:
|
|
self.epochs_without_improvement += 1
|
|
self.save_checkpoint(epoch, val_metrics["loss"], is_best=False)
|
|
|
|
patience = self.config.get("early_stopping_patience", 20)
|
|
if self.epochs_without_improvement >= patience:
|
|
print(
|
|
f"Early stopping triggered after {patience} epochs without improvement"
|
|
)
|
|
break
|
|
|
|
self.writer.close()
|
|
|
|
def save_checkpoint(self, epoch: int, val_loss: float, is_best: bool = False):
|
|
checkpoint_dir = os.path.join(
|
|
self.config.get("output_dir", "runs/similarity"), "checkpoints"
|
|
)
|
|
os.makedirs(checkpoint_dir, exist_ok=True)
|
|
|
|
checkpoint = {
|
|
"epoch": epoch,
|
|
"model_state_dict": self.model.state_dict(),
|
|
"optimizer_state_dict": self.optimizer.state_dict(),
|
|
"val_loss": val_loss,
|
|
"config": self.config,
|
|
}
|
|
|
|
checkpoint_path = os.path.join(checkpoint_dir, f"checkpoint_epoch_{epoch}.pt")
|
|
torch.save(checkpoint, checkpoint_path)
|
|
|
|
if is_best:
|
|
best_path = os.path.join(checkpoint_dir, "best_model.pt")
|
|
torch.save(checkpoint, best_path)
|
|
|
|
def load_checkpoint(self, checkpoint_path: str):
|
|
checkpoint = torch.load(checkpoint_path, map_location=self.device)
|
|
self.model.load_state_dict(checkpoint["model_state_dict"])
|
|
self.optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
|
|
return checkpoint["epoch"], checkpoint["val_loss"]
|
|
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser(description="Train similarity estimation model")
|
|
parser.add_argument(
|
|
"--data_dir",
|
|
type=str,
|
|
default=r"C:\Users\admin\Projects\autopilot\datasets\ya_go_maps\images",
|
|
)
|
|
parser.add_argument("--batch_size", type=int, default=32)
|
|
parser.add_argument("--epochs", type=int, default=100)
|
|
parser.add_argument("--learning_rate", type=float, default=2e-4)
|
|
parser.add_argument("--image_size", type=int, default=256)
|
|
parser.add_argument("--train_split", type=float, default=0.8)
|
|
parser.add_argument("--output_dir", type=str, default="runs/similarity")
|
|
parser.add_argument("--num_workers", type=int, default=0)
|
|
parser.add_argument(
|
|
"--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu"
|
|
)
|
|
|
|
args = parser.parse_args()
|
|
|
|
config = {
|
|
"data_dir": args.data_dir,
|
|
"batch_size": args.batch_size,
|
|
"epochs": args.epochs,
|
|
"learning_rate": args.learning_rate,
|
|
"image_size": (args.image_size, args.image_size),
|
|
"train_split": args.train_split,
|
|
"output_dir": args.output_dir,
|
|
"num_workers": args.num_workers,
|
|
"log_interval": 10,
|
|
"save_interval": 5,
|
|
"early_stopping_patience": 20,
|
|
"beta1": 0.5,
|
|
"beta2": 0.999,
|
|
}
|
|
|
|
device = torch.device(args.device)
|
|
print(f"Using device: {device}")
|
|
|
|
print("Creating data loaders...")
|
|
train_loader, val_loader = create_data_loaders(
|
|
root_dir=config["data_dir"],
|
|
batch_size=config["batch_size"],
|
|
train_split=config["train_split"],
|
|
num_workers=config["num_workers"],
|
|
image_size=config["image_size"],
|
|
augment_train=True,
|
|
augment_val=False,
|
|
device=device,
|
|
)
|
|
|
|
print(f"Train batches: {len(train_loader)}")
|
|
print(f"Val batches: {len(val_loader)}")
|
|
|
|
print("Creating model...")
|
|
model = create_similarity_model(
|
|
model_type="cnn",
|
|
input_size=config["image_size"],
|
|
input_channels=3,
|
|
hidden_channels=64,
|
|
num_blocks=4,
|
|
dropout_rate=0.3,
|
|
use_batch_norm=True,
|
|
)
|
|
|
|
print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")
|
|
|
|
trainer = SimilarityTrainer(
|
|
model=model,
|
|
train_loader=train_loader,
|
|
val_loader=val_loader,
|
|
device=device,
|
|
config=config,
|
|
)
|
|
|
|
print("Starting training...")
|
|
trainer.train(config["epochs"])
|
|
|
|
print("Training completed!")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|