Files
autopilot/models/SiaN/train.py
2026-04-04 20:26:56 +03:00

106 lines
3.9 KiB
Python

import os
import torch
import torch.nn as nn
import torch.optim as optim
from dataloader import create_data_loaders
from model import HomographyCNN6, HomographyLoss6
from utils import config
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm
class HomographyTrainer:
def __init__(self, model, train_loader, val_loader, device):
self.model = model.to(device)
self.train_loader = train_loader
self.val_loader = val_loader
self.device = device
self.criterion = HomographyLoss6()
self.optimizer = optim.Adam(model.parameters(), lr=config["learning_rate"])
self.writer = None
self.best_val_loss = float("inf")
def train_epoch(self, epoch):
self.model.train()
total_loss, total_samples = 0, 0
pbar = tqdm(self.train_loader, desc=f"Epoch {epoch}")
for batch_idx, batch in enumerate(pbar):
google_img = batch["google_img"].to(self.device)
yandex_img = batch["yandex_img"].to(self.device)
target = batch["homography_params"].to(self.device)
self.optimizer.zero_grad()
output = self.model(google_img, yandex_img)
loss = self.criterion(output, target)
loss.backward()
self.optimizer.step()
total_loss += loss.item() * google_img.size(0)
total_samples += google_img.size(0)
pbar.set_postfix({"loss": loss.item()})
return {"loss": total_loss / total_samples}
def validate(self):
self.model.eval()
total_loss, total_samples = 0, 0
with torch.no_grad():
for batch in tqdm(self.val_loader, desc="Validation"):
google_img = batch["google_img"].to(self.device)
yandex_img = batch["yandex_img"].to(self.device)
target = batch["homography_params"].to(self.device)
output = self.model(google_img, yandex_img)
loss = self.criterion(output, target)
total_loss += loss.item() * google_img.size(0)
total_samples += google_img.size(0)
return {"loss": total_loss / total_samples}
def train(self, num_epochs):
log_dir = config["output_dir"]
os.makedirs(log_dir, exist_ok=True)
self.writer = SummaryWriter(log_dir)
for epoch in range(1, num_epochs + 1):
train_metrics = self.train_epoch(epoch)
val_metrics = self.validate()
print(f"Train Loss: {train_metrics['loss']:.4f}, Val Loss: {val_metrics['loss']:.4f}")
if val_metrics["loss"] < self.best_val_loss:
self.best_val_loss = val_metrics["loss"]
self.save_checkpoint(epoch, is_best=True)
print(f"Best model saved (val loss: {val_metrics['loss']:.4f})")
self.writer.close()
def save_checkpoint(self, epoch, is_best=False):
ckpt_dir = os.path.join(config["output_dir"], "checkpoints")
os.makedirs(ckpt_dir, exist_ok=True)
ckpt = {"epoch": epoch, "model_state_dict": self.model.state_dict(), "val_loss": self.best_val_loss}
torch.save(ckpt, os.path.join(ckpt_dir, f"checkpoint_epoch_{epoch}.pt"))
if is_best:
torch.save(ckpt, os.path.join(ckpt_dir, "best_model.pt"))
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
train_loader, val_loader = create_data_loaders(
root_dir=config["data_dir"],
batch_size=config["batch_size"],
train_split=config["train_split"],
num_workers=config["num_workers"],
image_size=config["image_size"],
)
model = HomographyCNN6(
input_channels=3,
backbone_name=config["backbone"],
pretrained=True,
dropout_rate=config["dropout_rate"]
)
print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")
trainer = HomographyTrainer(model, train_loader, val_loader, device)
trainer.train(config["epochs"])