191 lines
6.3 KiB
Python
191 lines
6.3 KiB
Python
"""Trainer for GAN model."""
|
|
|
|
import json
|
|
import time
|
|
from pathlib import Path
|
|
from typing import Any, Dict, Tuple
|
|
|
|
import torch
|
|
from torch.utils.data import DataLoader
|
|
from tqdm import tqdm
|
|
|
|
|
|
class GANTrainer:
|
|
"""Simple GAN trainer."""
|
|
|
|
def __init__(
|
|
self,
|
|
model: torch.nn.Module,
|
|
train_loader: DataLoader,
|
|
val_loader: DataLoader,
|
|
config: Dict[str, Any],
|
|
):
|
|
self.model = model
|
|
self.train_loader = train_loader
|
|
self.val_loader = val_loader
|
|
self.config = config
|
|
self.device = model.device
|
|
|
|
# Optimizers
|
|
lr = config.get("learning_rate", 2e-4)
|
|
beta1 = config.get("beta1", 0.5)
|
|
beta2 = config.get("beta2", 0.999)
|
|
self.opt_G = torch.optim.Adam(model.generator.parameters(), lr=lr, betas=(beta1, beta2))
|
|
self.opt_D = torch.optim.Adam(model.discriminator.parameters(), lr=lr, betas=(beta1, beta2))
|
|
|
|
# Training state
|
|
self.current_epoch = 0
|
|
self.best_val_loss = float("inf")
|
|
self.g_losses = []
|
|
self.d_losses = []
|
|
self.val_g_losses = []
|
|
self.val_d_losses = []
|
|
|
|
# Output dir
|
|
self.output_dir = Path(config.get("output_dir", "runs/gan"))
|
|
self.output_dir.mkdir(parents=True, exist_ok=True)
|
|
(self.output_dir / "checkpoints").mkdir(exist_ok=True)
|
|
|
|
# Save config
|
|
with open(self.output_dir / "config.json", "w") as f:
|
|
json.dump(config, f, indent=2)
|
|
|
|
def train_epoch(self) -> Tuple[float, float]:
|
|
"""Train for one epoch."""
|
|
self.model.train()
|
|
total_g = total_d = 0.0
|
|
num_batches = len(self.train_loader)
|
|
|
|
pbar = tqdm(self.train_loader, desc=f"Epoch {self.current_epoch + 1}")
|
|
for batch in pbar:
|
|
yandex_img = batch["yandex_img"].to(self.device)
|
|
google_img = batch["google_img"].to(self.device)
|
|
|
|
# Train D
|
|
self.opt_D.zero_grad()
|
|
with torch.no_grad():
|
|
fake_img = self.model.generator(yandex_img)
|
|
d_loss = self.model.discriminator_step(yandex_img, google_img, fake_img)[0]
|
|
d_loss.backward()
|
|
self.opt_D.step()
|
|
|
|
# Train G
|
|
self.opt_G.zero_grad()
|
|
g_loss = self.model.generator_step(yandex_img, google_img)[0]
|
|
g_loss.backward()
|
|
self.opt_G.step()
|
|
|
|
total_g += g_loss.item()
|
|
total_d += d_loss.item()
|
|
pbar.set_postfix({"g_loss": g_loss.item(), "d_loss": d_loss.item()})
|
|
|
|
avg_g = total_g / num_batches
|
|
avg_d = total_d / num_batches
|
|
self.g_losses.append(avg_g)
|
|
self.d_losses.append(avg_d)
|
|
return avg_g, avg_d
|
|
|
|
@torch.no_grad()
|
|
def validate(self) -> Tuple[float, float]:
|
|
"""Validate the model."""
|
|
self.model.eval()
|
|
total_g = total_d = 0.0
|
|
|
|
for batch in tqdm(self.val_loader, desc="Val"):
|
|
yandex_img = batch["yandex_img"].to(self.device)
|
|
google_img = batch["google_img"].to(self.device)
|
|
fake_img = self.model.generator(yandex_img)
|
|
g_loss = self.model.generator_step(yandex_img, google_img)[0]
|
|
d_loss = self.model.discriminator_step(yandex_img, google_img, fake_img)[0]
|
|
total_g += g_loss.item()
|
|
total_d += d_loss.item()
|
|
|
|
avg_g = total_g / len(self.val_loader)
|
|
avg_d = total_d / len(self.val_loader)
|
|
self.val_g_losses.append(avg_g)
|
|
self.val_d_losses.append(avg_d)
|
|
return avg_g, avg_d
|
|
|
|
def train(self, num_epochs: int):
|
|
"""Train the model."""
|
|
print(f"Training for {num_epochs} epochs...")
|
|
|
|
for epoch in range(num_epochs):
|
|
self.current_epoch = epoch
|
|
|
|
# Train & validate
|
|
train_g, train_d = self.train_epoch()
|
|
val_g, val_d = self.validate()
|
|
|
|
# Save best checkpoint
|
|
val_total = val_g + val_d
|
|
if val_total < self.best_val_loss:
|
|
self.best_val_loss = val_total
|
|
self.save_checkpoint("best")
|
|
|
|
# Periodic checkpoint
|
|
if (epoch + 1) % self.config.get("save_interval", 5) == 0:
|
|
self.save_checkpoint(f"epoch_{epoch + 1}")
|
|
|
|
print(f"Epoch {epoch + 1}: train_g={train_g:.4f}, train_d={train_d:.4f}, val_g={val_g:.4f}, val_d={val_d:.4f}")
|
|
|
|
# Early stopping
|
|
patience = self.config.get("early_stopping_patience", 0)
|
|
if patience > 0 and len(self.val_g_losses) > patience:
|
|
recent = self.val_g_losses[-patience:]
|
|
if all(l >= min(self.val_g_losses[:-patience]) for l in recent):
|
|
print(f"Early stopping at epoch {epoch + 1}")
|
|
break
|
|
|
|
# Save final
|
|
self.save_checkpoint("final")
|
|
print(f"Training finished. Best val loss: {self.best_val_loss:.4f}")
|
|
|
|
def save_checkpoint(self, name: str):
|
|
"""Save model checkpoint."""
|
|
path = self.output_dir / "checkpoints" / f"{name}.pth"
|
|
torch.save({
|
|
"epoch": self.current_epoch,
|
|
"generator": self.model.generator.state_dict(),
|
|
"discriminator": self.model.discriminator.state_dict(),
|
|
"opt_G": self.opt_G.state_dict(),
|
|
"opt_D": self.opt_D.state_dict(),
|
|
}, path)
|
|
|
|
|
|
def create_trainer(
|
|
model: torch.nn.Module,
|
|
train_loader: DataLoader,
|
|
val_loader: DataLoader,
|
|
config: Dict[str, Any],
|
|
) -> GANTrainer:
|
|
"""Create a trainer instance."""
|
|
return GANTrainer(model, train_loader, val_loader, config)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
# Quick test
|
|
from config import create_config
|
|
from dataloader import create_data_loaders
|
|
from model import create_gan
|
|
|
|
config = create_config()
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
|
model = create_gan(use_cuda=False)
|
|
train_loader, val_loader = create_data_loaders(
|
|
root_dir=config["data_dir"],
|
|
batch_size=4,
|
|
image_size=tuple(config["image_size"]),
|
|
num_workers=0,
|
|
)
|
|
|
|
trainer = create_trainer(model, train_loader, val_loader, config)
|
|
|
|
# Test one training step (just to verify no errors)
|
|
print("Testing one training step...")
|
|
try:
|
|
g_loss, d_loss = trainer.train_epoch()
|
|
print(f"Training step succeeded: G={g_loss:.4f}, D={d_loss:.4f}")
|
|
except Exception as e:
|
|
print(f"Training step failed: {e}") |