Files
autopilot/models/GAN/trainer.py

191 lines
6.3 KiB
Python

"""Trainer for GAN model."""
import json
import time
from pathlib import Path
from typing import Any, Dict, Tuple
import torch
from torch.utils.data import DataLoader
from tqdm import tqdm
class GANTrainer:
"""Simple GAN trainer."""
def __init__(
self,
model: torch.nn.Module,
train_loader: DataLoader,
val_loader: DataLoader,
config: Dict[str, Any],
):
self.model = model
self.train_loader = train_loader
self.val_loader = val_loader
self.config = config
self.device = model.device
# Optimizers
lr = config.get("learning_rate", 2e-4)
beta1 = config.get("beta1", 0.5)
beta2 = config.get("beta2", 0.999)
self.opt_G = torch.optim.Adam(model.generator.parameters(), lr=lr, betas=(beta1, beta2))
self.opt_D = torch.optim.Adam(model.discriminator.parameters(), lr=lr, betas=(beta1, beta2))
# Training state
self.current_epoch = 0
self.best_val_loss = float("inf")
self.g_losses = []
self.d_losses = []
self.val_g_losses = []
self.val_d_losses = []
# Output dir
self.output_dir = Path(config.get("output_dir", "runs/gan"))
self.output_dir.mkdir(parents=True, exist_ok=True)
(self.output_dir / "checkpoints").mkdir(exist_ok=True)
# Save config
with open(self.output_dir / "config.json", "w") as f:
json.dump(config, f, indent=2)
def train_epoch(self) -> Tuple[float, float]:
"""Train for one epoch."""
self.model.train()
total_g = total_d = 0.0
num_batches = len(self.train_loader)
pbar = tqdm(self.train_loader, desc=f"Epoch {self.current_epoch + 1}")
for batch in pbar:
yandex_img = batch["yandex_img"].to(self.device)
google_img = batch["google_img"].to(self.device)
# Train D
self.opt_D.zero_grad()
with torch.no_grad():
fake_img = self.model.generator(yandex_img)
d_loss = self.model.discriminator_step(yandex_img, google_img, fake_img)[0]
d_loss.backward()
self.opt_D.step()
# Train G
self.opt_G.zero_grad()
g_loss = self.model.generator_step(yandex_img, google_img)[0]
g_loss.backward()
self.opt_G.step()
total_g += g_loss.item()
total_d += d_loss.item()
pbar.set_postfix({"g_loss": g_loss.item(), "d_loss": d_loss.item()})
avg_g = total_g / num_batches
avg_d = total_d / num_batches
self.g_losses.append(avg_g)
self.d_losses.append(avg_d)
return avg_g, avg_d
@torch.no_grad()
def validate(self) -> Tuple[float, float]:
"""Validate the model."""
self.model.eval()
total_g = total_d = 0.0
for batch in tqdm(self.val_loader, desc="Val"):
yandex_img = batch["yandex_img"].to(self.device)
google_img = batch["google_img"].to(self.device)
fake_img = self.model.generator(yandex_img)
g_loss = self.model.generator_step(yandex_img, google_img)[0]
d_loss = self.model.discriminator_step(yandex_img, google_img, fake_img)[0]
total_g += g_loss.item()
total_d += d_loss.item()
avg_g = total_g / len(self.val_loader)
avg_d = total_d / len(self.val_loader)
self.val_g_losses.append(avg_g)
self.val_d_losses.append(avg_d)
return avg_g, avg_d
def train(self, num_epochs: int):
"""Train the model."""
print(f"Training for {num_epochs} epochs...")
for epoch in range(num_epochs):
self.current_epoch = epoch
# Train & validate
train_g, train_d = self.train_epoch()
val_g, val_d = self.validate()
# Save best checkpoint
val_total = val_g + val_d
if val_total < self.best_val_loss:
self.best_val_loss = val_total
self.save_checkpoint("best")
# Periodic checkpoint
if (epoch + 1) % self.config.get("save_interval", 5) == 0:
self.save_checkpoint(f"epoch_{epoch + 1}")
print(f"Epoch {epoch + 1}: train_g={train_g:.4f}, train_d={train_d:.4f}, val_g={val_g:.4f}, val_d={val_d:.4f}")
# Early stopping
patience = self.config.get("early_stopping_patience", 0)
if patience > 0 and len(self.val_g_losses) > patience:
recent = self.val_g_losses[-patience:]
if all(l >= min(self.val_g_losses[:-patience]) for l in recent):
print(f"Early stopping at epoch {epoch + 1}")
break
# Save final
self.save_checkpoint("final")
print(f"Training finished. Best val loss: {self.best_val_loss:.4f}")
def save_checkpoint(self, name: str):
"""Save model checkpoint."""
path = self.output_dir / "checkpoints" / f"{name}.pth"
torch.save({
"epoch": self.current_epoch,
"generator": self.model.generator.state_dict(),
"discriminator": self.model.discriminator.state_dict(),
"opt_G": self.opt_G.state_dict(),
"opt_D": self.opt_D.state_dict(),
}, path)
def create_trainer(
model: torch.nn.Module,
train_loader: DataLoader,
val_loader: DataLoader,
config: Dict[str, Any],
) -> GANTrainer:
"""Create a trainer instance."""
return GANTrainer(model, train_loader, val_loader, config)
if __name__ == "__main__":
# Quick test
from config import create_config
from dataloader import create_data_loaders
from model import create_gan
config = create_config()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = create_gan(use_cuda=False)
train_loader, val_loader = create_data_loaders(
root_dir=config["data_dir"],
batch_size=4,
image_size=tuple(config["image_size"]),
num_workers=0,
)
trainer = create_trainer(model, train_loader, val_loader, config)
# Test one training step (just to verify no errors)
print("Testing one training step...")
try:
g_loss, d_loss = trainer.train_epoch()
print(f"Training step succeeded: G={g_loss:.4f}, D={d_loss:.4f}")
except Exception as e:
print(f"Training step failed: {e}")