feat: add models
This commit is contained in:
415
models/GAN/trainer.py
Normal file
415
models/GAN/trainer.py
Normal file
@@ -0,0 +1,415 @@
|
||||
import json
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
from torch.utils.data import DataLoader
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
from tqdm import tqdm
|
||||
|
||||
# Type aliases
|
||||
ModuleType = nn.Module
|
||||
TensorType = torch.Tensor
|
||||
|
||||
|
||||
class GANTrainer:
|
||||
"""Trainer class for GAN model."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: ModuleType,
|
||||
train_loader: DataLoader,
|
||||
val_loader: DataLoader,
|
||||
device: torch.device,
|
||||
config: Dict[str, Any],
|
||||
):
|
||||
"""
|
||||
Initialize the GAN trainer.
|
||||
|
||||
Args:
|
||||
model: GAN model (ImageGAN)
|
||||
train_loader: Training data loader
|
||||
val_loader: Validation data loader
|
||||
device: Device to run training on
|
||||
config: Training configuration dictionary
|
||||
"""
|
||||
self.model = model.to(device)
|
||||
self.train_loader = train_loader
|
||||
self.val_loader = val_loader
|
||||
self.device = device
|
||||
self.config = config
|
||||
|
||||
# Optimizers
|
||||
lr = config.get("learning_rate", 2e-4)
|
||||
beta1 = config.get("beta1", 0.5)
|
||||
beta2 = config.get("beta2", 0.999)
|
||||
|
||||
# Separate optimizers for generator and discriminator
|
||||
# Note: self.model is expected to have .generator and .discriminator attributes
|
||||
self.optimizer_G = optim.Adam(
|
||||
self.model.generator.parameters(), lr=lr, betas=(beta1, beta2)
|
||||
)
|
||||
self.optimizer_D = optim.Adam(
|
||||
self.model.discriminator.parameters(), lr=lr, betas=(beta1, beta2)
|
||||
)
|
||||
|
||||
# Training state
|
||||
self.current_epoch = 0
|
||||
self.best_val_loss = float("inf")
|
||||
self.g_losses: List[float] = []
|
||||
self.d_losses: List[float] = []
|
||||
self.val_g_losses: List[float] = []
|
||||
self.val_d_losses: List[float] = []
|
||||
|
||||
# Create output directory
|
||||
self.output_dir = Path(config.get("output_dir", "runs/gan"))
|
||||
self.output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# TensorBoard writer
|
||||
self.writer = SummaryWriter(log_dir=self.output_dir / "tensorboard")
|
||||
|
||||
# Save configuration
|
||||
config_path = self.output_dir / "config.json"
|
||||
with open(config_path, "w") as f:
|
||||
json.dump(config, f, indent=2)
|
||||
|
||||
print(f"Training configuration saved to {config_path}")
|
||||
# Access parameters through the model's generator and discriminator
|
||||
generator_params = sum(p.numel() for p in self.model.generator.parameters())
|
||||
discriminator_params = sum(
|
||||
p.numel() for p in self.model.discriminator.parameters()
|
||||
)
|
||||
|
||||
print(f"Generator has {generator_params:,} parameters")
|
||||
print(f"Discriminator has {discriminator_params:,} parameters")
|
||||
|
||||
def train_epoch(self) -> Tuple[float, float]:
|
||||
"""
|
||||
Train for one epoch.
|
||||
|
||||
Returns:
|
||||
Tuple of (average generator loss, average discriminator loss)
|
||||
"""
|
||||
self.model.train()
|
||||
total_g_loss = 0.0
|
||||
total_d_loss = 0.0
|
||||
num_batches = len(self.train_loader)
|
||||
|
||||
progress_bar = tqdm(self.train_loader, desc=f"Epoch {self.current_epoch + 1}")
|
||||
for batch_idx, batch in enumerate(progress_bar):
|
||||
# Move data to device
|
||||
yandex_img = batch["yandex_img"].to(self.device)
|
||||
google_img = batch["google_img"].to(self.device)
|
||||
|
||||
# ========== Train Discriminator ==========
|
||||
self.optimizer_D.zero_grad()
|
||||
|
||||
# Generate fake image
|
||||
with torch.no_grad():
|
||||
fake_google_img = self.model.generator(yandex_img)
|
||||
|
||||
# Discriminator loss - returns tuple of tensors
|
||||
d_loss_tuple = self.model.discriminator_step(
|
||||
yandex_img, google_img, fake_google_img
|
||||
)
|
||||
d_loss, d_real_loss, d_fake_loss = d_loss_tuple
|
||||
|
||||
# Backward and optimize discriminator
|
||||
d_loss.backward()
|
||||
self.optimizer_D.step()
|
||||
|
||||
# ========== Train Generator ==========
|
||||
self.optimizer_G.zero_grad()
|
||||
|
||||
# Generate fake image
|
||||
fake_google_img = self.model.generator(yandex_img)
|
||||
|
||||
# Generator loss - returns tuple of tensors
|
||||
g_loss_tuple = self.model.generator_step(yandex_img, google_img)
|
||||
g_loss, g_gan_loss, g_l1_loss = g_loss_tuple
|
||||
|
||||
# Backward and optimize generator
|
||||
g_loss.backward()
|
||||
self.optimizer_G.step()
|
||||
|
||||
# Update statistics
|
||||
total_g_loss += g_loss.item()
|
||||
total_d_loss += d_loss.item()
|
||||
|
||||
# Update progress bar
|
||||
progress_bar.set_postfix(
|
||||
{
|
||||
"g_loss": g_loss.item(),
|
||||
"d_loss": d_loss.item(),
|
||||
"g_l1": g_l1_loss.item(),
|
||||
"d_real": d_real_loss.item(),
|
||||
"d_fake": d_fake_loss.item(),
|
||||
}
|
||||
)
|
||||
|
||||
# Log batch losses to TensorBoard
|
||||
global_step = self.current_epoch * num_batches + batch_idx
|
||||
self.writer.add_scalar("train/batch_g_loss", g_loss.item(), global_step)
|
||||
self.writer.add_scalar("train/batch_d_loss", d_loss.item(), global_step)
|
||||
self.writer.add_scalar(
|
||||
"train/batch_g_l1_loss", g_l1_loss.item(), global_step
|
||||
)
|
||||
self.writer.add_scalar(
|
||||
"train/batch_d_real_loss", d_real_loss.item(), global_step
|
||||
)
|
||||
self.writer.add_scalar(
|
||||
"train/batch_d_fake_loss", d_fake_loss.item(), global_step
|
||||
)
|
||||
|
||||
avg_g_loss = total_g_loss / num_batches
|
||||
avg_d_loss = total_d_loss / num_batches
|
||||
self.g_losses.append(avg_g_loss)
|
||||
self.d_losses.append(avg_d_loss)
|
||||
|
||||
return avg_g_loss, avg_d_loss
|
||||
|
||||
def validate(self) -> Tuple[float, float]:
|
||||
"""
|
||||
Validate the model.
|
||||
|
||||
Returns:
|
||||
Tuple of (average generator validation loss, average discriminator validation loss)
|
||||
"""
|
||||
self.model.eval()
|
||||
total_g_loss = 0.0
|
||||
total_d_loss = 0.0
|
||||
|
||||
progress_bar = tqdm(self.val_loader, desc="Validation")
|
||||
for batch in progress_bar:
|
||||
# Move data to device
|
||||
yandex_img = batch["yandex_img"].to(self.device)
|
||||
google_img = batch["google_img"].to(self.device)
|
||||
|
||||
with torch.no_grad():
|
||||
# Generate fake image
|
||||
fake_google_img = self.model.generator(yandex_img)
|
||||
|
||||
# Generator loss - returns tuple of tensors
|
||||
g_loss_tuple = self.model.generator_step(yandex_img, google_img)
|
||||
g_loss, g_gan_loss, g_l1_loss = g_loss_tuple
|
||||
|
||||
# Discriminator loss - returns tuple of tensors
|
||||
d_loss_tuple = self.model.discriminator_step(
|
||||
yandex_img, google_img, fake_google_img
|
||||
)
|
||||
d_loss, d_real_loss, d_fake_loss = d_loss_tuple
|
||||
|
||||
# Update statistics
|
||||
total_g_loss += g_loss.item()
|
||||
total_d_loss += d_loss.item()
|
||||
|
||||
# Update progress bar
|
||||
progress_bar.set_postfix({"g_loss": g_loss.item(), "d_loss": d_loss.item()})
|
||||
|
||||
avg_g_loss = total_g_loss / len(self.val_loader)
|
||||
avg_d_loss = total_d_loss / len(self.val_loader)
|
||||
self.val_g_losses.append(avg_g_loss)
|
||||
self.val_d_losses.append(avg_d_loss)
|
||||
|
||||
return avg_g_loss, avg_d_loss
|
||||
|
||||
def save_checkpoint(self, is_best: bool = False):
|
||||
"""
|
||||
Save training checkpoint.
|
||||
|
||||
Args:
|
||||
is_best: Whether this is the best model so far
|
||||
"""
|
||||
checkpoint = {
|
||||
"epoch": self.current_epoch,
|
||||
"model_state_dict": self.model.state_dict(),
|
||||
"optimizer_G_state_dict": self.optimizer_G.state_dict(),
|
||||
"optimizer_D_state_dict": self.optimizer_D.state_dict(),
|
||||
"g_losses": self.g_losses,
|
||||
"d_losses": self.d_losses,
|
||||
"val_g_losses": self.val_g_losses,
|
||||
"val_d_losses": self.val_d_losses,
|
||||
"best_val_loss": self.best_val_loss,
|
||||
"config": self.config,
|
||||
}
|
||||
|
||||
# Save regular checkpoint
|
||||
checkpoint_path = (
|
||||
self.output_dir / f"checkpoint_epoch_{self.current_epoch + 1}.pth"
|
||||
)
|
||||
torch.save(checkpoint, checkpoint_path)
|
||||
|
||||
# Save best model
|
||||
if is_best:
|
||||
best_path = self.output_dir / "model_best.pth"
|
||||
torch.save(checkpoint, best_path)
|
||||
print(f"Best model saved to {best_path}")
|
||||
|
||||
def load_checkpoint(self, checkpoint_path: str, resume_training: bool = False):
|
||||
"""
|
||||
Load training checkpoint.
|
||||
|
||||
Args:
|
||||
checkpoint_path: Path to checkpoint file
|
||||
resume_training: Если True, продолжить обучение с сохраненной эпохи
|
||||
"""
|
||||
checkpoint = torch.load(checkpoint_path, map_location=self.device)
|
||||
|
||||
self.model.load_state_dict(checkpoint["model_state_dict"])
|
||||
self.optimizer_G.load_state_dict(checkpoint["optimizer_G_state_dict"])
|
||||
self.optimizer_D.load_state_dict(checkpoint["optimizer_D_state_dict"])
|
||||
|
||||
self.current_epoch = checkpoint["epoch"]
|
||||
self.g_losses = checkpoint["g_losses"]
|
||||
self.d_losses = checkpoint["d_losses"]
|
||||
self.val_g_losses = checkpoint["val_g_losses"]
|
||||
self.val_d_losses = checkpoint["val_d_losses"]
|
||||
self.best_val_loss = checkpoint["best_val_loss"]
|
||||
|
||||
if resume_training:
|
||||
print(f"Resuming training from epoch {self.current_epoch + 1}")
|
||||
else:
|
||||
print(f"Loaded checkpoint from epoch {self.current_epoch + 1}")
|
||||
|
||||
def train(self, num_epochs: int, start_epoch: int = 0):
|
||||
"""
|
||||
Train the model for specified number of epochs.
|
||||
|
||||
Args:
|
||||
num_epochs: Number of epochs to train
|
||||
start_epoch: Starting epoch (useful when resuming training)
|
||||
"""
|
||||
print(
|
||||
f"Starting GAN training for {num_epochs} epochs from epoch {start_epoch + 1}..."
|
||||
)
|
||||
start_time = time.time()
|
||||
|
||||
for epoch in range(start_epoch, start_epoch + num_epochs):
|
||||
self.current_epoch = epoch
|
||||
|
||||
# Train for one epoch
|
||||
train_g_loss, train_d_loss = self.train_epoch()
|
||||
|
||||
# Validate
|
||||
val_g_loss, val_d_loss = self.validate()
|
||||
|
||||
# Log to TensorBoard
|
||||
self.writer.add_scalar("train/epoch_g_loss", train_g_loss, epoch)
|
||||
self.writer.add_scalar("train/epoch_d_loss", train_d_loss, epoch)
|
||||
self.writer.add_scalar("val/epoch_g_loss", val_g_loss, epoch)
|
||||
self.writer.add_scalar("val/epoch_d_loss", val_d_loss, epoch)
|
||||
|
||||
# Print epoch summary
|
||||
print(f"\nEpoch {epoch + 1}/{num_epochs}:")
|
||||
print(" Generator:")
|
||||
print(f" Train Loss: {train_g_loss:.6f}")
|
||||
print(f" Val Loss: {val_g_loss:.6f}")
|
||||
print(" Discriminator:")
|
||||
print(f" Train Loss: {train_d_loss:.6f}")
|
||||
print(f" Val Loss: {val_d_loss:.6f}")
|
||||
|
||||
# Save checkpoint
|
||||
val_total_loss = val_g_loss + val_d_loss
|
||||
is_best = val_total_loss < self.best_val_loss
|
||||
if is_best:
|
||||
self.best_val_loss = val_total_loss
|
||||
|
||||
self.save_checkpoint(is_best=is_best)
|
||||
|
||||
# Early stopping
|
||||
if self.config.get("early_stopping_patience", 0) > 0:
|
||||
val_losses = [
|
||||
g + d for g, d in zip(self.val_g_losses, self.val_d_losses)
|
||||
]
|
||||
if (
|
||||
epoch - np.argmin(val_losses)
|
||||
>= self.config["early_stopping_patience"]
|
||||
):
|
||||
print(f"Early stopping at epoch {epoch + 1}")
|
||||
break
|
||||
|
||||
# Training completed
|
||||
training_time = time.time() - start_time
|
||||
print(f"\nTraining completed in {training_time:.2f} seconds")
|
||||
print(f"Best validation total loss: {self.best_val_loss:.6f}")
|
||||
|
||||
# Save final model
|
||||
final_model_path = self.output_dir / "model_final.pth"
|
||||
torch.save(self.model.state_dict(), final_model_path)
|
||||
print(f"Final model saved to {final_model_path}")
|
||||
|
||||
# Save training history
|
||||
history_path = self.output_dir / "training_history.json"
|
||||
history = {
|
||||
"g_losses": self.g_losses,
|
||||
"d_losses": self.d_losses,
|
||||
"val_g_losses": self.val_g_losses,
|
||||
"val_d_losses": self.val_d_losses,
|
||||
"best_val_loss": self.best_val_loss,
|
||||
"total_epochs": self.current_epoch + 1,
|
||||
}
|
||||
with open(history_path, "w") as f:
|
||||
json.dump(history, f, indent=2)
|
||||
print(f"Training history saved to {history_path}")
|
||||
|
||||
# Close TensorBoard writer
|
||||
self.writer.close()
|
||||
|
||||
def evaluate(self, test_loader: DataLoader) -> Dict:
|
||||
"""
|
||||
Evaluate the model on test data.
|
||||
|
||||
Args:
|
||||
test_loader: Test data loader
|
||||
|
||||
Returns:
|
||||
Dictionary with evaluation metrics
|
||||
"""
|
||||
self.model.eval()
|
||||
total_g_loss = 0.0
|
||||
total_d_loss = 0.0
|
||||
|
||||
print("Evaluating model on test set...")
|
||||
|
||||
for batch in tqdm(test_loader, desc="Evaluation"):
|
||||
# Move data to device
|
||||
yandex_img = batch["yandex_img"].to(self.device)
|
||||
google_img = batch["google_img"].to(self.device)
|
||||
|
||||
with torch.no_grad():
|
||||
# Generate fake image
|
||||
fake_google_img = self.model.generator(yandex_img)
|
||||
|
||||
# Generator loss - returns tuple of tensors
|
||||
g_loss_tuple = self.model.generator_step(yandex_img, google_img)
|
||||
g_loss, g_gan_loss, g_l1_loss = g_loss_tuple
|
||||
|
||||
# Discriminator loss - returns tuple of tensors
|
||||
d_loss_tuple = self.model.discriminator_step(
|
||||
yandex_img, google_img, fake_google_img
|
||||
)
|
||||
d_loss, d_real_loss, d_fake_loss = d_loss_tuple
|
||||
|
||||
# Update statistics
|
||||
total_g_loss += g_loss.item()
|
||||
total_d_loss += d_loss.item()
|
||||
|
||||
avg_g_loss = total_g_loss / len(test_loader)
|
||||
avg_d_loss = total_d_loss / len(test_loader)
|
||||
|
||||
metrics = {
|
||||
"generator_loss": avg_g_loss,
|
||||
"discriminator_loss": avg_d_loss,
|
||||
"total_loss": avg_g_loss + avg_d_loss,
|
||||
}
|
||||
|
||||
print("\nTest Results:")
|
||||
print(f" Generator Loss: {avg_g_loss:.6f}")
|
||||
print(f" Discriminator Loss: {avg_d_loss:.6f}")
|
||||
print(f" Total Loss: {avg_g_loss + avg_d_loss:.6f}")
|
||||
|
||||
return metrics
|
||||
Reference in New Issue
Block a user