ref: simplify and modularize GAN codebase

This commit is contained in:
2026-03-22 21:10:05 +03:00
parent 05f8746d58
commit c6df3edab8
11 changed files with 677 additions and 1984 deletions

36
models/GAN/config.py Normal file
View File

@@ -0,0 +1,36 @@
"""Configuration for GAN training."""
def create_config():
"""Create default configuration dictionary."""
return {
# Optimizer params
"learning_rate": 2e-4,
"beta1": 0.5,
"beta2": 0.999,
# Training params
"batch_size": 32,
"epochs": 100,
# GAN params
"gan_mode": "vanilla",
"lambda_L1": 100.0,
# Regularization
"grad_clip": 1.0,
# Early stopping
"early_stopping_patience": 20,
# Output
"output_dir": "runs/gan_training",
# Logging
"log_interval": 10,
"save_interval": 5,
# Data
"data_dir": r"C:\Users\admin\Projects\autopilot\datasets\ya_go_maps\images",
"image_size": [256, 256],
"train_split": 0.8,
"num_workers": 0,
}
if __name__ == "__main__":
config = create_config()
print("Default config:", config)

162
models/GAN/dataloader.py Normal file
View File

@@ -0,0 +1,162 @@
"""Data loader for Yandex-to-Google image translation."""
import os
from typing import Dict, List, Tuple
import torch
from PIL import Image
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
class YaGoDataset(Dataset):
"""Dataset loading pairs of Yandex and Google map images."""
def __init__(
self,
root_dir: str,
image_size: Tuple[int, int] = (256, 256),
augment: bool = False,
):
"""
Args:
root_dir: Directory with images named {idx:04d}_google.png and {idx:04d}_yandex.png
image_size: Target image size (height, width)
augment: Whether to apply augmentation (not implemented for simplicity)
"""
self.root_dir = root_dir
self.image_size = image_size
self.augment = augment
# Discover image pairs
self.pairs = self._find_pairs()
# Transform to tensor + normalization
self.transform = transforms.Compose(
[
transforms.ToTensor(),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
]
)
def _find_pairs(self) -> List[Dict]:
"""Find all matching Google-Yandex image pairs."""
pairs = []
google_files = [f for f in os.listdir(self.root_dir) if f.endswith("_google.png")]
for google_file in sorted(google_files):
idx_str = google_file.split("_")[0]
try:
idx = int(idx_str)
except ValueError:
continue
yandex_file = f"{idx:04d}_yandex.png"
yandex_path = os.path.join(self.root_dir, yandex_file)
if os.path.exists(yandex_path):
pairs.append(
{
"idx": idx,
"google_path": os.path.join(self.root_dir, google_file),
"yandex_path": yandex_path,
}
)
return pairs
def __len__(self) -> int:
return len(self.pairs)
def __getitem__(self, idx: int) -> dict:
pair = self.pairs[idx]
# Load images
google_img = Image.open(pair["google_path"]).convert("RGB")
yandex_img = Image.open(pair["yandex_path"]).convert("RGB")
# Resize
google_img = google_img.resize((self.image_size[1], self.image_size[0]))
yandex_img = yandex_img.resize((self.image_size[1], self.image_size[0]))
# Apply transforms
google_tensor = self.transform(google_img)
yandex_tensor = self.transform(yandex_img)
return {
"google_img": google_tensor,
"yandex_img": yandex_tensor,
"idx": torch.tensor(pair["idx"], dtype=torch.long),
}
def create_data_loaders(
root_dir: str,
batch_size: int = 32,
train_split: float = 0.8,
num_workers: int = 0,
image_size: Tuple[int, int] = (256, 256),
) -> Tuple[DataLoader, DataLoader]:
"""
Create train and validation data loaders.
Args:
root_dir: Directory with image pairs
batch_size: Batch size
train_split: Fraction for training (0.0-1.0)
num_workers: DataLoader workers
image_size: Target image size
Returns:
(train_loader, val_loader)
"""
# Full dataset
dataset = YaGoDataset(root_dir=root_dir, image_size=image_size)
# Split
dataset_size = len(dataset)
train_size = int(train_split * dataset_size)
indices = torch.randperm(dataset_size).tolist()
train_indices = indices[:train_size]
val_indices = indices[train_size:]
# Subsets
from torch.utils.data import Subset
train_dataset = Subset(dataset, train_indices)
val_dataset = Subset(dataset, val_indices)
# DataLoaders
train_loader = DataLoader(
train_dataset,
batch_size=batch_size,
shuffle=True,
num_workers=num_workers,
pin_memory=True,
)
val_loader = DataLoader(
val_dataset,
batch_size=batch_size,
shuffle=False,
num_workers=num_workers,
pin_memory=True,
)
return train_loader, val_loader
if __name__ == "__main__":
# Quick test
from config import create_config
config = create_config()
train_loader, val_loader = create_data_loaders(
root_dir=config["data_dir"],
batch_size=4,
image_size=tuple(config["image_size"]),
)
batch = next(iter(train_loader))
print(f"Batch shapes: google={batch['google_img'].shape}, yandex={batch['yandex_img'].shape}")
print(f"Train batches: {len(train_loader)}, Val batches: {len(val_loader)}")

View File

