"""Trainer for GAN model.""" import json import time from pathlib import Path from typing import Any, Dict, Tuple import torch from torch.utils.data import DataLoader from tqdm import tqdm class GANTrainer: """Simple GAN trainer.""" def __init__( self, model: torch.nn.Module, train_loader: DataLoader, val_loader: DataLoader, config: Dict[str, Any], ): self.model = model self.train_loader = train_loader self.val_loader = val_loader self.config = config self.device = model.device # Optimizers lr = config.get("learning_rate", 2e-4) beta1 = config.get("beta1", 0.5) beta2 = config.get("beta2", 0.999) self.opt_G = torch.optim.Adam(model.generator.parameters(), lr=lr, betas=(beta1, beta2)) self.opt_D = torch.optim.Adam(model.discriminator.parameters(), lr=lr, betas=(beta1, beta2)) # Training state self.current_epoch = 0 self.best_val_loss = float("inf") self.g_losses = [] self.d_losses = [] self.val_g_losses = [] self.val_d_losses = [] # Output dir self.output_dir = Path(config.get("output_dir", "runs/gan")) self.output_dir.mkdir(parents=True, exist_ok=True) (self.output_dir / "checkpoints").mkdir(exist_ok=True) # Save config with open(self.output_dir / "config.json", "w") as f: json.dump(config, f, indent=2) def train_epoch(self) -> Tuple[float, float]: """Train for one epoch.""" self.model.train() total_g = total_d = 0.0 num_batches = len(self.train_loader) pbar = tqdm(self.train_loader, desc=f"Epoch {self.current_epoch + 1}") for batch in pbar: yandex_img = batch["yandex_img"].to(self.device) google_img = batch["google_img"].to(self.device) # Train D self.opt_D.zero_grad() with torch.no_grad(): fake_img = self.model.generator(yandex_img) d_loss = self.model.discriminator_step(yandex_img, google_img, fake_img)[0] d_loss.backward() self.opt_D.step() # Train G self.opt_G.zero_grad() g_loss = self.model.generator_step(yandex_img, google_img)[0] g_loss.backward() self.opt_G.step() total_g += g_loss.item() total_d += d_loss.item() pbar.set_postfix({"g_loss": g_loss.item(), "d_loss": d_loss.item()}) avg_g = total_g / num_batches avg_d = total_d / num_batches self.g_losses.append(avg_g) self.d_losses.append(avg_d) return avg_g, avg_d @torch.no_grad() def validate(self) -> Tuple[float, float]: """Validate the model.""" self.model.eval() total_g = total_d = 0.0 for batch in tqdm(self.val_loader, desc="Val"): yandex_img = batch["yandex_img"].to(self.device) google_img = batch["google_img"].to(self.device) fake_img = self.model.generator(yandex_img) g_loss = self.model.generator_step(yandex_img, google_img)[0] d_loss = self.model.discriminator_step(yandex_img, google_img, fake_img)[0] total_g += g_loss.item() total_d += d_loss.item() avg_g = total_g / len(self.val_loader) avg_d = total_d / len(self.val_loader) self.val_g_losses.append(avg_g) self.val_d_losses.append(avg_d) return avg_g, avg_d def train(self, num_epochs: int): """Train the model.""" print(f"Training for {num_epochs} epochs...") for epoch in range(num_epochs): self.current_epoch = epoch # Train & validate train_g, train_d = self.train_epoch() val_g, val_d = self.validate() # Save best checkpoint val_total = val_g + val_d if val_total < self.best_val_loss: self.best_val_loss = val_total self.save_checkpoint("best") # Periodic checkpoint if (epoch + 1) % self.config.get("save_interval", 5) == 0: self.save_checkpoint(f"epoch_{epoch + 1}") print(f"Epoch {epoch + 1}: train_g={train_g:.4f}, train_d={train_d:.4f}, val_g={val_g:.4f}, val_d={val_d:.4f}") # Early stopping patience = self.config.get("early_stopping_patience", 0) if patience > 0 and len(self.val_g_losses) > patience: recent = self.val_g_losses[-patience:] if all(l >= min(self.val_g_losses[:-patience]) for l in recent): print(f"Early stopping at epoch {epoch + 1}") break # Save final self.save_checkpoint("final") print(f"Training finished. Best val loss: {self.best_val_loss:.4f}") def save_checkpoint(self, name: str): """Save model checkpoint.""" path = self.output_dir / "checkpoints" / f"{name}.pth" torch.save({ "epoch": self.current_epoch, "generator": self.model.generator.state_dict(), "discriminator": self.model.discriminator.state_dict(), "opt_G": self.opt_G.state_dict(), "opt_D": self.opt_D.state_dict(), }, path) def create_trainer( model: torch.nn.Module, train_loader: DataLoader, val_loader: DataLoader, config: Dict[str, Any], ) -> GANTrainer: """Create a trainer instance.""" return GANTrainer(model, train_loader, val_loader, config) if __name__ == "__main__": # Quick test from config import create_config from dataloader import create_data_loaders from model import create_gan config = create_config() device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = create_gan(use_cuda=False) train_loader, val_loader = create_data_loaders( root_dir=config["data_dir"], batch_size=4, image_size=tuple(config["image_size"]), num_workers=0, ) trainer = create_trainer(model, train_loader, val_loader, config) # Test one training step (just to verify no errors) print("Testing one training step...") try: g_loss, d_loss = trainer.train_epoch() print(f"Training step succeeded: G={g_loss:.4f}, D={d_loss:.4f}") except Exception as e: print(f"Training step failed: {e}")