416 lines
15 KiB
Python
416 lines
15 KiB
Python
import json
|
||
import time
|
||
from pathlib import Path
|
||
from typing import Any, Dict, List, Tuple
|
||
|
||
import numpy as np
|
||
import torch
|
||
import torch.nn as nn
|
||
import torch.optim as optim
|
||
from torch.utils.data import DataLoader
|
||
from torch.utils.tensorboard import SummaryWriter
|
||
from tqdm import tqdm
|
||
|
||
# Type aliases
|
||
ModuleType = nn.Module
|
||
TensorType = torch.Tensor
|
||
|
||
|
||
class GANTrainer:
|
||
"""Trainer class for GAN model."""
|
||
|
||
def __init__(
|
||
self,
|
||
model: ModuleType,
|
||
train_loader: DataLoader,
|
||
val_loader: DataLoader,
|
||
device: torch.device,
|
||
config: Dict[str, Any],
|
||
):
|
||
"""
|
||
Initialize the GAN trainer.
|
||
|
||
Args:
|
||
model: GAN model (ImageGAN)
|
||
train_loader: Training data loader
|
||
val_loader: Validation data loader
|
||
device: Device to run training on
|
||
config: Training configuration dictionary
|
||
"""
|
||
self.model = model.to(device)
|
||
self.train_loader = train_loader
|
||
self.val_loader = val_loader
|
||
self.device = device
|
||
self.config = config
|
||
|
||
# Optimizers
|
||
lr = config.get("learning_rate", 2e-4)
|
||
beta1 = config.get("beta1", 0.5)
|
||
beta2 = config.get("beta2", 0.999)
|
||
|
||
# Separate optimizers for generator and discriminator
|
||
# Note: self.model is expected to have .generator and .discriminator attributes
|
||
self.optimizer_G = optim.Adam(
|
||
self.model.generator.parameters(), lr=lr, betas=(beta1, beta2)
|
||
)
|
||
self.optimizer_D = optim.Adam(
|
||
self.model.discriminator.parameters(), lr=lr, betas=(beta1, beta2)
|
||
)
|
||
|
||
# Training state
|
||
self.current_epoch = 0
|
||
self.best_val_loss = float("inf")
|
||
self.g_losses: List[float] = []
|
||
self.d_losses: List[float] = []
|
||
self.val_g_losses: List[float] = []
|
||
self.val_d_losses: List[float] = []
|
||
|
||
# Create output directory
|
||
self.output_dir = Path(config.get("output_dir", "runs/gan"))
|
||
self.output_dir.mkdir(parents=True, exist_ok=True)
|
||
|
||
# TensorBoard writer
|
||
self.writer = SummaryWriter(log_dir=self.output_dir / "tensorboard")
|
||
|
||
# Save configuration
|
||
config_path = self.output_dir / "config.json"
|
||
with open(config_path, "w") as f:
|
||
json.dump(config, f, indent=2)
|
||
|
||
print(f"Training configuration saved to {config_path}")
|
||
# Access parameters through the model's generator and discriminator
|
||
generator_params = sum(p.numel() for p in self.model.generator.parameters())
|
||
discriminator_params = sum(
|
||
p.numel() for p in self.model.discriminator.parameters()
|
||
)
|
||
|
||
print(f"Generator has {generator_params:,} parameters")
|
||
print(f"Discriminator has {discriminator_params:,} parameters")
|
||
|
||
def train_epoch(self) -> Tuple[float, float]:
|
||
"""
|
||
Train for one epoch.
|
||
|
||
Returns:
|
||
Tuple of (average generator loss, average discriminator loss)
|
||
"""
|
||
self.model.train()
|
||
total_g_loss = 0.0
|
||
total_d_loss = 0.0
|
||
num_batches = len(self.train_loader)
|
||
|
||
progress_bar = tqdm(self.train_loader, desc=f"Epoch {self.current_epoch + 1}")
|
||
for batch_idx, batch in enumerate(progress_bar):
|
||
# Move data to device
|
||
yandex_img = batch["yandex_img"].to(self.device)
|
||
google_img = batch["google_img"].to(self.device)
|
||
|
||
# ========== Train Discriminator ==========
|
||
self.optimizer_D.zero_grad()
|
||
|
||
# Generate fake image
|
||
with torch.no_grad():
|
||
fake_google_img = self.model.generator(yandex_img)
|
||
|
||
# Discriminator loss - returns tuple of tensors
|
||
d_loss_tuple = self.model.discriminator_step(
|
||
yandex_img, google_img, fake_google_img
|
||
)
|
||
d_loss, d_real_loss, d_fake_loss = d_loss_tuple
|
||
|
||
# Backward and optimize discriminator
|
||
d_loss.backward()
|
||
self.optimizer_D.step()
|
||
|
||
# ========== Train Generator ==========
|
||
self.optimizer_G.zero_grad()
|
||
|
||
# Generate fake image
|
||
fake_google_img = self.model.generator(yandex_img)
|
||
|
||
# Generator loss - returns tuple of tensors
|
||
g_loss_tuple = self.model.generator_step(yandex_img, google_img)
|
||
g_loss, g_gan_loss, g_l1_loss = g_loss_tuple
|
||
|
||
# Backward and optimize generator
|
||
g_loss.backward()
|
||
self.optimizer_G.step()
|
||
|
||
# Update statistics
|
||
total_g_loss += g_loss.item()
|
||
total_d_loss += d_loss.item()
|
||
|
||
# Update progress bar
|
||
progress_bar.set_postfix(
|
||
{
|
||
"g_loss": g_loss.item(),
|
||
"d_loss": d_loss.item(),
|
||
"g_l1": g_l1_loss.item(),
|
||
"d_real": d_real_loss.item(),
|
||
"d_fake": d_fake_loss.item(),
|
||
}
|
||
)
|
||
|
||
# Log batch losses to TensorBoard
|
||
global_step = self.current_epoch * num_batches + batch_idx
|
||
self.writer.add_scalar("train/batch_g_loss", g_loss.item(), global_step)
|
||
self.writer.add_scalar("train/batch_d_loss", d_loss.item(), global_step)
|
||
self.writer.add_scalar(
|
||
"train/batch_g_l1_loss", g_l1_loss.item(), global_step
|
||
)
|
||
self.writer.add_scalar(
|
||
"train/batch_d_real_loss", d_real_loss.item(), global_step
|
||
)
|
||
self.writer.add_scalar(
|
||
"train/batch_d_fake_loss", d_fake_loss.item(), global_step
|
||
)
|
||
|
||
avg_g_loss = total_g_loss / num_batches
|
||
avg_d_loss = total_d_loss / num_batches
|
||
self.g_losses.append(avg_g_loss)
|
||
self.d_losses.append(avg_d_loss)
|
||
|
||
return avg_g_loss, avg_d_loss
|
||
|
||
def validate(self) -> Tuple[float, float]:
|
||
"""
|
||
Validate the model.
|
||
|
||
Returns:
|
||
Tuple of (average generator validation loss, average discriminator validation loss)
|
||
"""
|
||
self.model.eval()
|
||
total_g_loss = 0.0
|
||
total_d_loss = 0.0
|
||
|
||
progress_bar = tqdm(self.val_loader, desc="Validation")
|
||
for batch in progress_bar:
|
||
# Move data to device
|
||
yandex_img = batch["yandex_img"].to(self.device)
|
||
google_img = batch["google_img"].to(self.device)
|
||
|
||
with torch.no_grad():
|
||
# Generate fake image
|
||
fake_google_img = self.model.generator(yandex_img)
|
||
|
||
# Generator loss - returns tuple of tensors
|
||
g_loss_tuple = self.model.generator_step(yandex_img, google_img)
|
||
g_loss, g_gan_loss, g_l1_loss = g_loss_tuple
|
||
|
||
# Discriminator loss - returns tuple of tensors
|
||
d_loss_tuple = self.model.discriminator_step(
|
||
yandex_img, google_img, fake_google_img
|
||
)
|
||
d_loss, d_real_loss, d_fake_loss = d_loss_tuple
|
||
|
||
# Update statistics
|
||
total_g_loss += g_loss.item()
|
||
total_d_loss += d_loss.item()
|
||
|
||
# Update progress bar
|
||
progress_bar.set_postfix({"g_loss": g_loss.item(), "d_loss": d_loss.item()})
|
||
|
||
avg_g_loss = total_g_loss / len(self.val_loader)
|
||
avg_d_loss = total_d_loss / len(self.val_loader)
|
||
self.val_g_losses.append(avg_g_loss)
|
||
self.val_d_losses.append(avg_d_loss)
|
||
|
||
return avg_g_loss, avg_d_loss
|
||
|
||
def save_checkpoint(self, is_best: bool = False):
|
||
"""
|
||
Save training checkpoint.
|
||
|
||
Args:
|
||
is_best: Whether this is the best model so far
|
||
"""
|
||
checkpoint = {
|
||
"epoch": self.current_epoch,
|
||
"model_state_dict": self.model.state_dict(),
|
||
"optimizer_G_state_dict": self.optimizer_G.state_dict(),
|
||
"optimizer_D_state_dict": self.optimizer_D.state_dict(),
|
||
"g_losses": self.g_losses,
|
||
"d_losses": self.d_losses,
|
||
"val_g_losses": self.val_g_losses,
|
||
"val_d_losses": self.val_d_losses,
|
||
"best_val_loss": self.best_val_loss,
|
||
"config": self.config,
|
||
}
|
||
|
||
# Save regular checkpoint
|
||
checkpoint_path = (
|
||
self.output_dir / f"checkpoint_epoch_{self.current_epoch + 1}.pth"
|
||
)
|
||
torch.save(checkpoint, checkpoint_path)
|
||
|
||
# Save best model
|
||
if is_best:
|
||
best_path = self.output_dir / "model_best.pth"
|
||
torch.save(checkpoint, best_path)
|
||
print(f"Best model saved to {best_path}")
|
||
|
||
def load_checkpoint(self, checkpoint_path: str, resume_training: bool = False):
|
||
"""
|
||
Load training checkpoint.
|
||
|
||
Args:
|
||
checkpoint_path: Path to checkpoint file
|
||
resume_training: Если True, продолжить обучение с сохраненной эпохи
|
||
"""
|
||
checkpoint = torch.load(checkpoint_path, map_location=self.device)
|
||
|
||
self.model.load_state_dict(checkpoint["model_state_dict"])
|
||
self.optimizer_G.load_state_dict(checkpoint["optimizer_G_state_dict"])
|
||
self.optimizer_D.load_state_dict(checkpoint["optimizer_D_state_dict"])
|
||
|
||
self.current_epoch = checkpoint["epoch"]
|
||
self.g_losses = checkpoint["g_losses"]
|
||
self.d_losses = checkpoint["d_losses"]
|
||
self.val_g_losses = checkpoint["val_g_losses"]
|
||
self.val_d_losses = checkpoint["val_d_losses"]
|
||
self.best_val_loss = checkpoint["best_val_loss"]
|
||
|
||
if resume_training:
|
||
print(f"Resuming training from epoch {self.current_epoch + 1}")
|
||
else:
|
||
print(f"Loaded checkpoint from epoch {self.current_epoch + 1}")
|
||
|
||
def train(self, num_epochs: int, start_epoch: int = 0):
|
||
"""
|
||
Train the model for specified number of epochs.
|
||
|
||
Args:
|
||
num_epochs: Number of epochs to train
|
||
start_epoch: Starting epoch (useful when resuming training)
|
||
"""
|
||
print(
|
||
f"Starting GAN training for {num_epochs} epochs from epoch {start_epoch + 1}..."
|
||
)
|
||
start_time = time.time()
|
||
|
||
for epoch in range(start_epoch, start_epoch + num_epochs):
|
||
self.current_epoch = epoch
|
||
|
||
# Train for one epoch
|
||
train_g_loss, train_d_loss = self.train_epoch()
|
||
|
||
# Validate
|
||
val_g_loss, val_d_loss = self.validate()
|
||
|
||
# Log to TensorBoard
|
||
self.writer.add_scalar("train/epoch_g_loss", train_g_loss, epoch)
|
||
self.writer.add_scalar("train/epoch_d_loss", train_d_loss, epoch)
|
||
self.writer.add_scalar("val/epoch_g_loss", val_g_loss, epoch)
|
||
self.writer.add_scalar("val/epoch_d_loss", val_d_loss, epoch)
|
||
|
||
# Print epoch summary
|
||
print(f"\nEpoch {epoch + 1}/{num_epochs}:")
|
||
print(" Generator:")
|
||
print(f" Train Loss: {train_g_loss:.6f}")
|
||
print(f" Val Loss: {val_g_loss:.6f}")
|
||
print(" Discriminator:")
|
||
print(f" Train Loss: {train_d_loss:.6f}")
|
||
print(f" Val Loss: {val_d_loss:.6f}")
|
||
|
||
# Save checkpoint
|
||
val_total_loss = val_g_loss + val_d_loss
|
||
is_best = val_total_loss < self.best_val_loss
|
||
if is_best:
|
||
self.best_val_loss = val_total_loss
|
||
|
||
self.save_checkpoint(is_best=is_best)
|
||
|
||
# Early stopping
|
||
if self.config.get("early_stopping_patience", 0) > 0:
|
||
val_losses = [
|
||
g + d for g, d in zip(self.val_g_losses, self.val_d_losses)
|
||
]
|
||
if (
|
||
epoch - np.argmin(val_losses)
|
||
>= self.config["early_stopping_patience"]
|
||
):
|
||
print(f"Early stopping at epoch {epoch + 1}")
|
||
break
|
||
|
||
# Training completed
|
||
training_time = time.time() - start_time
|
||
print(f"\nTraining completed in {training_time:.2f} seconds")
|
||
print(f"Best validation total loss: {self.best_val_loss:.6f}")
|
||
|
||
# Save final model
|
||
final_model_path = self.output_dir / "model_final.pth"
|
||
torch.save(self.model.state_dict(), final_model_path)
|
||
print(f"Final model saved to {final_model_path}")
|
||
|
||
# Save training history
|
||
history_path = self.output_dir / "training_history.json"
|
||
history = {
|
||
"g_losses": self.g_losses,
|
||
"d_losses": self.d_losses,
|
||
"val_g_losses": self.val_g_losses,
|
||
"val_d_losses": self.val_d_losses,
|
||
"best_val_loss": self.best_val_loss,
|
||
"total_epochs": self.current_epoch + 1,
|
||
}
|
||
with open(history_path, "w") as f:
|
||
json.dump(history, f, indent=2)
|
||
print(f"Training history saved to {history_path}")
|
||
|
||
# Close TensorBoard writer
|
||
self.writer.close()
|
||
|
||
def evaluate(self, test_loader: DataLoader) -> Dict:
|
||
"""
|
||
Evaluate the model on test data.
|
||
|
||
Args:
|
||
test_loader: Test data loader
|
||
|
||
Returns:
|
||
Dictionary with evaluation metrics
|
||
"""
|
||
self.model.eval()
|
||
total_g_loss = 0.0
|
||
total_d_loss = 0.0
|
||
|
||
print("Evaluating model on test set...")
|
||
|
||
for batch in tqdm(test_loader, desc="Evaluation"):
|
||
# Move data to device
|
||
yandex_img = batch["yandex_img"].to(self.device)
|
||
google_img = batch["google_img"].to(self.device)
|
||
|
||
with torch.no_grad():
|
||
# Generate fake image
|
||
fake_google_img = self.model.generator(yandex_img)
|
||
|
||
# Generator loss - returns tuple of tensors
|
||
g_loss_tuple = self.model.generator_step(yandex_img, google_img)
|
||
g_loss, g_gan_loss, g_l1_loss = g_loss_tuple
|
||
|
||
# Discriminator loss - returns tuple of tensors
|
||
d_loss_tuple = self.model.discriminator_step(
|
||
yandex_img, google_img, fake_google_img
|
||
)
|
||
d_loss, d_real_loss, d_fake_loss = d_loss_tuple
|
||
|
||
# Update statistics
|
||
total_g_loss += g_loss.item()
|
||
total_d_loss += d_loss.item()
|
||
|
||
avg_g_loss = total_g_loss / len(test_loader)
|
||
avg_d_loss = total_d_loss / len(test_loader)
|
||
|
||
metrics = {
|
||
"generator_loss": avg_g_loss,
|
||
"discriminator_loss": avg_d_loss,
|
||
"total_loss": avg_g_loss + avg_d_loss,
|
||
}
|
||
|
||
print("\nTest Results:")
|
||
print(f" Generator Loss: {avg_g_loss:.6f}")
|
||
print(f" Discriminator Loss: {avg_d_loss:.6f}")
|
||
print(f" Total Loss: {avg_g_loss + avg_d_loss:.6f}")
|
||
|
||
return metrics
|