Files
autopilot/models/GAN/trainer.py
2026-02-20 16:52:02 +03:00

416 lines
15 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import json
import time
from pathlib import Path
from typing import Any, Dict, List, Tuple
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm
# Type aliases
ModuleType = nn.Module
TensorType = torch.Tensor
class GANTrainer:
"""Trainer class for GAN model."""
def __init__(
self,
model: ModuleType,
train_loader: DataLoader,
val_loader: DataLoader,
device: torch.device,
config: Dict[str, Any],
):
"""
Initialize the GAN trainer.
Args:
model: GAN model (ImageGAN)
train_loader: Training data loader
val_loader: Validation data loader
device: Device to run training on
config: Training configuration dictionary
"""
self.model = model.to(device)
self.train_loader = train_loader
self.val_loader = val_loader
self.device = device
self.config = config
# Optimizers
lr = config.get("learning_rate", 2e-4)
beta1 = config.get("beta1", 0.5)
beta2 = config.get("beta2", 0.999)
# Separate optimizers for generator and discriminator
# Note: self.model is expected to have .generator and .discriminator attributes
self.optimizer_G = optim.Adam(
self.model.generator.parameters(), lr=lr, betas=(beta1, beta2)
)
self.optimizer_D = optim.Adam(
self.model.discriminator.parameters(), lr=lr, betas=(beta1, beta2)
)
# Training state
self.current_epoch = 0
self.best_val_loss = float("inf")
self.g_losses: List[float] = []
self.d_losses: List[float] = []
self.val_g_losses: List[float] = []
self.val_d_losses: List[float] = []
# Create output directory
self.output_dir = Path(config.get("output_dir", "runs/gan"))
self.output_dir.mkdir(parents=True, exist_ok=True)
# TensorBoard writer
self.writer = SummaryWriter(log_dir=self.output_dir / "tensorboard")
# Save configuration
config_path = self.output_dir / "config.json"
with open(config_path, "w") as f:
json.dump(config, f, indent=2)
print(f"Training configuration saved to {config_path}")
# Access parameters through the model's generator and discriminator
generator_params = sum(p.numel() for p in self.model.generator.parameters())
discriminator_params = sum(
p.numel() for p in self.model.discriminator.parameters()
)
print(f"Generator has {generator_params:,} parameters")
print(f"Discriminator has {discriminator_params:,} parameters")
def train_epoch(self) -> Tuple[float, float]:
"""
Train for one epoch.
Returns:
Tuple of (average generator loss, average discriminator loss)
"""
self.model.train()
total_g_loss = 0.0
total_d_loss = 0.0
num_batches = len(self.train_loader)
progress_bar = tqdm(self.train_loader, desc=f"Epoch {self.current_epoch + 1}")
for batch_idx, batch in enumerate(progress_bar):
# Move data to device
yandex_img = batch["yandex_img"].to(self.device)
google_img = batch["google_img"].to(self.device)
# ========== Train Discriminator ==========
self.optimizer_D.zero_grad()
# Generate fake image
with torch.no_grad():
fake_google_img = self.model.generator(yandex_img)
# Discriminator loss - returns tuple of tensors
d_loss_tuple = self.model.discriminator_step(
yandex_img, google_img, fake_google_img
)
d_loss, d_real_loss, d_fake_loss = d_loss_tuple
# Backward and optimize discriminator
d_loss.backward()
self.optimizer_D.step()
# ========== Train Generator ==========
self.optimizer_G.zero_grad()
# Generate fake image
fake_google_img = self.model.generator(yandex_img)
# Generator loss - returns tuple of tensors
g_loss_tuple = self.model.generator_step(yandex_img, google_img)
g_loss, g_gan_loss, g_l1_loss = g_loss_tuple
# Backward and optimize generator
g_loss.backward()
self.optimizer_G.step()
# Update statistics
total_g_loss += g_loss.item()
total_d_loss += d_loss.item()
# Update progress bar
progress_bar.set_postfix(
{
"g_loss": g_loss.item(),
"d_loss": d_loss.item(),
"g_l1": g_l1_loss.item(),
"d_real": d_real_loss.item(),
"d_fake": d_fake_loss.item(),
}
)
# Log batch losses to TensorBoard
global_step = self.current_epoch * num_batches + batch_idx
self.writer.add_scalar("train/batch_g_loss", g_loss.item(), global_step)
self.writer.add_scalar("train/batch_d_loss", d_loss.item(), global_step)
self.writer.add_scalar(
"train/batch_g_l1_loss", g_l1_loss.item(), global_step
)
self.writer.add_scalar(
"train/batch_d_real_loss", d_real_loss.item(), global_step
)
self.writer.add_scalar(
"train/batch_d_fake_loss", d_fake_loss.item(), global_step
)
avg_g_loss = total_g_loss / num_batches
avg_d_loss = total_d_loss / num_batches
self.g_losses.append(avg_g_loss)
self.d_losses.append(avg_d_loss)
return avg_g_loss, avg_d_loss
def validate(self) -> Tuple[float, float]:
"""
Validate the model.
Returns:
Tuple of (average generator validation loss, average discriminator validation loss)
"""
self.model.eval()
total_g_loss = 0.0
total_d_loss = 0.0
progress_bar = tqdm(self.val_loader, desc="Validation")
for batch in progress_bar:
# Move data to device
yandex_img = batch["yandex_img"].to(self.device)
google_img = batch["google_img"].to(self.device)
with torch.no_grad():
# Generate fake image
fake_google_img = self.model.generator(yandex_img)
# Generator loss - returns tuple of tensors
g_loss_tuple = self.model.generator_step(yandex_img, google_img)
g_loss, g_gan_loss, g_l1_loss = g_loss_tuple
# Discriminator loss - returns tuple of tensors
d_loss_tuple = self.model.discriminator_step(
yandex_img, google_img, fake_google_img
)
d_loss, d_real_loss, d_fake_loss = d_loss_tuple
# Update statistics
total_g_loss += g_loss.item()
total_d_loss += d_loss.item()
# Update progress bar
progress_bar.set_postfix({"g_loss": g_loss.item(), "d_loss": d_loss.item()})
avg_g_loss = total_g_loss / len(self.val_loader)
avg_d_loss = total_d_loss / len(self.val_loader)
self.val_g_losses.append(avg_g_loss)
self.val_d_losses.append(avg_d_loss)
return avg_g_loss, avg_d_loss
def save_checkpoint(self, is_best: bool = False):
"""
Save training checkpoint.
Args:
is_best: Whether this is the best model so far
"""
checkpoint = {
"epoch": self.current_epoch,
"model_state_dict": self.model.state_dict(),
"optimizer_G_state_dict": self.optimizer_G.state_dict(),
"optimizer_D_state_dict": self.optimizer_D.state_dict(),
"g_losses": self.g_losses,
"d_losses": self.d_losses,
"val_g_losses": self.val_g_losses,
"val_d_losses": self.val_d_losses,
"best_val_loss": self.best_val_loss,
"config": self.config,
}
# Save regular checkpoint
checkpoint_path = (
self.output_dir / f"checkpoint_epoch_{self.current_epoch + 1}.pth"
)
torch.save(checkpoint, checkpoint_path)
# Save best model
if is_best:
best_path = self.output_dir / "model_best.pth"
torch.save(checkpoint, best_path)
print(f"Best model saved to {best_path}")
def load_checkpoint(self, checkpoint_path: str, resume_training: bool = False):
"""
Load training checkpoint.
Args:
checkpoint_path: Path to checkpoint file
resume_training: Если True, продолжить обучение с сохраненной эпохи
"""
checkpoint = torch.load(checkpoint_path, map_location=self.device)
self.model.load_state_dict(checkpoint["model_state_dict"])
self.optimizer_G.load_state_dict(checkpoint["optimizer_G_state_dict"])
self.optimizer_D.load_state_dict(checkpoint["optimizer_D_state_dict"])
self.current_epoch = checkpoint["epoch"]
self.g_losses = checkpoint["g_losses"]
self.d_losses = checkpoint["d_losses"]
self.val_g_losses = checkpoint["val_g_losses"]
self.val_d_losses = checkpoint["val_d_losses"]
self.best_val_loss = checkpoint["best_val_loss"]
if resume_training:
print(f"Resuming training from epoch {self.current_epoch + 1}")
else:
print(f"Loaded checkpoint from epoch {self.current_epoch + 1}")
def train(self, num_epochs: int, start_epoch: int = 0):
"""
Train the model for specified number of epochs.
Args:
num_epochs: Number of epochs to train
start_epoch: Starting epoch (useful when resuming training)
"""
print(
f"Starting GAN training for {num_epochs} epochs from epoch {start_epoch + 1}..."
)
start_time = time.time()
for epoch in range(start_epoch, start_epoch + num_epochs):
self.current_epoch = epoch
# Train for one epoch
train_g_loss, train_d_loss = self.train_epoch()
# Validate
val_g_loss, val_d_loss = self.validate()
# Log to TensorBoard
self.writer.add_scalar("train/epoch_g_loss", train_g_loss, epoch)
self.writer.add_scalar("train/epoch_d_loss", train_d_loss, epoch)
self.writer.add_scalar("val/epoch_g_loss", val_g_loss, epoch)
self.writer.add_scalar("val/epoch_d_loss", val_d_loss, epoch)
# Print epoch summary
print(f"\nEpoch {epoch + 1}/{num_epochs}:")
print(" Generator:")
print(f" Train Loss: {train_g_loss:.6f}")
print(f" Val Loss: {val_g_loss:.6f}")
print(" Discriminator:")
print(f" Train Loss: {train_d_loss:.6f}")
print(f" Val Loss: {val_d_loss:.6f}")
# Save checkpoint
val_total_loss = val_g_loss + val_d_loss
is_best = val_total_loss < self.best_val_loss
if is_best:
self.best_val_loss = val_total_loss
self.save_checkpoint(is_best=is_best)
# Early stopping
if self.config.get("early_stopping_patience", 0) > 0:
val_losses = [
g + d for g, d in zip(self.val_g_losses, self.val_d_losses)
]
if (
epoch - np.argmin(val_losses)
>= self.config["early_stopping_patience"]
):
print(f"Early stopping at epoch {epoch + 1}")
break
# Training completed
training_time = time.time() - start_time
print(f"\nTraining completed in {training_time:.2f} seconds")
print(f"Best validation total loss: {self.best_val_loss:.6f}")
# Save final model
final_model_path = self.output_dir / "model_final.pth"
torch.save(self.model.state_dict(), final_model_path)
print(f"Final model saved to {final_model_path}")
# Save training history
history_path = self.output_dir / "training_history.json"
history = {
"g_losses": self.g_losses,
"d_losses": self.d_losses,
"val_g_losses": self.val_g_losses,
"val_d_losses": self.val_d_losses,
"best_val_loss": self.best_val_loss,
"total_epochs": self.current_epoch + 1,
}
with open(history_path, "w") as f:
json.dump(history, f, indent=2)
print(f"Training history saved to {history_path}")
# Close TensorBoard writer
self.writer.close()
def evaluate(self, test_loader: DataLoader) -> Dict:
"""
Evaluate the model on test data.
Args:
test_loader: Test data loader
Returns:
Dictionary with evaluation metrics
"""
self.model.eval()
total_g_loss = 0.0
total_d_loss = 0.0
print("Evaluating model on test set...")
for batch in tqdm(test_loader, desc="Evaluation"):
# Move data to device
yandex_img = batch["yandex_img"].to(self.device)
google_img = batch["google_img"].to(self.device)
with torch.no_grad():
# Generate fake image
fake_google_img = self.model.generator(yandex_img)
# Generator loss - returns tuple of tensors
g_loss_tuple = self.model.generator_step(yandex_img, google_img)
g_loss, g_gan_loss, g_l1_loss = g_loss_tuple
# Discriminator loss - returns tuple of tensors
d_loss_tuple = self.model.discriminator_step(
yandex_img, google_img, fake_google_img
)
d_loss, d_real_loss, d_fake_loss = d_loss_tuple
# Update statistics
total_g_loss += g_loss.item()
total_d_loss += d_loss.item()
avg_g_loss = total_g_loss / len(test_loader)
avg_d_loss = total_d_loss / len(test_loader)
metrics = {
"generator_loss": avg_g_loss,
"discriminator_loss": avg_d_loss,
"total_loss": avg_g_loss + avg_d_loss,
}
print("\nTest Results:")
print(f" Generator Loss: {avg_g_loss:.6f}")
print(f" Discriminator Loss: {avg_d_loss:.6f}")
print(f" Total Loss: {avg_g_loss + avg_d_loss:.6f}")
return metrics