import json import time from pathlib import Path from typing import Any, Dict, List, Tuple import numpy as np import torch import torch.nn as nn import torch.optim as optim from torch.utils.data import DataLoader from torch.utils.tensorboard import SummaryWriter from tqdm import tqdm # Type aliases ModuleType = nn.Module TensorType = torch.Tensor class GANTrainer: """Trainer class for GAN model.""" def __init__( self, model: ModuleType, train_loader: DataLoader, val_loader: DataLoader, device: torch.device, config: Dict[str, Any], ): """ Initialize the GAN trainer. Args: model: GAN model (ImageGAN) train_loader: Training data loader val_loader: Validation data loader device: Device to run training on config: Training configuration dictionary """ self.model = model.to(device) self.train_loader = train_loader self.val_loader = val_loader self.device = device self.config = config # Optimizers lr = config.get("learning_rate", 2e-4) beta1 = config.get("beta1", 0.5) beta2 = config.get("beta2", 0.999) # Separate optimizers for generator and discriminator # Note: self.model is expected to have .generator and .discriminator attributes self.optimizer_G = optim.Adam( self.model.generator.parameters(), lr=lr, betas=(beta1, beta2) ) self.optimizer_D = optim.Adam( self.model.discriminator.parameters(), lr=lr, betas=(beta1, beta2) ) # Training state self.current_epoch = 0 self.best_val_loss = float("inf") self.g_losses: List[float] = [] self.d_losses: List[float] = [] self.val_g_losses: List[float] = [] self.val_d_losses: List[float] = [] # Create output directory self.output_dir = Path(config.get("output_dir", "runs/gan")) self.output_dir.mkdir(parents=True, exist_ok=True) # TensorBoard writer self.writer = SummaryWriter(log_dir=self.output_dir / "tensorboard") # Save configuration config_path = self.output_dir / "config.json" with open(config_path, "w") as f: json.dump(config, f, indent=2) print(f"Training configuration saved to {config_path}") # Access parameters through the model's generator and discriminator generator_params = sum(p.numel() for p in self.model.generator.parameters()) discriminator_params = sum( p.numel() for p in self.model.discriminator.parameters() ) print(f"Generator has {generator_params:,} parameters") print(f"Discriminator has {discriminator_params:,} parameters") def train_epoch(self) -> Tuple[float, float]: """ Train for one epoch. Returns: Tuple of (average generator loss, average discriminator loss) """ self.model.train() total_g_loss = 0.0 total_d_loss = 0.0 num_batches = len(self.train_loader) progress_bar = tqdm(self.train_loader, desc=f"Epoch {self.current_epoch + 1}") for batch_idx, batch in enumerate(progress_bar): # Move data to device yandex_img = batch["yandex_img"].to(self.device) google_img = batch["google_img"].to(self.device) # ========== Train Discriminator ========== self.optimizer_D.zero_grad() # Generate fake image with torch.no_grad(): fake_google_img = self.model.generator(yandex_img) # Discriminator loss - returns tuple of tensors d_loss_tuple = self.model.discriminator_step( yandex_img, google_img, fake_google_img ) d_loss, d_real_loss, d_fake_loss = d_loss_tuple # Backward and optimize discriminator d_loss.backward() self.optimizer_D.step() # ========== Train Generator ========== self.optimizer_G.zero_grad() # Generate fake image fake_google_img = self.model.generator(yandex_img) # Generator loss - returns tuple of tensors g_loss_tuple = self.model.generator_step(yandex_img, google_img) g_loss, g_gan_loss, g_l1_loss = g_loss_tuple # Backward and optimize generator g_loss.backward() self.optimizer_G.step() # Update statistics total_g_loss += g_loss.item() total_d_loss += d_loss.item() # Update progress bar progress_bar.set_postfix( { "g_loss": g_loss.item(), "d_loss": d_loss.item(), "g_l1": g_l1_loss.item(), "d_real": d_real_loss.item(), "d_fake": d_fake_loss.item(), } ) # Log batch losses to TensorBoard global_step = self.current_epoch * num_batches + batch_idx self.writer.add_scalar("train/batch_g_loss", g_loss.item(), global_step) self.writer.add_scalar("train/batch_d_loss", d_loss.item(), global_step) self.writer.add_scalar( "train/batch_g_l1_loss", g_l1_loss.item(), global_step ) self.writer.add_scalar( "train/batch_d_real_loss", d_real_loss.item(), global_step ) self.writer.add_scalar( "train/batch_d_fake_loss", d_fake_loss.item(), global_step ) avg_g_loss = total_g_loss / num_batches avg_d_loss = total_d_loss / num_batches self.g_losses.append(avg_g_loss) self.d_losses.append(avg_d_loss) return avg_g_loss, avg_d_loss def validate(self) -> Tuple[float, float]: """ Validate the model. Returns: Tuple of (average generator validation loss, average discriminator validation loss) """ self.model.eval() total_g_loss = 0.0 total_d_loss = 0.0 progress_bar = tqdm(self.val_loader, desc="Validation") for batch in progress_bar: # Move data to device yandex_img = batch["yandex_img"].to(self.device) google_img = batch["google_img"].to(self.device) with torch.no_grad(): # Generate fake image fake_google_img = self.model.generator(yandex_img) # Generator loss - returns tuple of tensors g_loss_tuple = self.model.generator_step(yandex_img, google_img) g_loss, g_gan_loss, g_l1_loss = g_loss_tuple # Discriminator loss - returns tuple of tensors d_loss_tuple = self.model.discriminator_step( yandex_img, google_img, fake_google_img ) d_loss, d_real_loss, d_fake_loss = d_loss_tuple # Update statistics total_g_loss += g_loss.item() total_d_loss += d_loss.item() # Update progress bar progress_bar.set_postfix({"g_loss": g_loss.item(), "d_loss": d_loss.item()}) avg_g_loss = total_g_loss / len(self.val_loader) avg_d_loss = total_d_loss / len(self.val_loader) self.val_g_losses.append(avg_g_loss) self.val_d_losses.append(avg_d_loss) return avg_g_loss, avg_d_loss def save_checkpoint(self, is_best: bool = False): """ Save training checkpoint. Args: is_best: Whether this is the best model so far """ checkpoint = { "epoch": self.current_epoch, "model_state_dict": self.model.state_dict(), "optimizer_G_state_dict": self.optimizer_G.state_dict(), "optimizer_D_state_dict": self.optimizer_D.state_dict(), "g_losses": self.g_losses, "d_losses": self.d_losses, "val_g_losses": self.val_g_losses, "val_d_losses": self.val_d_losses, "best_val_loss": self.best_val_loss, "config": self.config, } # Save regular checkpoint checkpoint_path = ( self.output_dir / f"checkpoint_epoch_{self.current_epoch + 1}.pth" ) torch.save(checkpoint, checkpoint_path) # Save best model if is_best: best_path = self.output_dir / "model_best.pth" torch.save(checkpoint, best_path) print(f"Best model saved to {best_path}") def load_checkpoint(self, checkpoint_path: str, resume_training: bool = False): """ Load training checkpoint. Args: checkpoint_path: Path to checkpoint file resume_training: Если True, продолжить обучение с сохраненной эпохи """ checkpoint = torch.load(checkpoint_path, map_location=self.device) self.model.load_state_dict(checkpoint["model_state_dict"]) self.optimizer_G.load_state_dict(checkpoint["optimizer_G_state_dict"]) self.optimizer_D.load_state_dict(checkpoint["optimizer_D_state_dict"]) self.current_epoch = checkpoint["epoch"] self.g_losses = checkpoint["g_losses"] self.d_losses = checkpoint["d_losses"] self.val_g_losses = checkpoint["val_g_losses"] self.val_d_losses = checkpoint["val_d_losses"] self.best_val_loss = checkpoint["best_val_loss"] if resume_training: print(f"Resuming training from epoch {self.current_epoch + 1}") else: print(f"Loaded checkpoint from epoch {self.current_epoch + 1}") def train(self, num_epochs: int, start_epoch: int = 0): """ Train the model for specified number of epochs. Args: num_epochs: Number of epochs to train start_epoch: Starting epoch (useful when resuming training) """ print( f"Starting GAN training for {num_epochs} epochs from epoch {start_epoch + 1}..." ) start_time = time.time() for epoch in range(start_epoch, start_epoch + num_epochs): self.current_epoch = epoch # Train for one epoch train_g_loss, train_d_loss = self.train_epoch() # Validate val_g_loss, val_d_loss = self.validate() # Log to TensorBoard self.writer.add_scalar("train/epoch_g_loss", train_g_loss, epoch) self.writer.add_scalar("train/epoch_d_loss", train_d_loss, epoch) self.writer.add_scalar("val/epoch_g_loss", val_g_loss, epoch) self.writer.add_scalar("val/epoch_d_loss", val_d_loss, epoch) # Print epoch summary print(f"\nEpoch {epoch + 1}/{num_epochs}:") print(" Generator:") print(f" Train Loss: {train_g_loss:.6f}") print(f" Val Loss: {val_g_loss:.6f}") print(" Discriminator:") print(f" Train Loss: {train_d_loss:.6f}") print(f" Val Loss: {val_d_loss:.6f}") # Save checkpoint val_total_loss = val_g_loss + val_d_loss is_best = val_total_loss < self.best_val_loss if is_best: self.best_val_loss = val_total_loss self.save_checkpoint(is_best=is_best) # Early stopping if self.config.get("early_stopping_patience", 0) > 0: val_losses = [ g + d for g, d in zip(self.val_g_losses, self.val_d_losses) ] if ( epoch - np.argmin(val_losses) >= self.config["early_stopping_patience"] ): print(f"Early stopping at epoch {epoch + 1}") break # Training completed training_time = time.time() - start_time print(f"\nTraining completed in {training_time:.2f} seconds") print(f"Best validation total loss: {self.best_val_loss:.6f}") # Save final model final_model_path = self.output_dir / "model_final.pth" torch.save(self.model.state_dict(), final_model_path) print(f"Final model saved to {final_model_path}") # Save training history history_path = self.output_dir / "training_history.json" history = { "g_losses": self.g_losses, "d_losses": self.d_losses, "val_g_losses": self.val_g_losses, "val_d_losses": self.val_d_losses, "best_val_loss": self.best_val_loss, "total_epochs": self.current_epoch + 1, } with open(history_path, "w") as f: json.dump(history, f, indent=2) print(f"Training history saved to {history_path}") # Close TensorBoard writer self.writer.close() def evaluate(self, test_loader: DataLoader) -> Dict: """ Evaluate the model on test data. Args: test_loader: Test data loader Returns: Dictionary with evaluation metrics """ self.model.eval() total_g_loss = 0.0 total_d_loss = 0.0 print("Evaluating model on test set...") for batch in tqdm(test_loader, desc="Evaluation"): # Move data to device yandex_img = batch["yandex_img"].to(self.device) google_img = batch["google_img"].to(self.device) with torch.no_grad(): # Generate fake image fake_google_img = self.model.generator(yandex_img) # Generator loss - returns tuple of tensors g_loss_tuple = self.model.generator_step(yandex_img, google_img) g_loss, g_gan_loss, g_l1_loss = g_loss_tuple # Discriminator loss - returns tuple of tensors d_loss_tuple = self.model.discriminator_step( yandex_img, google_img, fake_google_img ) d_loss, d_real_loss, d_fake_loss = d_loss_tuple # Update statistics total_g_loss += g_loss.item() total_d_loss += d_loss.item() avg_g_loss = total_g_loss / len(test_loader) avg_d_loss = total_d_loss / len(test_loader) metrics = { "generator_loss": avg_g_loss, "discriminator_loss": avg_d_loss, "total_loss": avg_g_loss + avg_d_loss, } print("\nTest Results:") print(f" Generator Loss: {avg_g_loss:.6f}") print(f" Discriminator Loss: {avg_d_loss:.6f}") print(f" Total Loss: {avg_g_loss + avg_d_loss:.6f}") return metrics