@@ -2,7 +2,7 @@
"cells": [
{
"cell_type": "code",
"execution_count": 23,
"execution_count": 1,
"id": "bb583d80",
"metadata": {},
"outputs": [],
@@ -37,7 +37,7 @@
},
{
"cell_type": "code",
"execution_count": 238,
"execution_count": 2,
"id": "92144cc0",
"metadata": {},
"outputs": [

View File

@@ -1,393 +0,0 @@
from typing import Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
class UNetDownBlock(nn.Module):
"""Блок downsampling для U-Net генератора"""
def __init__(
self,
in_channels: int,
out_channels: int,
normalize: bool = True,
dropout: float = 0.0,
):
super().__init__()
layers = [
nn.Conv2d(
in_channels,
out_channels,
kernel_size=4,
stride=2,
padding=1,
bias=False,
)
]
if normalize:
layers.append(nn.BatchNorm2d(out_channels))
layers.append(nn.LeakyReLU(0.2, inplace=True))
if dropout > 0:
layers.append(nn.Dropout2d(dropout))
self.model = nn.Sequential(*layers)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.model(x)
class UNetUpBlock(nn.Module):
"""Блок upsampling для U-Net генератора"""
def __init__(self, in_channels: int, out_channels: int, dropout: float = 0.0):
super().__init__()
layers = [
nn.ConvTranspose2d(
in_channels,
out_channels,
kernel_size=4,
stride=2,
padding=1,
bias=False,
),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True),
]
if dropout > 0:
layers.append(nn.Dropout2d(dropout))
self.model = nn.Sequential(*layers)
def forward(self, x: torch.Tensor, skip_input: torch.Tensor) -> torch.Tensor:
x = self.model(x)
# Обрезаем skip connection до размера x, если необходимо
if x.shape != skip_input.shape:
diffY = skip_input.size(2) - x.size(2)
diffX = skip_input.size(3) - x.size(3)
x = F.pad(
x, [diffX // 2, diffX - diffX // 2, diffY // 2, diffY - diffY // 2]
)
x = torch.cat([x, skip_input], dim=1)
return x
class GeneratorUNet(nn.Module):
"""Генератор на основе U-Net архитектуры для преобразования Yandex → Google"""
def __init__(self, in_channels: int = 3, out_channels: int = 3):
super().__init__()
# Downsampling path
self.down1 = UNetDownBlock(in_channels, 64, normalize=False)
self.down2 = UNetDownBlock(64, 128)
self.down3 = UNetDownBlock(128, 256)
self.down4 = UNetDownBlock(256, 512)
self.down5 = UNetDownBlock(512, 512)
self.down6 = UNetDownBlock(512, 512)
self.down7 = UNetDownBlock(512, 512)
# Bottleneck
self.bottleneck = nn.Sequential(
nn.Conv2d(512, 512, kernel_size=4, stride=2, padding=1, bias=False),
nn.ReLU(inplace=True),
)
# Upsampling path
self.up1 = UNetUpBlock(512, 512, dropout=0.5)
self.up2 = UNetUpBlock(1024, 512, dropout=0.5)
self.up3 = UNetUpBlock(1024, 512, dropout=0.5)
self.up4 = UNetUpBlock(1024, 512)
self.up5 = UNetUpBlock(1024, 256)
self.up6 = UNetUpBlock(512, 128)
self.up7 = UNetUpBlock(256, 64)
# Final layer
self.final = nn.Sequential(
nn.ConvTranspose2d(128, out_channels, kernel_size=4, stride=2, padding=1),
nn.Tanh(),
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
# Downsampling
d1 = self.down1(x) # 350x350
d2 = self.down2(d1) # 175x175
d3 = self.down3(d2) # 88x88
d4 = self.down4(d3) # 44x44
d5 = self.down5(d4) # 22x22
d6 = self.down6(d5) # 11x11
d7 = self.down7(d6) # 6x6
# Bottleneck
u = self.bottleneck(d7) # 3x3
# Upsampling with skip connections
u = self.up1(u, d7) # 6x6
u = self.up2(u, d6) # 11x11
u = self.up3(u, d5) # 22x22
u = self.up4(u, d4) # 44x44
u = self.up5(u, d3) # 88x88
u = self.up6(u, d2) # 175x175
u = self.up7(u, d1) # 350x350
# Final output
return self.final(u) # 700x700
class DiscriminatorPatchGAN(nn.Module):
"""Дискриминатор PatchGAN для изображений 700x700"""
def __init__(
self, in_channels: int = 6
): # 3 для реального + 3 для сгенерированного
super().__init__()
def discriminator_block(
in_filters: int, out_filters: int, normalization: bool = True
):
"""Блок дискриминатора"""
layers = [
nn.Conv2d(in_filters, out_filters, kernel_size=4, stride=2, padding=1)
]
if normalization:
layers.append(nn.BatchNorm2d(out_filters))
layers.append(nn.LeakyReLU(0.2, inplace=True))
return layers
self.model = nn.Sequential(
*discriminator_block(in_channels, 64, normalization=False), # 350x350
*discriminator_block(64, 128), # 175x175
*discriminator_block(128, 256), # 88x88
*discriminator_block(256, 512), # 44x44
nn.Conv2d(512, 1, kernel_size=4, stride=1, padding=1), # 41x41
nn.Sigmoid(),
)
def forward(self, img_A: torch.Tensor, img_B: torch.Tensor) -> torch.Tensor:
"""
Принимает пару изображений (реальное и сгенерированное)
и возвращает вероятность того, что пара реальная
"""
# Объединяем два изображения по каналам
img_input = torch.cat((img_A, img_B), 1)
return self.model(img_input)
class GANLoss(nn.Module):
"""Функция потерь для GAN"""
def __init__(
self,
gan_mode: str = "vanilla",
target_real_label: float = 1.0,
target_fake_label: float = 0.0,
):
super().__init__()
self.register_buffer("real_label", torch.tensor(target_real_label))
self.register_buffer("fake_label", torch.tensor(target_fake_label))
self.gan_mode = gan_mode
if gan_mode == "vanilla":
self.loss = nn.BCEWithLogitsLoss()
elif gan_mode == "lsgan":
self.loss = nn.MSELoss()
elif gan_mode == "wgangp":
self.loss = None
else:
raise NotImplementedError(f"GAN mode {gan_mode} not implemented")
def get_target_tensor(
self, prediction: torch.Tensor, target_is_real: bool
) -> torch.Tensor:
"""Создает тензор меток"""
if target_is_real:
target_tensor = self.real_label
else:
target_tensor = self.fake_label
return target_tensor.expand_as(prediction)
def __call__(self, prediction: torch.Tensor, target_is_real: bool) -> torch.Tensor:
"""Вычисляет потери"""
if self.gan_mode in ["vanilla", "lsgan"]:
target_tensor = self.get_target_tensor(prediction, target_is_real)
loss = self.loss(prediction, target_tensor)
elif self.gan_mode == "wgangp":
if target_is_real:
loss = -prediction.mean()
else:
loss = prediction.mean()
return loss
class ImageGAN(nn.Module):
"""Основной класс GAN для преобразования изображений Yandex → Google"""
def __init__(
self,
input_channels: int = 3,
output_channels: int = 3,
gan_mode: str = "vanilla",
lambda_L1: float = 100.0,
use_cuda: bool = True,
):
super().__init__()
self.generator = GeneratorUNet(input_channels, output_channels)
self.discriminator = DiscriminatorPatchGAN(input_channels + output_channels)
self.gan_loss = GANLoss(gan_mode)
self.l1_loss = nn.L1Loss()
self.lambda_L1 = lambda_L1
self.device = torch.device(
"cuda" if use_cuda and torch.cuda.is_available() else "cpu"
)
self.to(self.device)
def forward(self, yandex_image: torch.Tensor) -> torch.Tensor:
"""Генерация изображения Google из Yandex"""
return self.generator(yandex_image)
def generator_step(
self, yandex_image: torch.Tensor, real_google_image: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Шаг обучения генератора
Returns:
total_loss: общие потери генератора
gan_loss: потери GAN
l1_loss: потери L1
"""
# Генерируем изображение
fake_google_image = self.generator(yandex_image)
# Оцениваем дискриминатором
fake_pred = self.discriminator(yandex_image, fake_google_image)
# Потери GAN (пытаемся обмануть дискриминатор)
gan_loss = self.gan_loss(fake_pred, True)
# Потери L1 для сохранения структуры
l1_loss = self.l1_loss(fake_google_image, real_google_image) * self.lambda_L1
# Общие потери
total_loss = gan_loss + l1_loss
return total_loss, gan_loss, l1_loss
def discriminator_step(
self,
yandex_image: torch.Tensor,
real_google_image: torch.Tensor,
fake_google_image: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Шаг обучения дискриминатора
Returns:
total_loss: общие потери дискриминатора
real_loss: потери на реальных изображениях
fake_loss: потери на сгенерированных изображениях
"""
# Предсказания для реальных пар
real_pred = self.discriminator(yandex_image, real_google_image)
real_loss = self.gan_loss(real_pred, True)
# Предсказания для сгенерированных пар
fake_pred = self.discriminator(yandex_image, fake_google_image.detach())
fake_loss = self.gan_loss(fake_pred, False)
# Общие потери дискриминатора
total_loss = (real_loss + fake_loss) * 0.5
return total_loss, real_loss, fake_loss
def to(self, device):
"""Перемещает модель на устройство"""
self.generator.to(device)
self.discriminator.to(device)
return self
def train_mode(self):
"""Переключает модель в режим обучения"""
self.generator.train()
self.discriminator.train()
def eval_mode(self):
"""Переключает модель в режим оценки"""
self.generator.eval()
self.discriminator.eval()
def save_checkpoint(self, path: str):
"""Сохраняет чекпоинт модели"""
checkpoint = {
"generator_state_dict": self.generator.state_dict(),
"discriminator_state_dict": self.discriminator.state_dict(),
"generator_optimizer_state_dict": getattr(
self.generator, "optimizer_state_dict", None
),
"discriminator_optimizer_state_dict": getattr(
self.discriminator, "optimizer_state_dict", None
),
}
torch.save(checkpoint, path)
def load_checkpoint(self, path: str):
"""Загружает чекпоинт модели"""
checkpoint = torch.load(path, map_location=self.device)
self.generator.load_state_dict(checkpoint["generator_state_dict"])
self.discriminator.load_state_dict(checkpoint["discriminator_state_dict"])
if checkpoint["generator_optimizer_state_dict"] is not None:
self.generator.optimizer_state_dict = checkpoint[
"generator_optimizer_state_dict"
]
if checkpoint["discriminator_optimizer_state_dict"] is not None:
self.discriminator.optimizer_state_dict = checkpoint[
"discriminator_optimizer_state_dict"
]
def create_image_gan(
input_channels: int = 3,
output_channels: int = 3,
gan_mode: str = "vanilla",
lambda_L1: float = 100.0,
use_cuda: bool = True,
) -> ImageGAN:
"""
Создает и возвращает модель GAN для преобразования изображений
Args:
input_channels: количество входных каналов (обычно 3 для RGB)
output_channels: количество выходных каналов (обычно 3 для RGB)
gan_mode: режим GAN ('vanilla', 'lsgan', 'wgangp')
lambda_L1: вес L1 потерь
use_cuda: использовать ли CUDA если доступно
Returns:
ImageGAN: модель GAN
"""
return ImageGAN(
input_channels=input_channels,
output_channels=output_channels,
gan_mode=gan_mode,
lambda_L1=lambda_L1,
use_cuda=use_cuda,
)
# Вспомогательные функции для инициализации весов
def weights_init_normal(m):
"""Инициализация весов с нормальным распределением"""
classname = m.__class__.__name__
if classname.find("Conv") != -1:
nn.init.normal_(m.weight.data, 0.0, 0.02)
elif classname.find("BatchNorm") != -1:
nn.init.normal_(m.weight.data, 1.0, 0.02)
nn.init.constant_(m.batch_norm.bias.data, 0.0)
def initialize_gan_weights(generator: nn.Module, discriminator: nn.Module):
"""Инициализирует веса генератора и дискриминатора"""
generator.apply(weights_init_normal)
discriminator.apply(weights_init_normal)

30
models/GAN/main.py Normal file
View File

@@ -0,0 +1,30 @@
"""Main entry point for GAN training."""
from config import create_config
from dataloader import create_data_loaders
from model import create_gan
from trainer import create_trainer
def main():
"""Run training pipeline."""
config = create_config()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Create components
model = create_gan(use_cuda=False) # Set to True to use GPU
train_loader, val_loader = create_data_loaders(
root_dir=config["data_dir"],
batch_size=config["batch_size"],
image_size=tuple(config["image_size"]),
num_workers=config["num_workers"],
)
trainer = create_trainer(model, train_loader, val_loader, config)
# Train
trainer.train(config["epochs"])
if __name__ == "__main__":
import torch
main()

View File

@@ -1,136 +0,0 @@
"""
Минимальный пример использования GAN trainer для преобразования Yandex → Google карт.
Этот пример показывает самый простой способ использования тренера.
"""
import sys
from pathlib import Path
import torch
from torch.utils.data import DataLoader, Dataset
# Добавляем путь к модулям
sys.path.append(str(Path(__file__).parent.parent.parent))
from models.GAN.gan import create_image_gan
from models.GAN.trainer import GANTrainer
class SimpleMapDataset(Dataset):
"""Простой датасет с фиктивными данными для примера."""
def __init__(self, num_samples=100, image_size=(256, 256)):
self.num_samples = num_samples
self.image_size = image_size
def __len__(self):
return self.num_samples
def __getitem__(self, idx):
# Создаем фиктивные изображения
# В реальном коде замените на загрузку реальных изображений
yandex_img = torch.randn(3, self.image_size[0], self.image_size[1])
google_img = torch.randn(3, self.image_size[0], self.image_size[1])
return {"yandex_img": yandex_img, "google_img": google_img}
def main():
"""Основная функция минимального примера."""
print("Минимальный пример использования GAN trainer")
print("=" * 50)
# 1. Конфигурация (минимальный набор параметров)
config = {
"learning_rate": 2e-4,
"batch_size": 4,
"output_dir": "runs/gan_minimal",
}
# 2. Устройство (CPU или GPU)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Используемое устройство: {device}")
# 3. Создание модели
print("\nСоздание GAN модели...")
model = create_image_gan(
input_channels=3,
output_channels=3,
gan_mode="vanilla", # Простейший режим
lambda_L1=100.0, # Стандартный вес L1 потерь
use_cuda=(device.type == "cuda"),
)
# 4. Создание даталоадеров
print("Создание даталоадеров...")
train_dataset = SimpleMapDataset(num_samples=50)
val_dataset = SimpleMapDataset(num_samples=10)
train_loader = DataLoader(
train_dataset,
batch_size=config["batch_size"],
shuffle=True,
num_workers=0,
)
val_loader = DataLoader(
val_dataset,
batch_size=config["batch_size"],
shuffle=False,
num_workers=0,
)
print(f" Обучающих примеров: {len(train_dataset)}")
print(f" Валидационных примеров: {len(val_dataset)}")
# 5. Создание тренера
print("\nСоздание тренера...")
trainer = GANTrainer(
model=model,
train_loader=train_loader,
val_loader=val_loader,
device=device,
config=config,
)
# 6. Обучение на небольшом количестве эпох
print("\nЗапуск обучения (3 эпохи для примера)...")
print("=" * 50)
trainer.train(num_epochs=3)
# 7. Генерация примеров
print("\nГенерация примеров преобразования...")
model.eval()
# Создаем тестовые данные
test_yandex = torch.randn(2, 3, 256, 256).to(device)
with torch.no_grad():
generated_google = model(test_yandex)
print(f"Входные изображения: {test_yandex.shape}")
print(f"Сгенерированные изображения: {generated_google.shape}")
print(
f"Диапазон значений: [{generated_google.min():.3f}, {generated_google.max():.3f}]"
)
# 8. Сохранение финальной модели
print("\nСохранение модели...")
model_save_path = "gan_model_minimal.pth"
torch.save(model.state_dict(), model_save_path)
print(f"Модель сохранена в: {model_save_path}")
print("\n" + "=" * 50)
print("Минимальный пример завершен!")
print("\nДля реального использования:")
print("1. Замените SimpleMapDataset на ваш реальный датасет")
print("2. Настройте параметры в config")
print("3. Увеличьте количество эпох (например, до 100)")
print("4. Используйте реальные изображения карт")
print("=" * 50)
if __name__ == "__main__":
main()

256
models/GAN/model.py Normal file
View File

@@ -0,0 +1,256 @@
"""GAN model for image translation Yandex -> Google."""
import torch
import torch.nn as nn
import torch.nn.functional as F
class UNetDownBlock(nn.Module):
"""Downsampling block for U-Net."""
def __init__(self, in_channels: int, out_channels: int, normalize: bool = True, dropout: float = 0.0):
super().__init__()
layers = [
nn.Conv2d(in_channels, out_channels, kernel_size=4, stride=2, padding=1, bias=False)
]
if normalize:
layers.append(nn.BatchNorm2d(out_channels))
layers.append(nn.LeakyReLU(0.2, inplace=True))
if dropout > 0:
layers.append(nn.Dropout2d(dropout))
self.model = nn.Sequential(*layers)
def forward(self, x):
return self.model(x)
class UNetUpBlock(nn.Module):
"""Upsampling block for U-Net."""
def __init__(self, in_channels: int, out_channels: int, dropout: float = 0.0):
super().__init__()
self.upconv = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=4, stride=2, padding=1, bias=False)
self.norm = nn.BatchNorm2d(out_channels)
self.relu = nn.ReLU(inplace=True)
if dropout > 0:
self.dropout = nn.Dropout2d(dropout)
else:
self.dropout = None
def forward(self, x, skip_input):
print(x.shape)
print(skip_input.shape)
x = self.upconv(x)
# Pad if needed to match skip connection size
if x.shape != skip_input.shape:
diff_h = skip_input.size(2) - x.size(2)
diff_w = skip_input.size(3) - x.size(3)
x = F.pad(x, [diff_w // 2, diff_w - diff_w // 2, diff_h // 2, diff_h - diff_h // 2])
x = self.norm(x)
x = self.relu(x)
if self.dropout:
x = self.dropout(x)
x = torch.cat([x, skip_input], dim=1)
return x
class GeneratorUNet(nn.Module):
"""U-Net generator for Yandex -> Google translation."""
def __init__(self, in_channels: int = 3, out_channels: int = 3):
super().__init__()
# Downsampling
self.down1 = UNetDownBlock(in_channels, 64, normalize=False)
self.down2 = UNetDownBlock(64, 128)
self.down3 = UNetDownBlock(128, 256)
self.down4 = UNetDownBlock(256, 512)
self.down5 = UNetDownBlock(512, 512)
self.down6 = UNetDownBlock(512, 512)
self.down7 = UNetDownBlock(512, 512)
# Bottleneck
self.bottleneck = nn.Sequential(
nn.Conv2d(512, 512, kernel_size=4, stride=2, padding=1, bias=False),
nn.ReLU(inplace=True),
)
# Upsampling
self.up1 = UNetUpBlock(1024, 512, dropout=0.5)
self.up2 = UNetUpBlock(1024, 512, dropout=0.5)
self.up3 = UNetUpBlock(1024, 512, dropout=0.5)
self.up4 = UNetUpBlock(1024, 512)
self.up5 = UNetUpBlock(1024, 256)
self.up6 = UNetUpBlock(512, 128)
self.up7 = UNetUpBlock(256, 64)
# Final
self.final = nn.Sequential(
nn.ConvTranspose2d(128, out_channels, kernel_size=4, stride=2, padding=1),
nn.Tanh(),
)
def forward(self, x):
# Down
d1 = self.down1(x)
d2 = self.down2(d1)
d3 = self.down3(d2)
d4 = self.down4(d3)
d5 = self.down5(d4)
d6 = self.down6(d5)
d7 = self.down7(d6)
# Bottleneck
u = self.bottleneck(d7)
# Up with skip connections
u = self.up1(u, d7)
u = self.up2(u, d6)
u = self.up3(u, d5)
u = self.up4(u, d4)
u = self.up5(u, d3)
u = self.up6(u, d2)
u = self.up7(u, d1)
return self.final(u)
class DiscriminatorPatchGAN(nn.Module):
"""PatchGAN discriminator."""
def __init__(self, in_channels: int = 6):
super().__init__()
self.model = nn.Sequential(
nn.Conv2d(in_channels, 64, kernel_size=4, stride=2, padding=1),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),
nn.BatchNorm2d(128),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1),
nn.BatchNorm2d(256),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(256, 512, kernel_size=4, stride=2, padding=1),
nn.BatchNorm2d(512),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(512, 1, kernel_size=4, stride=1, padding=1),
nn.Sigmoid(),
)
def forward(self, img_A, img_B):
x = torch.cat([img_A, img_B], dim=1)
return self.model(x)
class GANLoss(nn.Module):
"""GAN loss supporting different GAN modes."""
def __init__(self, gan_mode: str = "vanilla", target_real: float = 1.0, target_fake: float = 0.0):
super().__init__()
self.gan_mode = gan_mode
self.register_buffer("real_label", torch.tensor(target_real))
self.register_buffer("fake_label", torch.tensor(target_fake))
if gan_mode == "vanilla":
self.loss_fn = nn.BCEWithLogitsLoss()
elif gan_mode == "lsgan":
self.loss_fn = nn.MSELoss()
elif gan_mode == "wgangp":
self.loss_fn = None
else:
raise ValueError(f"Unknown GAN mode: {gan_mode}")
def forward(self, prediction: torch.Tensor, target_is_real: bool) -> torch.Tensor:
if self.gan_mode in ["vanilla", "lsgan"]:
target = self.real_label if target_is_real else self.fake_label
target = target.expand_as(prediction)
return self.loss_fn(prediction, target)
elif self.gan_mode == "wgangp":
return -prediction.mean() if target_is_real else prediction.mean()
class ImageGAN(nn.Module):
"""Complete GAN model for image translation."""
def __init__(
self,
input_channels: int = 3,
output_channels: int = 3,
gan_mode: str = "vanilla",
lambda_L1: float = 100.0,
use_cuda: bool = True,
):
super().__init__()
self.generator = GeneratorUNet(input_channels, output_channels)
self.discriminator = DiscriminatorPatchGAN(input_channels + output_channels)
self.gan_loss = GANLoss(gan_mode)
self.l1_loss = nn.L1Loss()
self.lambda_L1 = lambda_L1
self.device = torch.device("cuda" if use_cuda and torch.cuda.is_available() else "cpu")
self.to(self.device)
def forward(self, yandex_image):
"""Generate Google image from Yandex."""
return self.generator(yandex_image)
def generator_step(self, yandex_img, real_google_img):
"""Compute generator losses."""
fake_google = self.generator(yandex_img)
fake_pred = self.discriminator(yandex_img, fake_google)
gan_loss = self.gan_loss(fake_pred, True)
l1_loss = self.l1_loss(fake_google, real_google_img) * self.lambda_L1
total_loss = gan_loss + l1_loss
return total_loss, gan_loss, l1_loss
def discriminator_step(self, yandex_img, real_google_img, fake_google_img):
"""Compute discriminator losses."""
real_pred = self.discriminator(yandex_img, real_google_img)
real_loss = self.gan_loss(real_pred, True)
fake_pred = self.discriminator(yandex_img, fake_google_img.detach())
fake_loss = self.gan_loss(fake_pred, False)
total_loss = (real_loss + fake_loss) * 0.5
return total_loss, real_loss, fake_loss
def create_gan(
input_channels: int = 3,
output_channels: int = 3,
gan_mode: str = "vanilla",
lambda_L1: float = 100.0,
use_cuda: bool = True,
) -> ImageGAN:
"""Create a GAN model."""
return ImageGAN(
input_channels=input_channels,
output_channels=output_channels,
gan_mode=gan_mode,
lambda_L1=lambda_L1,
use_cuda=use_cuda,
)
def initialize_weights(model: nn.Module):
"""Initialize model weights."""
for m in model.modules():
if isinstance(m, nn.Conv2d):
nn.init.normal_(m.weight.data, 0.0, 0.02)
elif isinstance(m, nn.BatchNorm2d):
nn.init.normal_(m.weight.data, 1.0, 0.02)
nn.init.constant_(m.bias.data, 0.0)
if __name__ == "__main__":
# Quick test
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = create_gan(use_cuda=False)
print(f"Model created on {model.device}")
# Test forward pass
test_input = torch.randn(2, 3, 256, 256).to(model.device)
output = model(test_input)
print(f"Output shape: {output.shape}")
# Count parameters
gen_params = sum(p.numel() for p in model.generator.parameters())
disc_params = sum(p.numel() for p in model.discriminator.parameters())
print(f"Generator: {gen_params:,} params, Discriminator: {disc_params:,} params")

View File

@@ -1,349 +0,0 @@
import os
import sys
import torch
import torch.nn as nn
# Добавляем путь к модулю
sys.path.append(
os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
)
from gan import (
DiscriminatorPatchGAN,
GeneratorUNet,
ImageGAN,
create_image_gan,
initialize_gan_weights,
)
def test_generator():
"""Тестирование генератора"""
print("=" * 60)
print("Тестирование генератора...")
print("=" * 60)
# Создаем генератор
generator = GeneratorUNet(in_channels=3, out_channels=3)
# Инициализируем веса
generator.apply(
lambda m: (
nn.init.normal_(m.weight.data, 0.0, 0.02)
if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d)
else None
)
)
# Создаем тестовый входной тензор (Yandex изображение)
batch_size = 2
height, width = 700, 700
yandex_image = torch.randn(batch_size, 3, height, width)
print(f"Размер входного изображения: {yandex_image.shape}")
print(
f"Количество параметров генератора: {sum(p.numel() for p in generator.parameters()):,}"
)
# Прямой проход
with torch.no_grad():
generated_image = generator(yandex_image)
print(f"Размер сгенерированного изображения: {generated_image.shape}")
print(
f"Диапазон значений сгенерированного изображения: [{generated_image.min():.3f}, {generated_image.max():.3f}]"
)
# Проверка размеров
assert generated_image.shape == (batch_size, 3, height, width), (
f"Ожидался размер {(batch_size, 3, height, width)}, получен {generated_image.shape}"
)
print("✓ Генератор работает корректно!")
return generator
def test_discriminator():
"""Тестирование дискриминатора"""
print("\n" + "=" * 60)
print("Тестирование дискриминатора...")
print("=" * 60)
# Создаем дискриминатор
discriminator = DiscriminatorPatchGAN(in_channels=6) # 3 + 3 канала
# Инициализируем веса
discriminator.apply(
lambda m: (
nn.init.normal_(m.weight.data, 0.0, 0.02)
if isinstance(m, nn.Conv2d)
else None
)
)
# Создаем тестовые тензоры
batch_size = 2
height, width = 700, 700
yandex_image = torch.randn(batch_size, 3, height, width)
google_image = torch.randn(batch_size, 3, height, width)
print(f"Размер Yandex изображения: {yandex_image.shape}")
print(f"Размер Google изображения: {google_image.shape}")
print(
f"Количество параметров дискриминатора: {sum(p.numel() for p in discriminator.parameters()):,}"
)
# Прямой проход
with torch.no_grad():
prediction = discriminator(yandex_image, google_image)
print(f"Размер выхода дискриминатора: {prediction.shape}")
print(
f"Диапазон значений предсказания: [{prediction.min():.3f}, {prediction.max():.3f}]"
)
# Проверка размеров (PatchGAN выдает карту вероятностей)
expected_height = 43 # Для изображения 700x700 после 4 downsampling блоков
expected_width = 43
assert prediction.shape == (batch_size, 1, expected_height, expected_width), (
f"Ожидался размер {(batch_size, 1, expected_height, expected_width)}, получен {prediction.shape}"
)
print(
f"✓ Дискриминатор работает корректно! Выходной размер: {prediction.shape[2]}x{prediction.shape[3]}"
)
return discriminator
def test_gan_model():
"""Тестирование полной GAN модели"""
print("\n" + "=" * 60)
print("Тестирование полной GAN модели...")
print("=" * 60)
# Создаем GAN модель
gan = ImageGAN(
input_channels=3,
output_channels=3,
gan_mode="vanilla",
lambda_L1=100.0,
use_cuda=False, # Тестируем на CPU для простоты
)
print(f"Устройство модели: {gan.device}")
print(
f"Количество параметров генератора: {sum(p.numel() for p in gan.generator.parameters()):,}"
)
print(
f"Количество параметров дискриминатора: {sum(p.numel() for p in gan.discriminator.parameters()):,}"
)
print(f"Общее количество параметров: {sum(p.numel() for p in gan.parameters()):,}")
# Создаем тестовые данные
batch_size = 2
height, width = 700, 700
yandex_image = torch.randn(batch_size, 3, height, width)
real_google_image = torch.randn(batch_size, 3, height, width)
print(f"\nТестирование прямого прохода...")
with torch.no_grad():
generated_image = gan(yandex_image)
print(f"Размер сгенерированного изображения: {generated_image.shape}")
print(f"\nТестирование шага генератора...")
gan.train_mode()
# Тестируем шаг генератора
total_loss, gan_loss, l1_loss = gan.generator_step(yandex_image, real_google_image)
print(f"Общие потери генератора: {total_loss.item():.6f}")
print(f"Потери GAN: {gan_loss.item():.6f}")
print(f"Потери L1: {l1_loss.item():.6f}")
print(f"\nТестирование шага дискриминатора...")
# Создаем сгенерированное изображение для дискриминатора
with torch.no_grad():
fake_google_image = gan.generator(yandex_image)
total_d_loss, real_loss, fake_loss = gan.discriminator_step(
yandex_image, real_google_image, fake_google_image
)
print(f"Общие потери дискриминатора: {total_d_loss.item():.6f}")
print(f"Потери на реальных изображениях: {real_loss.item():.6f}")
print(f"Потери на сгенерированных изображениях: {fake_loss.item():.6f}")
print(f"\nТестирование режимов обучения/оценки...")
gan.eval_mode()
print(f"Генератор в режиме eval: {not gan.generator.training}")
print(f"Дискриминатор в режиме eval: {not gan.discriminator.training}")
gan.train_mode()
print(f"Генератор в режиме train: {gan.generator.training}")
print(f"Дискриминатор в режиме train: {gan.discriminator.training}")
print("\n✓ Полная GAN модель работает корректно!")
return gan
def test_factory_function():
"""Тестирование фабричной функции"""
print("\n" + "=" * 60)
print("Тестирование фабричной функции...")
print("=" * 60)
# Тестируем разные режимы GAN
for gan_mode in ["vanilla", "lsgan"]:
print(f"\nСоздание GAN в режиме '{gan_mode}'...")
gan = create_image_gan(
input_channels=3,
output_channels=3,
gan_mode=gan_mode,
lambda_L1=100.0,
use_cuda=False,
)
print(f" Режим GAN: {gan.gan_loss.gan_mode}")
print(f" Вес L1 потерь: {gan.lambda_L1}")
print(f" Устройство: {gan.device}")
# Быстрая проверка прямого прохода
batch_size = 1
yandex_image = torch.randn(batch_size, 3, 700, 700)
with torch.no_grad():
generated = gan(yandex_image)
print(f" Размер выхода: {generated.shape}")
print(f" ✓ GAN в режиме '{gan_mode}' создан успешно")
print("\n✓ Фабричная функция работает корректно!")
def test_weights_initialization():
"""Тестирование инициализации весов"""
print("\n" + "=" * 60)
print("Тестирование инициализации весов...")
print("=" * 60)
# Создаем модели
generator = GeneratorUNet(3, 3)
discriminator = DiscriminatorPatchGAN(6)
# Инициализируем веса
initialize_gan_weights(generator, discriminator)
# Проверяем средние значения весов
def check_weights_mean(model, model_name):
conv_weights = []
for name, param in model.named_parameters():
if "weight" in name and (
"conv" in name.lower() or "Conv" in str(param.__class__)
):
conv_weights.append(param.data.mean().item())
if conv_weights:
avg_mean = sum(conv_weights) / len(conv_weights)
print(f" Среднее значение весов Conv слоев в {model_name}: {avg_mean:.6f}")
# Проверяем, что веса инициализированы около 0
assert abs(avg_mean) < 0.1, f"Веса {model_name} не инициализированы около 0"
check_weights_mean(generator, "генераторе")
check_weights_mean(discriminator, "дискриминаторе")
print("✓ Инициализация весов работает корректно!")
def test_memory_usage():
"""Тестирование использования памяти"""
print("\n" + "=" * 60)
print("Тестирование использования памяти...")
print("=" * 60)
import os
import psutil
# Получаем текущее использование памяти
process = psutil.Process(os.getpid())
memory_before = process.memory_info().rss / 1024 / 1024 # в MB
print(f"Память до создания моделей: {memory_before:.2f} MB")
# Создаем несколько моделей
models = []
for i in range(3):
gan = create_image_gan(use_cuda=False)
models.append(gan)
# Делаем тестовый проход
batch_size = 1
yandex_image = torch.randn(batch_size, 3, 700, 700)
real_google_image = torch.randn(batch_size, 3, 700, 700)
with torch.no_grad():
_ = gan(yandex_image)
_ = gan.generator_step(yandex_image, real_google_image)
memory_after = process.memory_info().rss / 1024 / 1024 # в MB
memory_used = memory_after - memory_before
print(f"Память после создания моделей: {memory_after:.2f} MB")
print(f"Использовано памяти: {memory_used:.2f} MB")
# Очищаем модели
del models
import gc
gc.collect()
memory_final = process.memory_info().rss / 1024 / 1024
print(f"Память после очистки: {memory_final:.2f} MB")
print("✓ Тестирование памяти завершено!")
def main():
"""Основная функция тестирования"""
print("Начало тестирования GAN архитектуры для преобразования Yandex → Google")
print("Размер изображения: 700x700 пикселей")
print("=" * 60)
try:
# Запускаем все тесты
test_generator()
test_discriminator()
test_gan_model()
test_factory_function()
test_weights_initialization()
test_memory_usage()
print("\n" + "=" * 60)
print("ВСЕ ТЕСТЫ ПРОЙДЕНЫ УСПЕШНО! 🎉")
print("=" * 60)
print("\nАрхитектура GAN готова к использованию для преобразования")
print("изображений из стиля Yandex в стиль Google.")
print("\nОсновные характеристики:")
print(" • Генератор: U-Net архитектура")
print(" • Дискриминатор: PatchGAN (43x43 патчей)")
print(" • Размер входных/выходных изображений: 700x700")
print(" • Поддержка режимов: vanilla, lsgan")
print(" • L1 регуляризация для сохранения структуры")
except Exception as e:
print(f"\n❌ Ошибка при тестировании: {e}")
import traceback
traceback.print_exc()
return 1
return 0
if __name__ == "__main__":
exit_code = main()
sys.exit(exit_code)

View File

@@ -1,342 +0,0 @@
"""
Тестовый скрипт для проверки GAN trainer.
"""
import sys
from pathlib import Path
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
# Добавляем путь к модулям
sys.path.append(str(Path(__file__).parent.parent.parent))
from models.GAN.gan import create_image_gan
from models.GAN.trainer import GANTrainer
class SimpleDataset(Dataset):
"""Простой датасет для тестирования."""
def __init__(self, num_samples=100, image_size=(256, 256)):
self.num_samples = num_samples
self.image_size = image_size
def __len__(self):
return self.num_samples
def __getitem__(self, idx):
# Создаем случайные изображения для тестирования
yandex_img = torch.randn(3, self.image_size[0], self.image_size[1])
google_img = torch.randn(3, self.image_size[0], self.image_size[1])
return {"yandex_img": yandex_img, "google_img": google_img}
def test_gan_model():
"""Тестирование GAN модели."""
print("Тестирование GAN модели...")
# Создаем модель
model = create_image_gan(
input_channels=3,
output_channels=3,
gan_mode="vanilla",
lambda_L1=100.0,
use_cuda=False, # Используем CPU для тестирования
)
# Тестируем forward pass
batch_size = 2
image_size = (256, 256)
yandex_input = torch.randn(batch_size, 3, *image_size)
with torch.no_grad():
output = model(yandex_input)
print(f"Входной размер: {yandex_input.shape}")
print(f"Выходной размер: {output.shape}")
print(f"Диапазон выходных значений: [{output.min():.3f}, {output.max():.3f}]")
# Проверяем, что выход в диапазоне [-1, 1] (из-за Tanh)
assert output.min() >= -1.0 and output.max() <= 1.0, "Выход не в диапазоне [-1, 1]"
print("✓ Forward pass работает корректно")
return model
def test_generator_step():
"""Тестирование шага генератора."""
print("\nТестирование шага генератора...")
model = create_image_gan(use_cuda=False)
model.train_mode()
batch_size = 2
yandex_img = torch.randn(batch_size, 3, 256, 256)
google_img = torch.randn(batch_size, 3, 256, 256)
# Тестируем generator_step
total_loss, gan_loss, l1_loss = model.generator_step(yandex_img, google_img)
print(f"Total loss: {total_loss.item():.6f}")
print(f"GAN loss: {gan_loss.item():.6f}")
print(f"L1 loss: {l1_loss.item():.6f}")
assert total_loss.item() > 0, "Потери должны быть положительными"
print("✓ Шаг генератора работает корректно")
def test_discriminator_step():
"""Тестирование шага дискриминатора."""
print("\nТестирование шага дискриминатора...")
model = create_image_gan(use_cuda=False)
model.train_mode()
batch_size = 2
yandex_img = torch.randn(batch_size, 3, 256, 256)
google_img = torch.randn(batch_size, 3, 256, 256)
# Генерируем fake изображение
with torch.no_grad():
fake_google_img = model.generator(yandex_img)
# Тестируем discriminator_step
total_loss, real_loss, fake_loss = model.discriminator_step(
yandex_img, google_img, fake_google_img
)
print(f"Total loss: {total_loss.item():.6f}")
print(f"Real loss: {real_loss.item():.6f}")
print(f"Fake loss: {fake_loss.item():.6f}")
assert total_loss.item() > 0, "Потери должны быть положительными"
print("✓ Шаг дискриминатора работает корректно")
def test_trainer_initialization():
"""Тестирование инициализации тренера."""
print("\nТестирование инициализации тренера...")
# Создаем модель
model = create_image_gan(use_cuda=False)
# Создаем даталоадеры
train_dataset = SimpleDataset(num_samples=50)
val_dataset = SimpleDataset(num_samples=10)
train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=2, shuffle=False)
# Конфигурация
config = {
"learning_rate": 2e-4,
"beta1": 0.5,
"beta2": 0.999,
"output_dir": "test_runs/gan",
"early_stopping_patience": 10,
}
# Создаем тренер
device = torch.device("cpu")
trainer = GANTrainer(
model=model,
train_loader=train_loader,
val_loader=val_loader,
device=device,
config=config,
)
print(f"Тренер создан успешно")
print(f"Оптимизатор генератора: {type(trainer.optimizer_G).__name__}")
print(f"Оптимизатор дискриминатора: {type(trainer.optimizer_D).__name__}")
print(f"Выходная директория: {trainer.output_dir}")
assert trainer.output_dir.exists(), "Выходная директория не создана"
print("✓ Тренер инициализирован корректно")
return trainer, train_loader, val_loader
def test_train_epoch():
"""Тестирование одной эпохи обучения."""
print("\nТестирование одной эпохи обучения...")
model = create_image_gan(use_cuda=False)
train_dataset = SimpleDataset(num_samples=20)
train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True)
val_loader = DataLoader(SimpleDataset(num_samples=5), batch_size=2, shuffle=False)
config = {
"learning_rate": 2e-4,
"beta1": 0.5,
"beta2": 0.999,
"output_dir": "test_runs/gan",
}
device = torch.device("cpu")
trainer = GANTrainer(
model=model,
train_loader=train_loader,
val_loader=val_loader,
device=device,
config=config,
)
# Запускаем одну эпоху обучения
avg_g_loss, avg_d_loss = trainer.train_epoch()
print(f"Средние потери за эпоху:")
print(f" Генератор: {avg_g_loss:.6f}")
print(f" Дискриминатор: {avg_d_loss:.6f}")
assert avg_g_loss > 0, "Потери генератора должны быть положительными"
assert avg_d_loss > 0, "Потери дискриминатора должны быть положительными"
print("✓ Эпоха обучения завершена успешно")
def test_validation():
"""Тестирование валидации."""
print("\nТестирование валидации...")
model = create_image_gan(use_cuda=False)
train_loader = DataLoader(SimpleDataset(num_samples=10), batch_size=2, shuffle=True)
val_loader = DataLoader(SimpleDataset(num_samples=5), batch_size=2, shuffle=False)
config = {
"learning_rate": 2e-4,
"beta1": 0.5,
"beta2": 0.999,
"output_dir": "test_runs/gan",
}
device = torch.device("cpu")
trainer = GANTrainer(
model=model,
train_loader=train_loader,
val_loader=val_loader,
device=device,
config=config,
)
# Запускаем валидацию
val_g_loss, val_d_loss = trainer.validate()
print(f"Потери на валидации:")
print(f" Генератор: {val_g_loss:.6f}")
print(f" Дискриминатор: {val_d_loss:.6f}")
assert val_g_loss > 0, "Потери генератора должны быть положительными"
assert val_d_loss > 0, "Потери дискриминатора должны быть положительными"
print("✓ Валидация завершена успешно")
def test_checkpoint_saving():
"""Тестирование сохранения чекпоинтов."""
print("\nТестирование сохранения чекпоинтов...")
model = create_image_gan(use_cuda=False)
train_loader = DataLoader(SimpleDataset(num_samples=10), batch_size=2, shuffle=True)
val_loader = DataLoader(SimpleDataset(num_samples=5), batch_size=2, shuffle=False)
config = {
"learning_rate": 2e-4,
"beta1": 0.5,
"beta2": 0.999,
"output_dir": "test_runs/gan_checkpoint",
}
device = torch.device("cpu")
trainer = GANTrainer(
model=model,
train_loader=train_loader,
val_loader=val_loader,
device=device,
config=config,
)
# Сохраняем чекпоинт
trainer.save_checkpoint(is_best=True)
# Проверяем, что файлы созданы
checkpoint_files = list(trainer.output_dir.glob("*.pth"))
print(f"Создано файлов чекпоинтов: {len(checkpoint_files)}")
for file in checkpoint_files:
print(f" - {file.name}")
assert len(checkpoint_files) > 0, "Файлы чекпоинтов не созданы"
print("✓ Чекпоинты сохранены успешно")
# Тестируем загрузку чекпоинта
checkpoint_path = checkpoint_files[0]
print(f"\nТестируем загрузку чекпоинта: {checkpoint_path}")
# Создаем новую модель и тренер
new_model = create_image_gan(use_cuda=False)
new_trainer = GANTrainer(
model=new_model,
train_loader=train_loader,
val_loader=val_loader,
device=device,
config=config,
)
# Загружаем чекпоинт
new_trainer.load_checkpoint(str(checkpoint_path))
print(f"Загружен чекпоинт эпохи: {new_trainer.current_epoch + 1}")
print("✓ Чекпоинт загружен успешно")
def main():
"""Основная функция тестирования."""
print("=" * 60)
print("Начало тестирования GAN trainer")
print("=" * 60)
try:
# Запускаем все тесты
test_gan_model()
test_generator_step()
test_discriminator_step()
test_trainer_initialization()
test_train_epoch()
test_validation()
test_checkpoint_saving()
print("\n" + "=" * 60)
print("Все тесты пройдены успешно! ✓")
print("=" * 60)
except Exception as e:
print(f"\nОшибка при тестировании: {e}")
import traceback
traceback.print_exc()
return 1
return 0
if __name__ == "__main__":
# Очищаем тестовые директории
import shutil
test_dirs = ["test_runs/gan", "test_runs/gan_checkpoint"]
for dir_path in test_dirs:
if Path(dir_path).exists():
shutil.rmtree(dir_path)
# Запускаем тесты
exit_code = main()
# Очищаем после тестов
for dir_path in test_dirs:
if Path(dir_path).exists():
shutil.rmtree(dir_path)
exit(exit_code)

View File

@@ -1,347 +0,0 @@
"""
Пример обучения GAN модели для преобразования Yandex → Google карт.
Этот скрипт показывает, как использовать GANTrainer для обучения модели.
"""
import sys
from pathlib import Path
import torch
from torch.utils.data import DataLoader
# Добавляем путь к модулям
sys.path.append(str(Path(__file__).parent.parent.parent))
from models.GAN.gan import create_image_gan
from models.GAN.trainer import GANTrainer
def create_simple_config():
"""Создает простую конфигурацию для обучения."""
config = {
# Параметры оптимизатора
"learning_rate": 2e-4,
"beta1": 0.5,
"beta2": 0.999,
# Параметры обучения
"batch_size": 4,
"epochs": 100,
# Параметры GAN
"gan_mode": "vanilla", # "vanilla", "lsgan", или "wgangp"
"lambda_L1": 100.0, # Вес L1 потерь
# Регуляризация
"grad_clip": 1.0,
# Ранняя остановка
"early_stopping_patience": 20,
# Выходные данные
"output_dir": "runs/gan_training",
# Логирование
"log_interval": 10, # Логировать каждые N батчей
"save_interval": 5, # Сохранять чекпоинт каждые N эпох
}
return config
def create_advanced_config():
"""Создает расширенную конфигурацию для обучения."""
config = {
# Параметры оптимизатора
"learning_rate": 2e-4,
"beta1": 0.5,
"beta2": 0.999,
# Планировщик learning rate
"use_scheduler": True,
"scheduler_type": "linear", # "linear", "cosine", или "plateau"
"scheduler_start_epoch": 50,
"scheduler_end_epoch": 100,
# Параметры обучения
"batch_size": 8,
"epochs": 200,
# Параметры GAN
"gan_mode": "lsgan", # LSGAN обычно более стабилен
"lambda_L1": 100.0,
# Аугментация данных
"augmentation": {
"random_crop": True,
"crop_size": 256,
"random_flip": True,
"color_jitter": True,
"brightness": 0.2,
"contrast": 0.2,
"saturation": 0.2,
"hue": 0.1,
},
# Регуляризация
"grad_clip": 1.0,
"weight_decay": 1e-4,
# Ранняя остановка
"early_stopping_patience": 30,
"early_stopping_min_delta": 1e-4,
# Выходные данные
"output_dir": "runs/gan_advanced",
# Логирование
"log_interval": 20,
"save_interval": 10,
"save_best_only": True, # Сохранять только лучшую модель
# Визуализация
"visualize_samples": True,
"num_visualize": 4,
"visualize_interval": 5, # Визуализировать каждые N эпох
}
return config
def print_config_summary(config):
"""Печатает сводку конфигурации."""
print("=" * 60)
print("Конфигурация обучения GAN")
print("=" * 60)
print(f"\nПараметры модели:")
print(f" Режим GAN: {config.get('gan_mode', 'vanilla')}")
print(f" Вес L1 потерь: {config.get('lambda_L1', 100.0)}")
print(f"\nПараметры обучения:")
print(f" Learning rate: {config.get('learning_rate', 2e-4)}")
print(f" Batch size: {config.get('batch_size', 4)}")
print(f" Эпох: {config.get('epochs', 100)}")
print(f" Beta1: {config.get('beta1', 0.5)}")
print(f" Beta2: {config.get('beta2', 0.999)}")
if config.get("use_scheduler", False):
print(f" Планировщик LR: {config.get('scheduler_type', 'linear')}")
print(f"\nРегуляризация:")
print(f" Gradient clipping: {config.get('grad_clip', 1.0)}")
if "weight_decay" in config:
print(f" Weight decay: {config['weight_decay']}")
print(f"\nРанняя остановка:")
if config.get("early_stopping_patience", 0) > 0:
print(f" Patience: {config['early_stopping_patience']} эпох")
if "early_stopping_min_delta" in config:
print(f" Min delta: {config['early_stopping_min_delta']}")
print(f"\nВыходные данные:")
print(f" Директория: {config.get('output_dir', 'runs/gan')}")
print(f" Интервал сохранения: {config.get('save_interval', 5)} эпох")
print(f"\nЛогирование:")
print(f" Интервал логирования: {config.get('log_interval', 10)} батчей")
print("=" * 60)
def setup_training():
"""Настраивает обучение."""
print("Настройка обучения GAN...")
# Выбираем конфигурацию
use_advanced = False # Измените на True для расширенной конфигурации
if use_advanced:
config = create_advanced_config()
else:
config = create_simple_config()
# Печатаем сводку конфигурации
print_config_summary(config)
# Устройство
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"\nИспользуемое устройство: {device}")
if device.type == "cuda":
print(f" GPU: {torch.cuda.get_device_name(0)}")
print(
f" Память: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB"
)
# Создаем модель
print("\nСоздание модели...")
model = create_image_gan(
input_channels=3,
output_channels=3,
gan_mode=config.get("gan_mode", "vanilla"),
lambda_L1=config.get("lambda_L1", 100.0),
use_cuda=(device.type == "cuda"),
)
# Создаем даталоадеры
print("\nСоздание даталоадеров...")
# ЗАМЕНИТЕ ЭТО НА ВАШИ РЕАЛЬНЫЕ ДАННЫЕ
# Пример:
# from your_dataset_module import create_data_loaders
# train_loader, val_loader = create_data_loaders(
# data_dir="ваш/путь/к/данным",
# batch_size=config["batch_size"],
# image_size=(256, 256),
# augment=config.get("augmentation", None),
# )
# Для примера создаем фиктивные даталоадеры
# ВАЖНО: Замените это на реальные данные!
print(" ВНИМАНИЕ: Используются фиктивные данные!")
print(" Замените на реальные даталоадеры!")
import numpy as np
from torch.utils.data import Dataset
class DummyDataset(Dataset):
def __init__(self, num_samples=100):
self.num_samples = num_samples
def __len__(self):
return self.num_samples
def __getitem__(self, idx):
# Фиктивные данные для примера
yandex_img = torch.randn(3, 256, 256)
google_img = torch.randn(3, 256, 256)
return {"yandex_img": yandex_img, "google_img": google_img}
train_dataset = DummyDataset(num_samples=100)
val_dataset = DummyDataset(num_samples=20)
train_loader = DataLoader(
train_dataset,
batch_size=config.get("batch_size", 4),
shuffle=True,
num_workers=0,
)
val_loader = DataLoader(
val_dataset,
batch_size=config.get("batch_size", 4),
shuffle=False,
num_workers=0,
)
print(f" Размер обучающего набора: {len(train_dataset)}")
print(f" Размер валидационного набора: {len(val_dataset)}")
print(f" Батчей в эпохе: {len(train_loader)}")
# Создаем тренер
print("\nСоздание тренера...")
trainer = GANTrainer(
model=model,
train_loader=train_loader,
val_loader=val_loader,
device=device,
config=config,
)
return trainer, config
def train_model(trainer, config):
"""Запускает обучение модели."""
print("\n" + "=" * 60)
print("Начало обучения")
print("=" * 60)
epochs = config.get("epochs", 100)
try:
trainer.train(num_epochs=epochs)
print("\n" + "=" * 60)
print("Обучение завершено успешно!")
print("=" * 60)
except KeyboardInterrupt:
print("\n\nОбучение прервано пользователем.")
print("Сохранение текущего состояния...")
trainer.save_checkpoint(is_best=False)
except Exception as e:
print(f"\n\nОшибка при обучении: {e}")
import traceback
traceback.print_exc()
# Пытаемся сохранить чекпоинт при ошибке
try:
trainer.save_checkpoint(is_best=False)
print("Текущее состояние сохранено.")
except:
print("Не удалось сохранить состояние.")
def evaluate_model(trainer, test_loader=None):
"""Оценивает обученную модель."""
print("\n" + "=" * 60)
print("Оценка модели")
print("=" * 60)
if test_loader is None:
print("Тестовый даталоадер не предоставлен.")
print("Используется валидационный даталоадер для оценки.")
test_loader = trainer.val_loader
metrics = trainer.evaluate(test_loader)
print("\nМетрики оценки:")
for key, value in metrics.items():
print(f" {key}: {value:.6f}")
return metrics
def generate_examples(model, device, num_examples=4):
"""Генерирует примеры преобразования."""
print("\n" + "=" * 60)
print("Генерация примеров")
print("=" * 60)
model.eval()
# Создаем фиктивные входные данные
yandex_input = torch.randn(num_examples, 3, 256, 256).to(device)
with torch.no_grad():
google_output = model(yandex_input)
print(f"Сгенерировано {num_examples} примеров")
print(f"Размер входных данных: {yandex_input.shape}")
print(f"Размер выходных данных: {google_output.shape}")
# Сохраняем примеры (в реальном коде сохраняйте как изображения)
print("\nПримеры сгенерированы.")
print("В реальном коде сохраняйте их как изображения для визуализации.")
return yandex_input, google_output
def main():
"""Основная функция."""
print("=" * 60)
print("Пример обучения GAN для преобразования Yandex → Google")
print("=" * 60)
# Настройка
trainer, config = setup_training()
# Обучение
train_model(trainer, config)
# Оценка (требует реальных тестовых данных)
# evaluate_model(trainer)
# Генерация примеров
# generate_examples(trainer.model, trainer.device)
print("\n" + "=" * 60)
print("Скрипт завершен.")
print("=" * 60)
print("\nСледующие шаги:")
print("1. Замените фиктивные даталоадеры на реальные данные")
print("2. Настройте параметры в create_simple_config()")
print("3. Запустите обучение с реальными данными")
print("4. Визуализируйте результаты")
print("=" * 60)
if __name__ == "__main__":
main()

