106 lines
3.9 KiB
Python
106 lines
3.9 KiB
Python
import os
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.optim as optim
|
|
from dataloader import create_data_loaders
|
|
from model import HomographyCNN6, HomographyLoss6
|
|
from utils import config
|
|
from torch.utils.tensorboard import SummaryWriter
|
|
from tqdm import tqdm
|
|
|
|
|
|
class HomographyTrainer:
|
|
def __init__(self, model, train_loader, val_loader, device):
|
|
self.model = model.to(device)
|
|
self.train_loader = train_loader
|
|
self.val_loader = val_loader
|
|
self.device = device
|
|
self.criterion = HomographyLoss6()
|
|
self.optimizer = optim.Adam(model.parameters(), lr=config["learning_rate"])
|
|
self.writer = None
|
|
self.best_val_loss = float("inf")
|
|
|
|
def train_epoch(self, epoch):
|
|
self.model.train()
|
|
total_loss, total_samples = 0, 0
|
|
pbar = tqdm(self.train_loader, desc=f"Epoch {epoch}")
|
|
for batch_idx, batch in enumerate(pbar):
|
|
google_img = batch["google_img"].to(self.device)
|
|
yandex_img = batch["yandex_img"].to(self.device)
|
|
target = batch["homography_params"].to(self.device)
|
|
|
|
self.optimizer.zero_grad()
|
|
output = self.model(google_img, yandex_img)
|
|
loss = self.criterion(output, target)
|
|
loss.backward()
|
|
self.optimizer.step()
|
|
|
|
total_loss += loss.item() * google_img.size(0)
|
|
total_samples += google_img.size(0)
|
|
pbar.set_postfix({"loss": loss.item()})
|
|
|
|
return {"loss": total_loss / total_samples}
|
|
|
|
def validate(self):
|
|
self.model.eval()
|
|
total_loss, total_samples = 0, 0
|
|
with torch.no_grad():
|
|
for batch in tqdm(self.val_loader, desc="Validation"):
|
|
google_img = batch["google_img"].to(self.device)
|
|
yandex_img = batch["yandex_img"].to(self.device)
|
|
target = batch["homography_params"].to(self.device)
|
|
output = self.model(google_img, yandex_img)
|
|
loss = self.criterion(output, target)
|
|
total_loss += loss.item() * google_img.size(0)
|
|
total_samples += google_img.size(0)
|
|
return {"loss": total_loss / total_samples}
|
|
|
|
def train(self, num_epochs):
|
|
log_dir = config["output_dir"]
|
|
os.makedirs(log_dir, exist_ok=True)
|
|
self.writer = SummaryWriter(log_dir)
|
|
|
|
for epoch in range(1, num_epochs + 1):
|
|
train_metrics = self.train_epoch(epoch)
|
|
val_metrics = self.validate()
|
|
print(f"Train Loss: {train_metrics['loss']:.4f}, Val Loss: {val_metrics['loss']:.4f}")
|
|
|
|
if val_metrics["loss"] < self.best_val_loss:
|
|
self.best_val_loss = val_metrics["loss"]
|
|
self.save_checkpoint(epoch, is_best=True)
|
|
print(f"Best model saved (val loss: {val_metrics['loss']:.4f})")
|
|
|
|
self.writer.close()
|
|
|
|
def save_checkpoint(self, epoch, is_best=False):
|
|
ckpt_dir = os.path.join(config["output_dir"], "checkpoints")
|
|
os.makedirs(ckpt_dir, exist_ok=True)
|
|
ckpt = {"epoch": epoch, "model_state_dict": self.model.state_dict(), "val_loss": self.best_val_loss}
|
|
torch.save(ckpt, os.path.join(ckpt_dir, f"checkpoint_epoch_{epoch}.pt"))
|
|
if is_best:
|
|
torch.save(ckpt, os.path.join(ckpt_dir, "best_model.pt"))
|
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
print(f"Using device: {device}")
|
|
|
|
train_loader, val_loader = create_data_loaders(
|
|
root_dir=config["data_dir"],
|
|
batch_size=config["batch_size"],
|
|
train_split=config["train_split"],
|
|
num_workers=config["num_workers"],
|
|
image_size=config["image_size"],
|
|
)
|
|
|
|
model = HomographyCNN6(
|
|
input_channels=3,
|
|
backbone_name=config["backbone"],
|
|
pretrained=True,
|
|
dropout_rate=config["dropout_rate"]
|
|
)
|
|
print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")
|
|
|
|
trainer = HomographyTrainer(model, train_loader, val_loader, device)
|
|
trainer.train(config["epochs"])
|