213 lines
7.0 KiB
Python
213 lines
7.0 KiB
Python
import os
|
|
import time
|
|
from datetime import datetime
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.optim as optim
|
|
from dataloader import config, create_data_loaders
|
|
from model import HomographyCNN, HomographyLoss, create_homography_model
|
|
from torch.utils.data import DataLoader
|
|
from torch.utils.tensorboard import SummaryWriter
|
|
from tqdm import tqdm
|
|
|
|
|
|
class HomographyTrainer:
|
|
def __init__(
|
|
self,
|
|
model: nn.Module,
|
|
train_loader: DataLoader,
|
|
val_loader: DataLoader,
|
|
device: torch.device,
|
|
config: dict,
|
|
):
|
|
self.model = model.to(device)
|
|
self.train_loader = train_loader
|
|
self.val_loader = val_loader
|
|
self.device = device
|
|
self.config = config
|
|
|
|
self.criterion = HomographyLoss()
|
|
self.optimizer = optim.Adam(
|
|
model.parameters(),
|
|
lr=config.get("learning_rate", 2e-4),
|
|
betas=(config.get("beta1", 0.5), config.get("beta2", 0.999)),
|
|
)
|
|
|
|
self.writer = None
|
|
self.best_val_loss = float("inf")
|
|
self.epochs_without_improvement = 0
|
|
|
|
def train_epoch(self, epoch: int) -> dict:
|
|
self.model.train()
|
|
total_loss = 0
|
|
total_samples = 0
|
|
|
|
pbar = tqdm(self.train_loader, desc=f"Epoch {epoch}")
|
|
for batch_idx, batch in enumerate(pbar):
|
|
google_img = batch["google_img"].to(self.device)
|
|
yandex_img = batch["yandex_img"].to(self.device)
|
|
target = batch["homography"].to(self.device)
|
|
|
|
self.optimizer.zero_grad()
|
|
|
|
output = self.model(google_img, yandex_img)
|
|
loss = self.criterion(output, target)
|
|
|
|
loss.backward()
|
|
self.optimizer.step()
|
|
|
|
total_loss += loss.item() * google_img.size(0)
|
|
total_samples += google_img.size(0)
|
|
|
|
if batch_idx % self.config.get("log_interval", 10) == 0:
|
|
pbar.set_postfix({"loss": loss.item()})
|
|
|
|
if self.writer:
|
|
self.writer.add_scalar(
|
|
"train/loss",
|
|
loss.item(),
|
|
epoch * len(self.train_loader) + batch_idx,
|
|
)
|
|
|
|
avg_loss = total_loss / total_samples
|
|
return {"loss": avg_loss}
|
|
|
|
def validate(self) -> dict:
|
|
self.model.eval()
|
|
total_loss = 0
|
|
total_samples = 0
|
|
|
|
with torch.no_grad():
|
|
for batch in tqdm(self.val_loader, desc="Validation"):
|
|
google_img = batch["google_img"].to(self.device)
|
|
yandex_img = batch["yandex_img"].to(self.device)
|
|
target = batch["homography"].to(self.device)
|
|
|
|
output = self.model(google_img, yandex_img)
|
|
loss = self.criterion(output, target)
|
|
|
|
total_loss += loss.item() * google_img.size(0)
|
|
total_samples += google_img.size(0)
|
|
|
|
avg_loss = total_loss / total_samples
|
|
return {"loss": avg_loss}
|
|
|
|
def train(self, num_epochs: int):
|
|
log_dir = self.config.get("output_dir", "runs/homography")
|
|
os.makedirs(log_dir, exist_ok=True)
|
|
self.writer = SummaryWriter(log_dir)
|
|
|
|
print(f"Starting training for {num_epochs} epochs")
|
|
print(f"Logging to: {log_dir}")
|
|
|
|
for epoch in range(1, num_epochs + 1):
|
|
print(f"\nEpoch {epoch}/{num_epochs}")
|
|
|
|
train_metrics = self.train_epoch(epoch)
|
|
val_metrics = self.validate()
|
|
|
|
print(f"Train Loss: {train_metrics['loss']:.4f}")
|
|
print(f"Val Loss: {val_metrics['loss']:.4f}")
|
|
|
|
if self.writer:
|
|
self.writer.add_scalar("epoch/train_loss", train_metrics["loss"], epoch)
|
|
self.writer.add_scalar("epoch/val_loss", val_metrics["loss"], epoch)
|
|
|
|
if val_metrics["loss"] < self.best_val_loss:
|
|
self.best_val_loss = val_metrics["loss"]
|
|
self.epochs_without_improvement = 0
|
|
self.save_checkpoint(epoch, val_metrics["loss"], is_best=True)
|
|
print(f"New best model saved with val loss: {val_metrics['loss']:.4f}")
|
|
else:
|
|
self.epochs_without_improvement += 1
|
|
self.save_checkpoint(epoch, val_metrics["loss"], is_best=False)
|
|
|
|
patience = self.config.get("early_stopping_patience", 20)
|
|
if self.epochs_without_improvement >= patience:
|
|
print(f"Early stopping triggered after {patience} epochs without improvement")
|
|
break
|
|
|
|
self.writer.close()
|
|
|
|
def save_checkpoint(self, epoch: int, val_loss: float, is_best: bool = False):
|
|
checkpoint_dir = os.path.join(
|
|
self.config.get("output_dir", "runs/homography"), "checkpoints"
|
|
)
|
|
os.makedirs(checkpoint_dir, exist_ok=True)
|
|
|
|
checkpoint = {
|
|
"epoch": epoch,
|
|
"model_state_dict": self.model.state_dict(),
|
|
"optimizer_state_dict": self.optimizer.state_dict(),
|
|
"val_loss": val_loss,
|
|
"config": self.config,
|
|
}
|
|
|
|
checkpoint_path = os.path.join(checkpoint_dir, f"checkpoint_epoch_{epoch}.pt")
|
|
torch.save(checkpoint, checkpoint_path)
|
|
|
|
if is_best:
|
|
best_path = os.path.join(checkpoint_dir, "best_model.pt")
|
|
torch.save(checkpoint, best_path)
|
|
|
|
def load_checkpoint(self, checkpoint_path: str):
|
|
checkpoint = torch.load(checkpoint_path, map_location=self.device)
|
|
self.model.load_state_dict(checkpoint["model_state_dict"])
|
|
self.optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
|
|
return checkpoint["epoch"], checkpoint["val_loss"]
|
|
|
|
|
|
def main():
|
|
config_dict = config.copy()
|
|
|
|
if isinstance(config_dict.get("image_size"), list):
|
|
config_dict["image_size"] = tuple(config_dict["image_size"])
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
print(f"Using device: {device}")
|
|
|
|
print("Creating data loaders...")
|
|
train_loader, val_loader = create_data_loaders(
|
|
root_dir=config_dict["data_dir"],
|
|
batch_size=config_dict["batch_size"],
|
|
train_split=config_dict["train_split"],
|
|
num_workers=config_dict["num_workers"],
|
|
image_size=config_dict["image_size"],
|
|
augment_train=True,
|
|
augment_val=False,
|
|
device=device,
|
|
)
|
|
|
|
print(f"Train batches: {len(train_loader)}")
|
|
print(f"Val batches: {len(val_loader)}")
|
|
|
|
print("Creating model...")
|
|
model = create_homography_model(
|
|
model_type="backbone",
|
|
input_channels=3,
|
|
backbone_name="resnet18",
|
|
pretrained=True,
|
|
dropout_rate=0.3,
|
|
use_batch_norm=True,
|
|
)
|
|
|
|
print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")
|
|
|
|
trainer = HomographyTrainer(
|
|
model=model,
|
|
train_loader=train_loader,
|
|
val_loader=val_loader,
|
|
device=device,
|
|
config=config_dict,
|
|
)
|
|
|
|
print("Starting training...")
|
|
trainer.train(config_dict["epochs"])
|
|
|
|
print("Training completed!")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|