View File

@@ -1,415 +1,191 @@
import json
import time
from pathlib import Path
from typing import Any, Dict, List, Tuple
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm
# Type aliases
ModuleType = nn.Module
TensorType = torch.Tensor
class GANTrainer:
"""Trainer class for GAN model."""
def __init__(
self,
model: ModuleType,
train_loader: DataLoader,
val_loader: DataLoader,
device: torch.device,
config: Dict[str, Any],
):
"""
Initialize the GAN trainer.
Args:
model: GAN model (ImageGAN)
train_loader: Training data loader
val_loader: Validation data loader
device: Device to run training on
config: Training configuration dictionary
"""
self.model = model.to(device)
self.train_loader = train_loader
self.val_loader = val_loader
self.device = device
self.config = config
# Optimizers
lr = config.get("learning_rate", 2e-4)
beta1 = config.get("beta1", 0.5)
beta2 = config.get("beta2", 0.999)
# Separate optimizers for generator and discriminator
# Note: self.model is expected to have .generator and .discriminator attributes
self.optimizer_G = optim.Adam(
self.model.generator.parameters(), lr=lr, betas=(beta1, beta2)
)
self.optimizer_D = optim.Adam(
self.model.discriminator.parameters(), lr=lr, betas=(beta1, beta2)
)
# Training state
self.current_epoch = 0
self.best_val_loss = float("inf")
self.g_losses: List[float] = []
self.d_losses: List[float] = []
self.val_g_losses: List[float] = []
self.val_d_losses: List[float] = []
# Create output directory
self.output_dir = Path(config.get("output_dir", "runs/gan"))
self.output_dir.mkdir(parents=True, exist_ok=True)
# TensorBoard writer
self.writer = SummaryWriter(log_dir=self.output_dir / "tensorboard")
# Save configuration
config_path = self.output_dir / "config.json"
with open(config_path, "w") as f:
json.dump(config, f, indent=2)
print(f"Training configuration saved to {config_path}")
# Access parameters through the model's generator and discriminator
generator_params = sum(p.numel() for p in self.model.generator.parameters())
discriminator_params = sum(
p.numel() for p in self.model.discriminator.parameters()
)
print(f"Generator has {generator_params:,} parameters")
print(f"Discriminator has {discriminator_params:,} parameters")
def train_epoch(self) -> Tuple[float, float]:
"""
Train for one epoch.
Returns:
Tuple of (average generator loss, average discriminator loss)
"""
self.model.train()
total_g_loss = 0.0
total_d_loss = 0.0
num_batches = len(self.train_loader)
progress_bar = tqdm(self.train_loader, desc=f"Epoch {self.current_epoch + 1}")
for batch_idx, batch in enumerate(progress_bar):
# Move data to device
yandex_img = batch["yandex_img"].to(self.device)
google_img = batch["google_img"].to(self.device)
# ========== Train Discriminator ==========
self.optimizer_D.zero_grad()
# Generate fake image
with torch.no_grad():
fake_google_img = self.model.generator(yandex_img)
# Discriminator loss - returns tuple of tensors
d_loss_tuple = self.model.discriminator_step(
yandex_img, google_img, fake_google_img
)
d_loss, d_real_loss, d_fake_loss = d_loss_tuple
# Backward and optimize discriminator
d_loss.backward()
self.optimizer_D.step()
# ========== Train Generator ==========
self.optimizer_G.zero_grad()
# Generate fake image
fake_google_img = self.model.generator(yandex_img)
# Generator loss - returns tuple of tensors
g_loss_tuple = self.model.generator_step(yandex_img, google_img)
g_loss, g_gan_loss, g_l1_loss = g_loss_tuple
# Backward and optimize generator
g_loss.backward()
self.optimizer_G.step()
# Update statistics
total_g_loss += g_loss.item()
total_d_loss += d_loss.item()
# Update progress bar
progress_bar.set_postfix(
{
"g_loss": g_loss.item(),
"d_loss": d_loss.item(),
"g_l1": g_l1_loss.item(),
"d_real": d_real_loss.item(),
"d_fake": d_fake_loss.item(),
}
)
# Log batch losses to TensorBoard
global_step = self.current_epoch * num_batches + batch_idx
self.writer.add_scalar("train/batch_g_loss", g_loss.item(), global_step)
self.writer.add_scalar("train/batch_d_loss", d_loss.item(), global_step)
self.writer.add_scalar(
"train/batch_g_l1_loss", g_l1_loss.item(), global_step
)
self.writer.add_scalar(
"train/batch_d_real_loss", d_real_loss.item(), global_step
)
self.writer.add_scalar(
"train/batch_d_fake_loss", d_fake_loss.item(), global_step
)
avg_g_loss = total_g_loss / num_batches
avg_d_loss = total_d_loss / num_batches
self.g_losses.append(avg_g_loss)
self.d_losses.append(avg_d_loss)
return avg_g_loss, avg_d_loss
def validate(self) -> Tuple[float, float]:
"""
Validate the model.
Returns:
Tuple of (average generator validation loss, average discriminator validation loss)
"""
self.model.eval()
total_g_loss = 0.0
total_d_loss = 0.0
progress_bar = tqdm(self.val_loader, desc="Validation")
for batch in progress_bar:
# Move data to device
yandex_img = batch["yandex_img"].to(self.device)
google_img = batch["google_img"].to(self.device)
with torch.no_grad():
# Generate fake image
fake_google_img = self.model.generator(yandex_img)
# Generator loss - returns tuple of tensors
g_loss_tuple = self.model.generator_step(yandex_img, google_img)
g_loss, g_gan_loss, g_l1_loss = g_loss_tuple
# Discriminator loss - returns tuple of tensors
d_loss_tuple = self.model.discriminator_step(
yandex_img, google_img, fake_google_img
)
d_loss, d_real_loss, d_fake_loss = d_loss_tuple
# Update statistics
total_g_loss += g_loss.item()
total_d_loss += d_loss.item()
# Update progress bar
progress_bar.set_postfix({"g_loss": g_loss.item(), "d_loss": d_loss.item()})
avg_g_loss = total_g_loss / len(self.val_loader)
avg_d_loss = total_d_loss / len(self.val_loader)
self.val_g_losses.append(avg_g_loss)
self.val_d_losses.append(avg_d_loss)
return avg_g_loss, avg_d_loss
def save_checkpoint(self, is_best: bool = False):
"""
Save training checkpoint.
Args:
is_best: Whether this is the best model so far
"""
checkpoint = {
"epoch": self.current_epoch,
"model_state_dict": self.model.state_dict(),
"optimizer_G_state_dict": self.optimizer_G.state_dict(),
"optimizer_D_state_dict": self.optimizer_D.state_dict(),
"g_losses": self.g_losses,
"d_losses": self.d_losses,
"val_g_losses": self.val_g_losses,
"val_d_losses": self.val_d_losses,
"best_val_loss": self.best_val_loss,
"config": self.config,
}
# Save regular checkpoint
checkpoint_path = (
self.output_dir / f"checkpoint_epoch_{self.current_epoch + 1}.pth"
)
torch.save(checkpoint, checkpoint_path)
# Save best model
if is_best:
best_path = self.output_dir / "model_best.pth"
torch.save(checkpoint, best_path)
print(f"Best model saved to {best_path}")
def load_checkpoint(self, checkpoint_path: str, resume_training: bool = False):
"""
Load training checkpoint.
Args:
checkpoint_path: Path to checkpoint file
resume_training: Если True, продолжить обучение с сохраненной эпохи
"""
checkpoint = torch.load(checkpoint_path, map_location=self.device)
self.model.load_state_dict(checkpoint["model_state_dict"])
self.optimizer_G.load_state_dict(checkpoint["optimizer_G_state_dict"])
self.optimizer_D.load_state_dict(checkpoint["optimizer_D_state_dict"])
self.current_epoch = checkpoint["epoch"]
self.g_losses = checkpoint["g_losses"]
self.d_losses = checkpoint["d_losses"]
self.val_g_losses = checkpoint["val_g_losses"]
self.val_d_losses = checkpoint["val_d_losses"]
self.best_val_loss = checkpoint["best_val_loss"]
if resume_training:
print(f"Resuming training from epoch {self.current_epoch + 1}")
else:
print(f"Loaded checkpoint from epoch {self.current_epoch + 1}")
def train(self, num_epochs: int, start_epoch: int = 0):
"""
Train the model for specified number of epochs.
Args:
num_epochs: Number of epochs to train
start_epoch: Starting epoch (useful when resuming training)
"""
print(
f"Starting GAN training for {num_epochs} epochs from epoch {start_epoch + 1}..."
)
start_time = time.time()
for epoch in range(start_epoch, start_epoch + num_epochs):
self.current_epoch = epoch
# Train for one epoch
train_g_loss, train_d_loss = self.train_epoch()
# Validate
val_g_loss, val_d_loss = self.validate()
# Log to TensorBoard
self.writer.add_scalar("train/epoch_g_loss", train_g_loss, epoch)
self.writer.add_scalar("train/epoch_d_loss", train_d_loss, epoch)
self.writer.add_scalar("val/epoch_g_loss", val_g_loss, epoch)
self.writer.add_scalar("val/epoch_d_loss", val_d_loss, epoch)
# Print epoch summary
print(f"\nEpoch {epoch + 1}/{num_epochs}:")
print(" Generator:")
print(f" Train Loss: {train_g_loss:.6f}")
print(f" Val Loss: {val_g_loss:.6f}")
print(" Discriminator:")
print(f" Train Loss: {train_d_loss:.6f}")
print(f" Val Loss: {val_d_loss:.6f}")
# Save checkpoint
val_total_loss = val_g_loss + val_d_loss
is_best = val_total_loss < self.best_val_loss
if is_best:
self.best_val_loss = val_total_loss
self.save_checkpoint(is_best=is_best)
# Early stopping
if self.config.get("early_stopping_patience", 0) > 0:
val_losses = [
g + d for g, d in zip(self.val_g_losses, self.val_d_losses)
]
if (
epoch - np.argmin(val_losses)
>= self.config["early_stopping_patience"]
):
print(f"Early stopping at epoch {epoch + 1}")
break
# Training completed
training_time = time.time() - start_time
print(f"\nTraining completed in {training_time:.2f} seconds")
print(f"Best validation total loss: {self.best_val_loss:.6f}")
# Save final model
final_model_path = self.output_dir / "model_final.pth"
torch.save(self.model.state_dict(), final_model_path)
print(f"Final model saved to {final_model_path}")
# Save training history
history_path = self.output_dir / "training_history.json"
history = {
"g_losses": self.g_losses,
"d_losses": self.d_losses,
"val_g_losses": self.val_g_losses,
"val_d_losses": self.val_d_losses,
"best_val_loss": self.best_val_loss,
"total_epochs": self.current_epoch + 1,
}
with open(history_path, "w") as f:
json.dump(history, f, indent=2)
print(f"Training history saved to {history_path}")
# Close TensorBoard writer
self.writer.close()
def evaluate(self, test_loader: DataLoader) -> Dict:
"""
Evaluate the model on test data.
Args:
test_loader: Test data loader
Returns:
Dictionary with evaluation metrics
"""
self.model.eval()
total_g_loss = 0.0
total_d_loss = 0.0
print("Evaluating model on test set...")
for batch in tqdm(test_loader, desc="Evaluation"):
# Move data to device
yandex_img = batch["yandex_img"].to(self.device)
google_img = batch["google_img"].to(self.device)
with torch.no_grad():
# Generate fake image
fake_google_img = self.model.generator(yandex_img)
# Generator loss - returns tuple of tensors
g_loss_tuple = self.model.generator_step(yandex_img, google_img)
g_loss, g_gan_loss, g_l1_loss = g_loss_tuple
# Discriminator loss - returns tuple of tensors
d_loss_tuple = self.model.discriminator_step(
yandex_img, google_img, fake_google_img
)
d_loss, d_real_loss, d_fake_loss = d_loss_tuple
# Update statistics
total_g_loss += g_loss.item()
total_d_loss += d_loss.item()
avg_g_loss = total_g_loss / len(test_loader)
avg_d_loss = total_d_loss / len(test_loader)
metrics = {
"generator_loss": avg_g_loss,
"discriminator_loss": avg_d_loss,
"total_loss": avg_g_loss + avg_d_loss,
}
print("\nTest Results:")
print(f" Generator Loss: {avg_g_loss:.6f}")
print(f" Discriminator Loss: {avg_d_loss:.6f}")
print(f" Total Loss: {avg_g_loss + avg_d_loss:.6f}")
return metrics
"""Trainer for GAN model."""
import json
import time
from pathlib import Path
from typing import Any, Dict, Tuple
import torch
from torch.utils.data import DataLoader
from tqdm import tqdm
class GANTrainer:
"""Simple GAN trainer."""
def __init__(
self,
model: torch.nn.Module,
train_loader: DataLoader,
val_loader: DataLoader,
config: Dict[str, Any],
):
self.model = model
self.train_loader = train_loader
self.val_loader = val_loader
self.config = config
self.device = model.device
# Optimizers
lr = config.get("learning_rate", 2e-4)
beta1 = config.get("beta1", 0.5)
beta2 = config.get("beta2", 0.999)
self.opt_G = torch.optim.Adam(model.generator.parameters(), lr=lr, betas=(beta1, beta2))
self.opt_D = torch.optim.Adam(model.discriminator.parameters(), lr=lr, betas=(beta1, beta2))
# Training state
self.current_epoch = 0
self.best_val_loss = float("inf")
self.g_losses = []
self.d_losses = []
self.val_g_losses = []
self.val_d_losses = []
# Output dir
self.output_dir = Path(config.get("output_dir", "runs/gan"))
self.output_dir.mkdir(parents=True, exist_ok=True)
(self.output_dir / "checkpoints").mkdir(exist_ok=True)
# Save config
with open(self.output_dir / "config.json", "w") as f:
json.dump(config, f, indent=2)
def train_epoch(self) -> Tuple[float, float]:
"""Train for one epoch."""
self.model.train()
total_g = total_d = 0.0
num_batches = len(self.train_loader)
pbar = tqdm(self.train_loader, desc=f"Epoch {self.current_epoch + 1}")
for batch in pbar:
yandex_img = batch["yandex_img"].to(self.device)
google_img = batch["google_img"].to(self.device)
# Train D
self.opt_D.zero_grad()
with torch.no_grad():
fake_img = self.model.generator(yandex_img)
d_loss = self.model.discriminator_step(yandex_img, google_img, fake_img)[0]
d_loss.backward()
self.opt_D.step()
# Train G
self.opt_G.zero_grad()
g_loss = self.model.generator_step(yandex_img, google_img)[0]
g_loss.backward()
self.opt_G.step()
total_g += g_loss.item()
total_d += d_loss.item()
pbar.set_postfix({"g_loss": g_loss.item(), "d_loss": d_loss.item()})
avg_g = total_g / num_batches
avg_d = total_d / num_batches
self.g_losses.append(avg_g)
self.d_losses.append(avg_d)
return avg_g, avg_d
@torch.no_grad()
def validate(self) -> Tuple[float, float]:
"""Validate the model."""
self.model.eval()
total_g = total_d = 0.0
for batch in tqdm(self.val_loader, desc="Val"):
yandex_img = batch["yandex_img"].to(self.device)
google_img = batch["google_img"].to(self.device)
fake_img = self.model.generator(yandex_img)
g_loss = self.model.generator_step(yandex_img, google_img)[0]
d_loss = self.model.discriminator_step(yandex_img, google_img, fake_img)[0]
total_g += g_loss.item()
total_d += d_loss.item()
avg_g = total_g / len(self.val_loader)
avg_d = total_d / len(self.val_loader)
self.val_g_losses.append(avg_g)
self.val_d_losses.append(avg_d)
return avg_g, avg_d
def train(self, num_epochs: int):
"""Train the model."""
print(f"Training for {num_epochs} epochs...")
for epoch in range(num_epochs):
self.current_epoch = epoch
# Train & validate
train_g, train_d = self.train_epoch()
val_g, val_d = self.validate()
# Save best checkpoint
val_total = val_g + val_d
if val_total < self.best_val_loss:
self.best_val_loss = val_total
self.save_checkpoint("best")
# Periodic checkpoint
if (epoch + 1) % self.config.get("save_interval", 5) == 0:
self.save_checkpoint(f"epoch_{epoch + 1}")
print(f"Epoch {epoch + 1}: train_g={train_g:.4f}, train_d={train_d:.4f}, val_g={val_g:.4f}, val_d={val_d:.4f}")
# Early stopping
patience = self.config.get("early_stopping_patience", 0)
if patience > 0 and len(self.val_g_losses) > patience:
recent = self.val_g_losses[-patience:]
if all(l >= min(self.val_g_losses[:-patience]) for l in recent):
print(f"Early stopping at epoch {epoch + 1}")
break
# Save final
self.save_checkpoint("final")
print(f"Training finished. Best val loss: {self.best_val_loss:.4f}")
def save_checkpoint(self, name: str):
"""Save model checkpoint."""
path = self.output_dir / "checkpoints" / f"{name}.pth"
torch.save({
"epoch": self.current_epoch,
"generator": self.model.generator.state_dict(),
"discriminator": self.model.discriminator.state_dict(),
"opt_G": self.opt_G.state_dict(),
"opt_D": self.opt_D.state_dict(),
}, path)
def create_trainer(
model: torch.nn.Module,
train_loader: DataLoader,
val_loader: DataLoader,
config: Dict[str, Any],
) -> GANTrainer:
"""Create a trainer instance."""
return GANTrainer(model, train_loader, val_loader, config)
if __name__ == "__main__":
# Quick test
from config import create_config
from dataloader import create_data_loaders
from model import create_gan
config = create_config()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = create_gan(use_cuda=False)
train_loader, val_loader = create_data_loaders(
root_dir=config["data_dir"],
batch_size=4,
image_size=tuple(config["image_size"]),
num_workers=0,
)
trainer = create_trainer(model, train_loader, val_loader, config)
# Test one training step (just to verify no errors)
print("Testing one training step...")
try:
g_loss, d_loss = trainer.train_epoch()
print(f"Training step succeeded: G={g_loss:.4f}, D={d_loss:.4f}")
except Exception as e:
print(f"Training step failed: {e}")