ref: simplify and modularize GAN codebase
This commit is contained in:
36
models/GAN/config.py
Normal file
36
models/GAN/config.py
Normal file
@@ -0,0 +1,36 @@
|
||||
"""Configuration for GAN training."""
|
||||
|
||||
|
||||
def create_config():
|
||||
"""Create default configuration dictionary."""
|
||||
return {
|
||||
# Optimizer params
|
||||
"learning_rate": 2e-4,
|
||||
"beta1": 0.5,
|
||||
"beta2": 0.999,
|
||||
# Training params
|
||||
"batch_size": 32,
|
||||
"epochs": 100,
|
||||
# GAN params
|
||||
"gan_mode": "vanilla",
|
||||
"lambda_L1": 100.0,
|
||||
# Regularization
|
||||
"grad_clip": 1.0,
|
||||
# Early stopping
|
||||
"early_stopping_patience": 20,
|
||||
# Output
|
||||
"output_dir": "runs/gan_training",
|
||||
# Logging
|
||||
"log_interval": 10,
|
||||
"save_interval": 5,
|
||||
# Data
|
||||
"data_dir": r"C:\Users\admin\Projects\autopilot\datasets\ya_go_maps\images",
|
||||
"image_size": [256, 256],
|
||||
"train_split": 0.8,
|
||||
"num_workers": 0,
|
||||
}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
config = create_config()
|
||||
print("Default config:", config)
|
||||
162
models/GAN/dataloader.py
Normal file
162
models/GAN/dataloader.py
Normal file
@@ -0,0 +1,162 @@
|
||||
"""Data loader for Yandex-to-Google image translation."""
|
||||
|
||||
import os
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
import torch
|
||||
from PIL import Image
|
||||
from torch.utils.data import DataLoader, Dataset
|
||||
from torchvision import transforms
|
||||
|
||||
|
||||
class YaGoDataset(Dataset):
|
||||
"""Dataset loading pairs of Yandex and Google map images."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
root_dir: str,
|
||||
image_size: Tuple[int, int] = (256, 256),
|
||||
augment: bool = False,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
root_dir: Directory with images named {idx:04d}_google.png and {idx:04d}_yandex.png
|
||||
image_size: Target image size (height, width)
|
||||
augment: Whether to apply augmentation (not implemented for simplicity)
|
||||
"""
|
||||
self.root_dir = root_dir
|
||||
self.image_size = image_size
|
||||
self.augment = augment
|
||||
|
||||
# Discover image pairs
|
||||
self.pairs = self._find_pairs()
|
||||
|
||||
# Transform to tensor + normalization
|
||||
self.transform = transforms.Compose(
|
||||
[
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
|
||||
]
|
||||
)
|
||||
|
||||
def _find_pairs(self) -> List[Dict]:
|
||||
"""Find all matching Google-Yandex image pairs."""
|
||||
pairs = []
|
||||
google_files = [f for f in os.listdir(self.root_dir) if f.endswith("_google.png")]
|
||||
|
||||
for google_file in sorted(google_files):
|
||||
idx_str = google_file.split("_")[0]
|
||||
try:
|
||||
idx = int(idx_str)
|
||||
except ValueError:
|
||||
continue
|
||||
|
||||
yandex_file = f"{idx:04d}_yandex.png"
|
||||
yandex_path = os.path.join(self.root_dir, yandex_file)
|
||||
|
||||
if os.path.exists(yandex_path):
|
||||
pairs.append(
|
||||
{
|
||||
"idx": idx,
|
||||
"google_path": os.path.join(self.root_dir, google_file),
|
||||
"yandex_path": yandex_path,
|
||||
}
|
||||
)
|
||||
|
||||
return pairs
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.pairs)
|
||||
|
||||
def __getitem__(self, idx: int) -> dict:
|
||||
pair = self.pairs[idx]
|
||||
|
||||
# Load images
|
||||
google_img = Image.open(pair["google_path"]).convert("RGB")
|
||||
yandex_img = Image.open(pair["yandex_path"]).convert("RGB")
|
||||
|
||||
# Resize
|
||||
google_img = google_img.resize((self.image_size[1], self.image_size[0]))
|
||||
yandex_img = yandex_img.resize((self.image_size[1], self.image_size[0]))
|
||||
|
||||
# Apply transforms
|
||||
google_tensor = self.transform(google_img)
|
||||
yandex_tensor = self.transform(yandex_img)
|
||||
|
||||
return {
|
||||
"google_img": google_tensor,
|
||||
"yandex_img": yandex_tensor,
|
||||
"idx": torch.tensor(pair["idx"], dtype=torch.long),
|
||||
}
|
||||
|
||||
|
||||
def create_data_loaders(
|
||||
root_dir: str,
|
||||
batch_size: int = 32,
|
||||
train_split: float = 0.8,
|
||||
num_workers: int = 0,
|
||||
image_size: Tuple[int, int] = (256, 256),
|
||||
) -> Tuple[DataLoader, DataLoader]:
|
||||
"""
|
||||
Create train and validation data loaders.
|
||||
|
||||
Args:
|
||||
root_dir: Directory with image pairs
|
||||
batch_size: Batch size
|
||||
train_split: Fraction for training (0.0-1.0)
|
||||
num_workers: DataLoader workers
|
||||
image_size: Target image size
|
||||
|
||||
Returns:
|
||||
(train_loader, val_loader)
|
||||
"""
|
||||
# Full dataset
|
||||
dataset = YaGoDataset(root_dir=root_dir, image_size=image_size)
|
||||
|
||||
# Split
|
||||
dataset_size = len(dataset)
|
||||
train_size = int(train_split * dataset_size)
|
||||
indices = torch.randperm(dataset_size).tolist()
|
||||
train_indices = indices[:train_size]
|
||||
val_indices = indices[train_size:]
|
||||
|
||||
# Subsets
|
||||
from torch.utils.data import Subset
|
||||
|
||||
train_dataset = Subset(dataset, train_indices)
|
||||
val_dataset = Subset(dataset, val_indices)
|
||||
|
||||
# DataLoaders
|
||||
train_loader = DataLoader(
|
||||
train_dataset,
|
||||
batch_size=batch_size,
|
||||
shuffle=True,
|
||||
num_workers=num_workers,
|
||||
pin_memory=True,
|
||||
)
|
||||
|
||||
val_loader = DataLoader(
|
||||
val_dataset,
|
||||
batch_size=batch_size,
|
||||
shuffle=False,
|
||||
num_workers=num_workers,
|
||||
pin_memory=True,
|
||||
)
|
||||
|
||||
return train_loader, val_loader
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Quick test
|
||||
from config import create_config
|
||||
|
||||
config = create_config()
|
||||
train_loader, val_loader = create_data_loaders(
|
||||
root_dir=config["data_dir"],
|
||||
batch_size=4,
|
||||
image_size=tuple(config["image_size"]),
|
||||
)
|
||||
|
||||
batch = next(iter(train_loader))
|
||||
print(f"Batch shapes: google={batch['google_img'].shape}, yandex={batch['yandex_img'].shape}")
|
||||
print(f"Train batches: {len(train_loader)}, Val batches: {len(val_loader)}")
|
||||
@@ -2,7 +2,7 @@
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 23,
|
||||
"execution_count": 1,
|
||||
"id": "bb583d80",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
@@ -37,7 +37,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 238,
|
||||
"execution_count": 2,
|
||||
"id": "92144cc0",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
|
||||
@@ -1,393 +0,0 @@
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class UNetDownBlock(nn.Module):
|
||||
"""Блок downsampling для U-Net генератора"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
normalize: bool = True,
|
||||
dropout: float = 0.0,
|
||||
):
|
||||
super().__init__()
|
||||
layers = [
|
||||
nn.Conv2d(
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size=4,
|
||||
stride=2,
|
||||
padding=1,
|
||||
bias=False,
|
||||
)
|
||||
]
|
||||
if normalize:
|
||||
layers.append(nn.BatchNorm2d(out_channels))
|
||||
layers.append(nn.LeakyReLU(0.2, inplace=True))
|
||||
if dropout > 0:
|
||||
layers.append(nn.Dropout2d(dropout))
|
||||
self.model = nn.Sequential(*layers)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return self.model(x)
|
||||
|
||||
|
||||
class UNetUpBlock(nn.Module):
|
||||
"""Блок upsampling для U-Net генератора"""
|
||||
|
||||
def __init__(self, in_channels: int, out_channels: int, dropout: float = 0.0):
|
||||
super().__init__()
|
||||
layers = [
|
||||
nn.ConvTranspose2d(
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size=4,
|
||||
stride=2,
|
||||
padding=1,
|
||||
bias=False,
|
||||
),
|
||||
nn.BatchNorm2d(out_channels),
|
||||
nn.ReLU(inplace=True),
|
||||
]
|
||||
if dropout > 0:
|
||||
layers.append(nn.Dropout2d(dropout))
|
||||
self.model = nn.Sequential(*layers)
|
||||
|
||||
def forward(self, x: torch.Tensor, skip_input: torch.Tensor) -> torch.Tensor:
|
||||
x = self.model(x)
|
||||
# Обрезаем skip connection до размера x, если необходимо
|
||||
if x.shape != skip_input.shape:
|
||||
diffY = skip_input.size(2) - x.size(2)
|
||||
diffX = skip_input.size(3) - x.size(3)
|
||||
x = F.pad(
|
||||
x, [diffX // 2, diffX - diffX // 2, diffY // 2, diffY - diffY // 2]
|
||||
)
|
||||
x = torch.cat([x, skip_input], dim=1)
|
||||
return x
|
||||
|
||||
|
||||
class GeneratorUNet(nn.Module):
|
||||
"""Генератор на основе U-Net архитектуры для преобразования Yandex → Google"""
|
||||
|
||||
def __init__(self, in_channels: int = 3, out_channels: int = 3):
|
||||
super().__init__()
|
||||
|
||||
# Downsampling path
|
||||
self.down1 = UNetDownBlock(in_channels, 64, normalize=False)
|
||||
self.down2 = UNetDownBlock(64, 128)
|
||||
self.down3 = UNetDownBlock(128, 256)
|
||||
self.down4 = UNetDownBlock(256, 512)
|
||||
self.down5 = UNetDownBlock(512, 512)
|
||||
self.down6 = UNetDownBlock(512, 512)
|
||||
self.down7 = UNetDownBlock(512, 512)
|
||||
|
||||
# Bottleneck
|
||||
self.bottleneck = nn.Sequential(
|
||||
nn.Conv2d(512, 512, kernel_size=4, stride=2, padding=1, bias=False),
|
||||
nn.ReLU(inplace=True),
|
||||
)
|
||||
|
||||
# Upsampling path
|
||||
self.up1 = UNetUpBlock(512, 512, dropout=0.5)
|
||||
self.up2 = UNetUpBlock(1024, 512, dropout=0.5)
|
||||
self.up3 = UNetUpBlock(1024, 512, dropout=0.5)
|
||||
self.up4 = UNetUpBlock(1024, 512)
|
||||
self.up5 = UNetUpBlock(1024, 256)
|
||||
self.up6 = UNetUpBlock(512, 128)
|
||||
self.up7 = UNetUpBlock(256, 64)
|
||||
|
||||
# Final layer
|
||||
self.final = nn.Sequential(
|
||||
nn.ConvTranspose2d(128, out_channels, kernel_size=4, stride=2, padding=1),
|
||||
nn.Tanh(),
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
# Downsampling
|
||||
d1 = self.down1(x) # 350x350
|
||||
d2 = self.down2(d1) # 175x175
|
||||
d3 = self.down3(d2) # 88x88
|
||||
d4 = self.down4(d3) # 44x44
|
||||
d5 = self.down5(d4) # 22x22
|
||||
d6 = self.down6(d5) # 11x11
|
||||
d7 = self.down7(d6) # 6x6
|
||||
|
||||
# Bottleneck
|
||||
u = self.bottleneck(d7) # 3x3
|
||||
|
||||
# Upsampling with skip connections
|
||||
u = self.up1(u, d7) # 6x6
|
||||
u = self.up2(u, d6) # 11x11
|
||||
u = self.up3(u, d5) # 22x22
|
||||
u = self.up4(u, d4) # 44x44
|
||||
u = self.up5(u, d3) # 88x88
|
||||
u = self.up6(u, d2) # 175x175
|
||||
u = self.up7(u, d1) # 350x350
|
||||
|
||||
# Final output
|
||||
return self.final(u) # 700x700
|
||||
|
||||
|
||||
class DiscriminatorPatchGAN(nn.Module):
|
||||
"""Дискриминатор PatchGAN для изображений 700x700"""
|
||||
|
||||
def __init__(
|
||||
self, in_channels: int = 6
|
||||
): # 3 для реального + 3 для сгенерированного
|
||||
super().__init__()
|
||||
|
||||
def discriminator_block(
|
||||
in_filters: int, out_filters: int, normalization: bool = True
|
||||
):
|
||||
"""Блок дискриминатора"""
|
||||
layers = [
|
||||
nn.Conv2d(in_filters, out_filters, kernel_size=4, stride=2, padding=1)
|
||||
]
|
||||
if normalization:
|
||||
layers.append(nn.BatchNorm2d(out_filters))
|
||||
layers.append(nn.LeakyReLU(0.2, inplace=True))
|
||||
return layers
|
||||
|
||||
self.model = nn.Sequential(
|
||||
*discriminator_block(in_channels, 64, normalization=False), # 350x350
|
||||
*discriminator_block(64, 128), # 175x175
|
||||
*discriminator_block(128, 256), # 88x88
|
||||
*discriminator_block(256, 512), # 44x44
|
||||
nn.Conv2d(512, 1, kernel_size=4, stride=1, padding=1), # 41x41
|
||||
nn.Sigmoid(),
|
||||
)
|
||||
|
||||
def forward(self, img_A: torch.Tensor, img_B: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Принимает пару изображений (реальное и сгенерированное)
|
||||
и возвращает вероятность того, что пара реальная
|
||||
"""
|
||||
# Объединяем два изображения по каналам
|
||||
img_input = torch.cat((img_A, img_B), 1)
|
||||
return self.model(img_input)
|
||||
|
||||
|
||||
class GANLoss(nn.Module):
|
||||
"""Функция потерь для GAN"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
gan_mode: str = "vanilla",
|
||||
target_real_label: float = 1.0,
|
||||
target_fake_label: float = 0.0,
|
||||
):
|
||||
super().__init__()
|
||||
self.register_buffer("real_label", torch.tensor(target_real_label))
|
||||
self.register_buffer("fake_label", torch.tensor(target_fake_label))
|
||||
self.gan_mode = gan_mode
|
||||
|
||||
if gan_mode == "vanilla":
|
||||
self.loss = nn.BCEWithLogitsLoss()
|
||||
elif gan_mode == "lsgan":
|
||||
self.loss = nn.MSELoss()
|
||||
elif gan_mode == "wgangp":
|
||||
self.loss = None
|
||||
else:
|
||||
raise NotImplementedError(f"GAN mode {gan_mode} not implemented")
|
||||
|
||||
def get_target_tensor(
|
||||
self, prediction: torch.Tensor, target_is_real: bool
|
||||
) -> torch.Tensor:
|
||||
"""Создает тензор меток"""
|
||||
if target_is_real:
|
||||
target_tensor = self.real_label
|
||||
else:
|
||||
target_tensor = self.fake_label
|
||||
return target_tensor.expand_as(prediction)
|
||||
|
||||
def __call__(self, prediction: torch.Tensor, target_is_real: bool) -> torch.Tensor:
|
||||
"""Вычисляет потери"""
|
||||
if self.gan_mode in ["vanilla", "lsgan"]:
|
||||
target_tensor = self.get_target_tensor(prediction, target_is_real)
|
||||
loss = self.loss(prediction, target_tensor)
|
||||
elif self.gan_mode == "wgangp":
|
||||
if target_is_real:
|
||||
loss = -prediction.mean()
|
||||
else:
|
||||
loss = prediction.mean()
|
||||
return loss
|
||||
|
||||
|
||||
class ImageGAN(nn.Module):
|
||||
"""Основной класс GAN для преобразования изображений Yandex → Google"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_channels: int = 3,
|
||||
output_channels: int = 3,
|
||||
gan_mode: str = "vanilla",
|
||||
lambda_L1: float = 100.0,
|
||||
use_cuda: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.generator = GeneratorUNet(input_channels, output_channels)
|
||||
self.discriminator = DiscriminatorPatchGAN(input_channels + output_channels)
|
||||
self.gan_loss = GANLoss(gan_mode)
|
||||
self.l1_loss = nn.L1Loss()
|
||||
self.lambda_L1 = lambda_L1
|
||||
|
||||
self.device = torch.device(
|
||||
"cuda" if use_cuda and torch.cuda.is_available() else "cpu"
|
||||
)
|
||||
self.to(self.device)
|
||||
|
||||
def forward(self, yandex_image: torch.Tensor) -> torch.Tensor:
|
||||
"""Генерация изображения Google из Yandex"""
|
||||
return self.generator(yandex_image)
|
||||
|
||||
def generator_step(
|
||||
self, yandex_image: torch.Tensor, real_google_image: torch.Tensor
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Шаг обучения генератора
|
||||
|
||||
Returns:
|
||||
total_loss: общие потери генератора
|
||||
gan_loss: потери GAN
|
||||
l1_loss: потери L1
|
||||
"""
|
||||
# Генерируем изображение
|
||||
fake_google_image = self.generator(yandex_image)
|
||||
|
||||
# Оцениваем дискриминатором
|
||||
fake_pred = self.discriminator(yandex_image, fake_google_image)
|
||||
|
||||
# Потери GAN (пытаемся обмануть дискриминатор)
|
||||
gan_loss = self.gan_loss(fake_pred, True)
|
||||
|
||||
# Потери L1 для сохранения структуры
|
||||
l1_loss = self.l1_loss(fake_google_image, real_google_image) * self.lambda_L1
|
||||
|
||||
# Общие потери
|
||||
total_loss = gan_loss + l1_loss
|
||||
|
||||
return total_loss, gan_loss, l1_loss
|
||||
|
||||
def discriminator_step(
|
||||
self,
|
||||
yandex_image: torch.Tensor,
|
||||
real_google_image: torch.Tensor,
|
||||
fake_google_image: torch.Tensor,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Шаг обучения дискриминатора
|
||||
|
||||
Returns:
|
||||
total_loss: общие потери дискриминатора
|
||||
real_loss: потери на реальных изображениях
|
||||
fake_loss: потери на сгенерированных изображениях
|
||||
"""
|
||||
# Предсказания для реальных пар
|
||||
real_pred = self.discriminator(yandex_image, real_google_image)
|
||||
real_loss = self.gan_loss(real_pred, True)
|
||||
|
||||
# Предсказания для сгенерированных пар
|
||||
fake_pred = self.discriminator(yandex_image, fake_google_image.detach())
|
||||
fake_loss = self.gan_loss(fake_pred, False)
|
||||
|
||||
# Общие потери дискриминатора
|
||||
total_loss = (real_loss + fake_loss) * 0.5
|
||||
|
||||
return total_loss, real_loss, fake_loss
|
||||
|
||||
def to(self, device):
|
||||
"""Перемещает модель на устройство"""
|
||||
self.generator.to(device)
|
||||
self.discriminator.to(device)
|
||||
return self
|
||||
|
||||
def train_mode(self):
|
||||
"""Переключает модель в режим обучения"""
|
||||
self.generator.train()
|
||||
self.discriminator.train()
|
||||
|
||||
def eval_mode(self):
|
||||
"""Переключает модель в режим оценки"""
|
||||
self.generator.eval()
|
||||
self.discriminator.eval()
|
||||
|
||||
def save_checkpoint(self, path: str):
|
||||
"""Сохраняет чекпоинт модели"""
|
||||
checkpoint = {
|
||||
"generator_state_dict": self.generator.state_dict(),
|
||||
"discriminator_state_dict": self.discriminator.state_dict(),
|
||||
"generator_optimizer_state_dict": getattr(
|
||||
self.generator, "optimizer_state_dict", None
|
||||
),
|
||||
"discriminator_optimizer_state_dict": getattr(
|
||||
self.discriminator, "optimizer_state_dict", None
|
||||
),
|
||||
}
|
||||
torch.save(checkpoint, path)
|
||||
|
||||
def load_checkpoint(self, path: str):
|
||||
"""Загружает чекпоинт модели"""
|
||||
checkpoint = torch.load(path, map_location=self.device)
|
||||
self.generator.load_state_dict(checkpoint["generator_state_dict"])
|
||||
self.discriminator.load_state_dict(checkpoint["discriminator_state_dict"])
|
||||
|
||||
if checkpoint["generator_optimizer_state_dict"] is not None:
|
||||
self.generator.optimizer_state_dict = checkpoint[
|
||||
"generator_optimizer_state_dict"
|
||||
]
|
||||
if checkpoint["discriminator_optimizer_state_dict"] is not None:
|
||||
self.discriminator.optimizer_state_dict = checkpoint[
|
||||
"discriminator_optimizer_state_dict"
|
||||
]
|
||||
|
||||
|
||||
def create_image_gan(
|
||||
input_channels: int = 3,
|
||||
output_channels: int = 3,
|
||||
gan_mode: str = "vanilla",
|
||||
lambda_L1: float = 100.0,
|
||||
use_cuda: bool = True,
|
||||
) -> ImageGAN:
|
||||
"""
|
||||
Создает и возвращает модель GAN для преобразования изображений
|
||||
|
||||
Args:
|
||||
input_channels: количество входных каналов (обычно 3 для RGB)
|
||||
output_channels: количество выходных каналов (обычно 3 для RGB)
|
||||
gan_mode: режим GAN ('vanilla', 'lsgan', 'wgangp')
|
||||
lambda_L1: вес L1 потерь
|
||||
use_cuda: использовать ли CUDA если доступно
|
||||
|
||||
Returns:
|
||||
ImageGAN: модель GAN
|
||||
"""
|
||||
return ImageGAN(
|
||||
input_channels=input_channels,
|
||||
output_channels=output_channels,
|
||||
gan_mode=gan_mode,
|
||||
lambda_L1=lambda_L1,
|
||||
use_cuda=use_cuda,
|
||||
)
|
||||
|
||||
|
||||
# Вспомогательные функции для инициализации весов
|
||||
def weights_init_normal(m):
|
||||
"""Инициализация весов с нормальным распределением"""
|
||||
classname = m.__class__.__name__
|
||||
if classname.find("Conv") != -1:
|
||||
nn.init.normal_(m.weight.data, 0.0, 0.02)
|
||||
elif classname.find("BatchNorm") != -1:
|
||||
nn.init.normal_(m.weight.data, 1.0, 0.02)
|
||||
nn.init.constant_(m.batch_norm.bias.data, 0.0)
|
||||
|
||||
|
||||
def initialize_gan_weights(generator: nn.Module, discriminator: nn.Module):
|
||||
"""Инициализирует веса генератора и дискриминатора"""
|
||||
generator.apply(weights_init_normal)
|
||||
discriminator.apply(weights_init_normal)
|
||||
30
models/GAN/main.py
Normal file
30
models/GAN/main.py
Normal file
@@ -0,0 +1,30 @@
|
||||
"""Main entry point for GAN training."""
|
||||
|
||||
from config import create_config
|
||||
from dataloader import create_data_loaders
|
||||
from model import create_gan
|
||||
from trainer import create_trainer
|
||||
|
||||
|
||||
def main():
|
||||
"""Run training pipeline."""
|
||||
config = create_config()
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
# Create components
|
||||
model = create_gan(use_cuda=False) # Set to True to use GPU
|
||||
train_loader, val_loader = create_data_loaders(
|
||||
root_dir=config["data_dir"],
|
||||
batch_size=config["batch_size"],
|
||||
image_size=tuple(config["image_size"]),
|
||||
num_workers=config["num_workers"],
|
||||
)
|
||||
trainer = create_trainer(model, train_loader, val_loader, config)
|
||||
|
||||
# Train
|
||||
trainer.train(config["epochs"])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import torch
|
||||
main()
|
||||
@@ -1,136 +0,0 @@
|
||||
"""
|
||||
Минимальный пример использования GAN trainer для преобразования Yandex → Google карт.
|
||||
|
||||
Этот пример показывает самый простой способ использования тренера.
|
||||
"""
|
||||
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from torch.utils.data import DataLoader, Dataset
|
||||
|
||||
# Добавляем путь к модулям
|
||||
sys.path.append(str(Path(__file__).parent.parent.parent))
|
||||
|
||||
from models.GAN.gan import create_image_gan
|
||||
from models.GAN.trainer import GANTrainer
|
||||
|
||||
|
||||
class SimpleMapDataset(Dataset):
|
||||
"""Простой датасет с фиктивными данными для примера."""
|
||||
|
||||
def __init__(self, num_samples=100, image_size=(256, 256)):
|
||||
self.num_samples = num_samples
|
||||
self.image_size = image_size
|
||||
|
||||
def __len__(self):
|
||||
return self.num_samples
|
||||
|
||||
def __getitem__(self, idx):
|
||||
# Создаем фиктивные изображения
|
||||
# В реальном коде замените на загрузку реальных изображений
|
||||
yandex_img = torch.randn(3, self.image_size[0], self.image_size[1])
|
||||
google_img = torch.randn(3, self.image_size[0], self.image_size[1])
|
||||
|
||||
return {"yandex_img": yandex_img, "google_img": google_img}
|
||||
|
||||
|
||||
def main():
|
||||
"""Основная функция минимального примера."""
|
||||
print("Минимальный пример использования GAN trainer")
|
||||
print("=" * 50)
|
||||
|
||||
# 1. Конфигурация (минимальный набор параметров)
|
||||
config = {
|
||||
"learning_rate": 2e-4,
|
||||
"batch_size": 4,
|
||||
"output_dir": "runs/gan_minimal",
|
||||
}
|
||||
|
||||
# 2. Устройство (CPU или GPU)
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
print(f"Используемое устройство: {device}")
|
||||
|
||||
# 3. Создание модели
|
||||
print("\nСоздание GAN модели...")
|
||||
model = create_image_gan(
|
||||
input_channels=3,
|
||||
output_channels=3,
|
||||
gan_mode="vanilla", # Простейший режим
|
||||
lambda_L1=100.0, # Стандартный вес L1 потерь
|
||||
use_cuda=(device.type == "cuda"),
|
||||
)
|
||||
|
||||
# 4. Создание даталоадеров
|
||||
print("Создание даталоадеров...")
|
||||
train_dataset = SimpleMapDataset(num_samples=50)
|
||||
val_dataset = SimpleMapDataset(num_samples=10)
|
||||
|
||||
train_loader = DataLoader(
|
||||
train_dataset,
|
||||
batch_size=config["batch_size"],
|
||||
shuffle=True,
|
||||
num_workers=0,
|
||||
)
|
||||
|
||||
val_loader = DataLoader(
|
||||
val_dataset,
|
||||
batch_size=config["batch_size"],
|
||||
shuffle=False,
|
||||
num_workers=0,
|
||||
)
|
||||
|
||||
print(f" Обучающих примеров: {len(train_dataset)}")
|
||||
print(f" Валидационных примеров: {len(val_dataset)}")
|
||||
|
||||
# 5. Создание тренера
|
||||
print("\nСоздание тренера...")
|
||||
trainer = GANTrainer(
|
||||
model=model,
|
||||
train_loader=train_loader,
|
||||
val_loader=val_loader,
|
||||
device=device,
|
||||
config=config,
|
||||
)
|
||||
|
||||
# 6. Обучение на небольшом количестве эпох
|
||||
print("\nЗапуск обучения (3 эпохи для примера)...")
|
||||
print("=" * 50)
|
||||
|
||||
trainer.train(num_epochs=3)
|
||||
|
||||
# 7. Генерация примеров
|
||||
print("\nГенерация примеров преобразования...")
|
||||
model.eval()
|
||||
|
||||
# Создаем тестовые данные
|
||||
test_yandex = torch.randn(2, 3, 256, 256).to(device)
|
||||
|
||||
with torch.no_grad():
|
||||
generated_google = model(test_yandex)
|
||||
|
||||
print(f"Входные изображения: {test_yandex.shape}")
|
||||
print(f"Сгенерированные изображения: {generated_google.shape}")
|
||||
print(
|
||||
f"Диапазон значений: [{generated_google.min():.3f}, {generated_google.max():.3f}]"
|
||||
)
|
||||
|
||||
# 8. Сохранение финальной модели
|
||||
print("\nСохранение модели...")
|
||||
model_save_path = "gan_model_minimal.pth"
|
||||
torch.save(model.state_dict(), model_save_path)
|
||||
print(f"Модель сохранена в: {model_save_path}")
|
||||
|
||||
print("\n" + "=" * 50)
|
||||
print("Минимальный пример завершен!")
|
||||
print("\nДля реального использования:")
|
||||
print("1. Замените SimpleMapDataset на ваш реальный датасет")
|
||||
print("2. Настройте параметры в config")
|
||||
print("3. Увеличьте количество эпох (например, до 100)")
|
||||
print("4. Используйте реальные изображения карт")
|
||||
print("=" * 50)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
256
models/GAN/model.py
Normal file
256
models/GAN/model.py
Normal file
@@ -0,0 +1,256 @@
|
||||
"""GAN model for image translation Yandex -> Google."""
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class UNetDownBlock(nn.Module):
|
||||
"""Downsampling block for U-Net."""
|
||||
|
||||
def __init__(self, in_channels: int, out_channels: int, normalize: bool = True, dropout: float = 0.0):
|
||||
super().__init__()
|
||||
layers = [
|
||||
nn.Conv2d(in_channels, out_channels, kernel_size=4, stride=2, padding=1, bias=False)
|
||||
]
|
||||
if normalize:
|
||||
layers.append(nn.BatchNorm2d(out_channels))
|
||||
layers.append(nn.LeakyReLU(0.2, inplace=True))
|
||||
if dropout > 0:
|
||||
layers.append(nn.Dropout2d(dropout))
|
||||
self.model = nn.Sequential(*layers)
|
||||
|
||||
def forward(self, x):
|
||||
return self.model(x)
|
||||
|
||||
|
||||
class UNetUpBlock(nn.Module):
|
||||
"""Upsampling block for U-Net."""
|
||||
|
||||
def __init__(self, in_channels: int, out_channels: int, dropout: float = 0.0):
|
||||
super().__init__()
|
||||
self.upconv = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=4, stride=2, padding=1, bias=False)
|
||||
self.norm = nn.BatchNorm2d(out_channels)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
if dropout > 0:
|
||||
self.dropout = nn.Dropout2d(dropout)
|
||||
else:
|
||||
self.dropout = None
|
||||
|
||||
def forward(self, x, skip_input):
|
||||
print(x.shape)
|
||||
print(skip_input.shape)
|
||||
x = self.upconv(x)
|
||||
# Pad if needed to match skip connection size
|
||||
if x.shape != skip_input.shape:
|
||||
diff_h = skip_input.size(2) - x.size(2)
|
||||
diff_w = skip_input.size(3) - x.size(3)
|
||||
x = F.pad(x, [diff_w // 2, diff_w - diff_w // 2, diff_h // 2, diff_h - diff_h // 2])
|
||||
x = self.norm(x)
|
||||
x = self.relu(x)
|
||||
if self.dropout:
|
||||
x = self.dropout(x)
|
||||
x = torch.cat([x, skip_input], dim=1)
|
||||
return x
|
||||
|
||||
|
||||
class GeneratorUNet(nn.Module):
|
||||
"""U-Net generator for Yandex -> Google translation."""
|
||||
|
||||
def __init__(self, in_channels: int = 3, out_channels: int = 3):
|
||||
super().__init__()
|
||||
|
||||
# Downsampling
|
||||
self.down1 = UNetDownBlock(in_channels, 64, normalize=False)
|
||||
self.down2 = UNetDownBlock(64, 128)
|
||||
self.down3 = UNetDownBlock(128, 256)
|
||||
self.down4 = UNetDownBlock(256, 512)
|
||||
self.down5 = UNetDownBlock(512, 512)
|
||||
self.down6 = UNetDownBlock(512, 512)
|
||||
self.down7 = UNetDownBlock(512, 512)
|
||||
|
||||
# Bottleneck
|
||||
self.bottleneck = nn.Sequential(
|
||||
nn.Conv2d(512, 512, kernel_size=4, stride=2, padding=1, bias=False),
|
||||
nn.ReLU(inplace=True),
|
||||
)
|
||||
|
||||
# Upsampling
|
||||
self.up1 = UNetUpBlock(1024, 512, dropout=0.5)
|
||||
self.up2 = UNetUpBlock(1024, 512, dropout=0.5)
|
||||
self.up3 = UNetUpBlock(1024, 512, dropout=0.5)
|
||||
self.up4 = UNetUpBlock(1024, 512)
|
||||
self.up5 = UNetUpBlock(1024, 256)
|
||||
self.up6 = UNetUpBlock(512, 128)
|
||||
self.up7 = UNetUpBlock(256, 64)
|
||||
|
||||
# Final
|
||||
self.final = nn.Sequential(
|
||||
nn.ConvTranspose2d(128, out_channels, kernel_size=4, stride=2, padding=1),
|
||||
nn.Tanh(),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
# Down
|
||||
d1 = self.down1(x)
|
||||
d2 = self.down2(d1)
|
||||
d3 = self.down3(d2)
|
||||
d4 = self.down4(d3)
|
||||
d5 = self.down5(d4)
|
||||
d6 = self.down6(d5)
|
||||
d7 = self.down7(d6)
|
||||
|
||||
# Bottleneck
|
||||
u = self.bottleneck(d7)
|
||||
|
||||
# Up with skip connections
|
||||
u = self.up1(u, d7)
|
||||
u = self.up2(u, d6)
|
||||
u = self.up3(u, d5)
|
||||
u = self.up4(u, d4)
|
||||
u = self.up5(u, d3)
|
||||
u = self.up6(u, d2)
|
||||
u = self.up7(u, d1)
|
||||
|
||||
return self.final(u)
|
||||
|
||||
|
||||
class DiscriminatorPatchGAN(nn.Module):
|
||||
"""PatchGAN discriminator."""
|
||||
|
||||
def __init__(self, in_channels: int = 6):
|
||||
super().__init__()
|
||||
self.model = nn.Sequential(
|
||||
nn.Conv2d(in_channels, 64, kernel_size=4, stride=2, padding=1),
|
||||
nn.LeakyReLU(0.2, inplace=True),
|
||||
nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),
|
||||
nn.BatchNorm2d(128),
|
||||
nn.LeakyReLU(0.2, inplace=True),
|
||||
nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1),
|
||||
nn.BatchNorm2d(256),
|
||||
nn.LeakyReLU(0.2, inplace=True),
|
||||
nn.Conv2d(256, 512, kernel_size=4, stride=2, padding=1),
|
||||
nn.BatchNorm2d(512),
|
||||
nn.LeakyReLU(0.2, inplace=True),
|
||||
nn.Conv2d(512, 1, kernel_size=4, stride=1, padding=1),
|
||||
nn.Sigmoid(),
|
||||
)
|
||||
|
||||
def forward(self, img_A, img_B):
|
||||
x = torch.cat([img_A, img_B], dim=1)
|
||||
return self.model(x)
|
||||
|
||||
|
||||
class GANLoss(nn.Module):
|
||||
"""GAN loss supporting different GAN modes."""
|
||||
|
||||
def __init__(self, gan_mode: str = "vanilla", target_real: float = 1.0, target_fake: float = 0.0):
|
||||
super().__init__()
|
||||
self.gan_mode = gan_mode
|
||||
self.register_buffer("real_label", torch.tensor(target_real))
|
||||
self.register_buffer("fake_label", torch.tensor(target_fake))
|
||||
|
||||
if gan_mode == "vanilla":
|
||||
self.loss_fn = nn.BCEWithLogitsLoss()
|
||||
elif gan_mode == "lsgan":
|
||||
self.loss_fn = nn.MSELoss()
|
||||
elif gan_mode == "wgangp":
|
||||
self.loss_fn = None
|
||||
else:
|
||||
raise ValueError(f"Unknown GAN mode: {gan_mode}")
|
||||
|
||||
def forward(self, prediction: torch.Tensor, target_is_real: bool) -> torch.Tensor:
|
||||
if self.gan_mode in ["vanilla", "lsgan"]:
|
||||
target = self.real_label if target_is_real else self.fake_label
|
||||
target = target.expand_as(prediction)
|
||||
return self.loss_fn(prediction, target)
|
||||
elif self.gan_mode == "wgangp":
|
||||
return -prediction.mean() if target_is_real else prediction.mean()
|
||||
|
||||
|
||||
class ImageGAN(nn.Module):
|
||||
"""Complete GAN model for image translation."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_channels: int = 3,
|
||||
output_channels: int = 3,
|
||||
gan_mode: str = "vanilla",
|
||||
lambda_L1: float = 100.0,
|
||||
use_cuda: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
self.generator = GeneratorUNet(input_channels, output_channels)
|
||||
self.discriminator = DiscriminatorPatchGAN(input_channels + output_channels)
|
||||
self.gan_loss = GANLoss(gan_mode)
|
||||
self.l1_loss = nn.L1Loss()
|
||||
self.lambda_L1 = lambda_L1
|
||||
|
||||
self.device = torch.device("cuda" if use_cuda and torch.cuda.is_available() else "cpu")
|
||||
self.to(self.device)
|
||||
|
||||
def forward(self, yandex_image):
|
||||
"""Generate Google image from Yandex."""
|
||||
return self.generator(yandex_image)
|
||||
|
||||
def generator_step(self, yandex_img, real_google_img):
|
||||
"""Compute generator losses."""
|
||||
fake_google = self.generator(yandex_img)
|
||||
fake_pred = self.discriminator(yandex_img, fake_google)
|
||||
gan_loss = self.gan_loss(fake_pred, True)
|
||||
l1_loss = self.l1_loss(fake_google, real_google_img) * self.lambda_L1
|
||||
total_loss = gan_loss + l1_loss
|
||||
return total_loss, gan_loss, l1_loss
|
||||
|
||||
def discriminator_step(self, yandex_img, real_google_img, fake_google_img):
|
||||
"""Compute discriminator losses."""
|
||||
real_pred = self.discriminator(yandex_img, real_google_img)
|
||||
real_loss = self.gan_loss(real_pred, True)
|
||||
fake_pred = self.discriminator(yandex_img, fake_google_img.detach())
|
||||
fake_loss = self.gan_loss(fake_pred, False)
|
||||
total_loss = (real_loss + fake_loss) * 0.5
|
||||
return total_loss, real_loss, fake_loss
|
||||
|
||||
|
||||
def create_gan(
|
||||
input_channels: int = 3,
|
||||
output_channels: int = 3,
|
||||
gan_mode: str = "vanilla",
|
||||
lambda_L1: float = 100.0,
|
||||
use_cuda: bool = True,
|
||||
) -> ImageGAN:
|
||||
"""Create a GAN model."""
|
||||
return ImageGAN(
|
||||
input_channels=input_channels,
|
||||
output_channels=output_channels,
|
||||
gan_mode=gan_mode,
|
||||
lambda_L1=lambda_L1,
|
||||
use_cuda=use_cuda,
|
||||
)
|
||||
|
||||
|
||||
def initialize_weights(model: nn.Module):
|
||||
"""Initialize model weights."""
|
||||
for m in model.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
nn.init.normal_(m.weight.data, 0.0, 0.02)
|
||||
elif isinstance(m, nn.BatchNorm2d):
|
||||
nn.init.normal_(m.weight.data, 1.0, 0.02)
|
||||
nn.init.constant_(m.bias.data, 0.0)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Quick test
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
model = create_gan(use_cuda=False)
|
||||
print(f"Model created on {model.device}")
|
||||
|
||||
# Test forward pass
|
||||
test_input = torch.randn(2, 3, 256, 256).to(model.device)
|
||||
output = model(test_input)
|
||||
print(f"Output shape: {output.shape}")
|
||||
|
||||
# Count parameters
|
||||
gen_params = sum(p.numel() for p in model.generator.parameters())
|
||||
disc_params = sum(p.numel() for p in model.discriminator.parameters())
|
||||
print(f"Generator: {gen_params:,} params, Discriminator: {disc_params:,} params")
|
||||
@@ -1,349 +0,0 @@
|
||||
import os
|
||||
import sys
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
# Добавляем путь к модулю
|
||||
sys.path.append(
|
||||
os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
)
|
||||
|
||||
from gan import (
|
||||
DiscriminatorPatchGAN,
|
||||
GeneratorUNet,
|
||||
ImageGAN,
|
||||
create_image_gan,
|
||||
initialize_gan_weights,
|
||||
)
|
||||
|
||||
|
||||
def test_generator():
|
||||
"""Тестирование генератора"""
|
||||
print("=" * 60)
|
||||
print("Тестирование генератора...")
|
||||
print("=" * 60)
|
||||
|
||||
# Создаем генератор
|
||||
generator = GeneratorUNet(in_channels=3, out_channels=3)
|
||||
|
||||
# Инициализируем веса
|
||||
generator.apply(
|
||||
lambda m: (
|
||||
nn.init.normal_(m.weight.data, 0.0, 0.02)
|
||||
if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d)
|
||||
else None
|
||||
)
|
||||
)
|
||||
|
||||
# Создаем тестовый входной тензор (Yandex изображение)
|
||||
batch_size = 2
|
||||
height, width = 700, 700
|
||||
yandex_image = torch.randn(batch_size, 3, height, width)
|
||||
|
||||
print(f"Размер входного изображения: {yandex_image.shape}")
|
||||
print(
|
||||
f"Количество параметров генератора: {sum(p.numel() for p in generator.parameters()):,}"
|
||||
)
|
||||
|
||||
# Прямой проход
|
||||
with torch.no_grad():
|
||||
generated_image = generator(yandex_image)
|
||||
|
||||
print(f"Размер сгенерированного изображения: {generated_image.shape}")
|
||||
print(
|
||||
f"Диапазон значений сгенерированного изображения: [{generated_image.min():.3f}, {generated_image.max():.3f}]"
|
||||
)
|
||||
|
||||
# Проверка размеров
|
||||
assert generated_image.shape == (batch_size, 3, height, width), (
|
||||
f"Ожидался размер {(batch_size, 3, height, width)}, получен {generated_image.shape}"
|
||||
)
|
||||
|
||||
print("✓ Генератор работает корректно!")
|
||||
return generator
|
||||
|
||||
|
||||
def test_discriminator():
|
||||
"""Тестирование дискриминатора"""
|
||||
print("\n" + "=" * 60)
|
||||
print("Тестирование дискриминатора...")
|
||||
print("=" * 60)
|
||||
|
||||
# Создаем дискриминатор
|
||||
discriminator = DiscriminatorPatchGAN(in_channels=6) # 3 + 3 канала
|
||||
|
||||
# Инициализируем веса
|
||||
discriminator.apply(
|
||||
lambda m: (
|
||||
nn.init.normal_(m.weight.data, 0.0, 0.02)
|
||||
if isinstance(m, nn.Conv2d)
|
||||
else None
|
||||
)
|
||||
)
|
||||
|
||||
# Создаем тестовые тензоры
|
||||
batch_size = 2
|
||||
height, width = 700, 700
|
||||
|
||||
yandex_image = torch.randn(batch_size, 3, height, width)
|
||||
google_image = torch.randn(batch_size, 3, height, width)
|
||||
|
||||
print(f"Размер Yandex изображения: {yandex_image.shape}")
|
||||
print(f"Размер Google изображения: {google_image.shape}")
|
||||
print(
|
||||
f"Количество параметров дискриминатора: {sum(p.numel() for p in discriminator.parameters()):,}"
|
||||
)
|
||||
|
||||
# Прямой проход
|
||||
with torch.no_grad():
|
||||
prediction = discriminator(yandex_image, google_image)
|
||||
|
||||
print(f"Размер выхода дискриминатора: {prediction.shape}")
|
||||
print(
|
||||
f"Диапазон значений предсказания: [{prediction.min():.3f}, {prediction.max():.3f}]"
|
||||
)
|
||||
|
||||
# Проверка размеров (PatchGAN выдает карту вероятностей)
|
||||
expected_height = 43 # Для изображения 700x700 после 4 downsampling блоков
|
||||
expected_width = 43
|
||||
assert prediction.shape == (batch_size, 1, expected_height, expected_width), (
|
||||
f"Ожидался размер {(batch_size, 1, expected_height, expected_width)}, получен {prediction.shape}"
|
||||
)
|
||||
|
||||
print(
|
||||
f"✓ Дискриминатор работает корректно! Выходной размер: {prediction.shape[2]}x{prediction.shape[3]}"
|
||||
)
|
||||
return discriminator
|
||||
|
||||
|
||||
def test_gan_model():
|
||||
"""Тестирование полной GAN модели"""
|
||||
print("\n" + "=" * 60)
|
||||
print("Тестирование полной GAN модели...")
|
||||
print("=" * 60)
|
||||
|
||||
# Создаем GAN модель
|
||||
gan = ImageGAN(
|
||||
input_channels=3,
|
||||
output_channels=3,
|
||||
gan_mode="vanilla",
|
||||
lambda_L1=100.0,
|
||||
use_cuda=False, # Тестируем на CPU для простоты
|
||||
)
|
||||
|
||||
print(f"Устройство модели: {gan.device}")
|
||||
print(
|
||||
f"Количество параметров генератора: {sum(p.numel() for p in gan.generator.parameters()):,}"
|
||||
)
|
||||
print(
|
||||
f"Количество параметров дискриминатора: {sum(p.numel() for p in gan.discriminator.parameters()):,}"
|
||||
)
|
||||
print(f"Общее количество параметров: {sum(p.numel() for p in gan.parameters()):,}")
|
||||
|
||||
# Создаем тестовые данные
|
||||
batch_size = 2
|
||||
height, width = 700, 700
|
||||
|
||||
yandex_image = torch.randn(batch_size, 3, height, width)
|
||||
real_google_image = torch.randn(batch_size, 3, height, width)
|
||||
|
||||
print(f"\nТестирование прямого прохода...")
|
||||
with torch.no_grad():
|
||||
generated_image = gan(yandex_image)
|
||||
|
||||
print(f"Размер сгенерированного изображения: {generated_image.shape}")
|
||||
|
||||
print(f"\nТестирование шага генератора...")
|
||||
gan.train_mode()
|
||||
|
||||
# Тестируем шаг генератора
|
||||
total_loss, gan_loss, l1_loss = gan.generator_step(yandex_image, real_google_image)
|
||||
|
||||
print(f"Общие потери генератора: {total_loss.item():.6f}")
|
||||
print(f"Потери GAN: {gan_loss.item():.6f}")
|
||||
print(f"Потери L1: {l1_loss.item():.6f}")
|
||||
|
||||
print(f"\nТестирование шага дискриминатора...")
|
||||
# Создаем сгенерированное изображение для дискриминатора
|
||||
with torch.no_grad():
|
||||
fake_google_image = gan.generator(yandex_image)
|
||||
|
||||
total_d_loss, real_loss, fake_loss = gan.discriminator_step(
|
||||
yandex_image, real_google_image, fake_google_image
|
||||
)
|
||||
|
||||
print(f"Общие потери дискриминатора: {total_d_loss.item():.6f}")
|
||||
print(f"Потери на реальных изображениях: {real_loss.item():.6f}")
|
||||
print(f"Потери на сгенерированных изображениях: {fake_loss.item():.6f}")
|
||||
|
||||
print(f"\nТестирование режимов обучения/оценки...")
|
||||
gan.eval_mode()
|
||||
print(f"Генератор в режиме eval: {not gan.generator.training}")
|
||||
print(f"Дискриминатор в режиме eval: {not gan.discriminator.training}")
|
||||
|
||||
gan.train_mode()
|
||||
print(f"Генератор в режиме train: {gan.generator.training}")
|
||||
print(f"Дискриминатор в режиме train: {gan.discriminator.training}")
|
||||
|
||||
print("\n✓ Полная GAN модель работает корректно!")
|
||||
return gan
|
||||
|
||||
|
||||
def test_factory_function():
|
||||
"""Тестирование фабричной функции"""
|
||||
print("\n" + "=" * 60)
|
||||
print("Тестирование фабричной функции...")
|
||||
print("=" * 60)
|
||||
|
||||
# Тестируем разные режимы GAN
|
||||
for gan_mode in ["vanilla", "lsgan"]:
|
||||
print(f"\nСоздание GAN в режиме '{gan_mode}'...")
|
||||
gan = create_image_gan(
|
||||
input_channels=3,
|
||||
output_channels=3,
|
||||
gan_mode=gan_mode,
|
||||
lambda_L1=100.0,
|
||||
use_cuda=False,
|
||||
)
|
||||
|
||||
print(f" Режим GAN: {gan.gan_loss.gan_mode}")
|
||||
print(f" Вес L1 потерь: {gan.lambda_L1}")
|
||||
print(f" Устройство: {gan.device}")
|
||||
|
||||
# Быстрая проверка прямого прохода
|
||||
batch_size = 1
|
||||
yandex_image = torch.randn(batch_size, 3, 700, 700)
|
||||
|
||||
with torch.no_grad():
|
||||
generated = gan(yandex_image)
|
||||
|
||||
print(f" Размер выхода: {generated.shape}")
|
||||
print(f" ✓ GAN в режиме '{gan_mode}' создан успешно")
|
||||
|
||||
print("\n✓ Фабричная функция работает корректно!")
|
||||
|
||||
|
||||
def test_weights_initialization():
|
||||
"""Тестирование инициализации весов"""
|
||||
print("\n" + "=" * 60)
|
||||
print("Тестирование инициализации весов...")
|
||||
print("=" * 60)
|
||||
|
||||
# Создаем модели
|
||||
generator = GeneratorUNet(3, 3)
|
||||
discriminator = DiscriminatorPatchGAN(6)
|
||||
|
||||
# Инициализируем веса
|
||||
initialize_gan_weights(generator, discriminator)
|
||||
|
||||
# Проверяем средние значения весов
|
||||
def check_weights_mean(model, model_name):
|
||||
conv_weights = []
|
||||
for name, param in model.named_parameters():
|
||||
if "weight" in name and (
|
||||
"conv" in name.lower() or "Conv" in str(param.__class__)
|
||||
):
|
||||
conv_weights.append(param.data.mean().item())
|
||||
|
||||
if conv_weights:
|
||||
avg_mean = sum(conv_weights) / len(conv_weights)
|
||||
print(f" Среднее значение весов Conv слоев в {model_name}: {avg_mean:.6f}")
|
||||
# Проверяем, что веса инициализированы около 0
|
||||
assert abs(avg_mean) < 0.1, f"Веса {model_name} не инициализированы около 0"
|
||||
|
||||
check_weights_mean(generator, "генераторе")
|
||||
check_weights_mean(discriminator, "дискриминаторе")
|
||||
|
||||
print("✓ Инициализация весов работает корректно!")
|
||||
|
||||
|
||||
def test_memory_usage():
|
||||
"""Тестирование использования памяти"""
|
||||
print("\n" + "=" * 60)
|
||||
print("Тестирование использования памяти...")
|
||||
print("=" * 60)
|
||||
|
||||
import os
|
||||
|
||||
import psutil
|
||||
|
||||
# Получаем текущее использование памяти
|
||||
process = psutil.Process(os.getpid())
|
||||
memory_before = process.memory_info().rss / 1024 / 1024 # в MB
|
||||
|
||||
print(f"Память до создания моделей: {memory_before:.2f} MB")
|
||||
|
||||
# Создаем несколько моделей
|
||||
models = []
|
||||
for i in range(3):
|
||||
gan = create_image_gan(use_cuda=False)
|
||||
models.append(gan)
|
||||
|
||||
# Делаем тестовый проход
|
||||
batch_size = 1
|
||||
yandex_image = torch.randn(batch_size, 3, 700, 700)
|
||||
real_google_image = torch.randn(batch_size, 3, 700, 700)
|
||||
|
||||
with torch.no_grad():
|
||||
_ = gan(yandex_image)
|
||||
_ = gan.generator_step(yandex_image, real_google_image)
|
||||
|
||||
memory_after = process.memory_info().rss / 1024 / 1024 # в MB
|
||||
memory_used = memory_after - memory_before
|
||||
|
||||
print(f"Память после создания моделей: {memory_after:.2f} MB")
|
||||
print(f"Использовано памяти: {memory_used:.2f} MB")
|
||||
|
||||
# Очищаем модели
|
||||
del models
|
||||
import gc
|
||||
|
||||
gc.collect()
|
||||
|
||||
memory_final = process.memory_info().rss / 1024 / 1024
|
||||
print(f"Память после очистки: {memory_final:.2f} MB")
|
||||
|
||||
print("✓ Тестирование памяти завершено!")
|
||||
|
||||
|
||||
def main():
|
||||
"""Основная функция тестирования"""
|
||||
print("Начало тестирования GAN архитектуры для преобразования Yandex → Google")
|
||||
print("Размер изображения: 700x700 пикселей")
|
||||
print("=" * 60)
|
||||
|
||||
try:
|
||||
# Запускаем все тесты
|
||||
test_generator()
|
||||
test_discriminator()
|
||||
test_gan_model()
|
||||
test_factory_function()
|
||||
test_weights_initialization()
|
||||
test_memory_usage()
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("ВСЕ ТЕСТЫ ПРОЙДЕНЫ УСПЕШНО! 🎉")
|
||||
print("=" * 60)
|
||||
print("\nАрхитектура GAN готова к использованию для преобразования")
|
||||
print("изображений из стиля Yandex в стиль Google.")
|
||||
print("\nОсновные характеристики:")
|
||||
print(" • Генератор: U-Net архитектура")
|
||||
print(" • Дискриминатор: PatchGAN (43x43 патчей)")
|
||||
print(" • Размер входных/выходных изображений: 700x700")
|
||||
print(" • Поддержка режимов: vanilla, lsgan")
|
||||
print(" • L1 регуляризация для сохранения структуры")
|
||||
|
||||
except Exception as e:
|
||||
print(f"\n❌ Ошибка при тестировании: {e}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
return 1
|
||||
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
exit_code = main()
|
||||
sys.exit(exit_code)
|
||||
@@ -1,342 +0,0 @@
|
||||
"""
|
||||
Тестовый скрипт для проверки GAN trainer.
|
||||
"""
|
||||
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.utils.data import DataLoader, Dataset
|
||||
|
||||
# Добавляем путь к модулям
|
||||
sys.path.append(str(Path(__file__).parent.parent.parent))
|
||||
|
||||
from models.GAN.gan import create_image_gan
|
||||
from models.GAN.trainer import GANTrainer
|
||||
|
||||
|
||||
class SimpleDataset(Dataset):
|
||||
"""Простой датасет для тестирования."""
|
||||
|
||||
def __init__(self, num_samples=100, image_size=(256, 256)):
|
||||
self.num_samples = num_samples
|
||||
self.image_size = image_size
|
||||
|
||||
def __len__(self):
|
||||
return self.num_samples
|
||||
|
||||
def __getitem__(self, idx):
|
||||
# Создаем случайные изображения для тестирования
|
||||
yandex_img = torch.randn(3, self.image_size[0], self.image_size[1])
|
||||
google_img = torch.randn(3, self.image_size[0], self.image_size[1])
|
||||
|
||||
return {"yandex_img": yandex_img, "google_img": google_img}
|
||||
|
||||
|
||||
def test_gan_model():
|
||||
"""Тестирование GAN модели."""
|
||||
print("Тестирование GAN модели...")
|
||||
|
||||
# Создаем модель
|
||||
model = create_image_gan(
|
||||
input_channels=3,
|
||||
output_channels=3,
|
||||
gan_mode="vanilla",
|
||||
lambda_L1=100.0,
|
||||
use_cuda=False, # Используем CPU для тестирования
|
||||
)
|
||||
|
||||
# Тестируем forward pass
|
||||
batch_size = 2
|
||||
image_size = (256, 256)
|
||||
yandex_input = torch.randn(batch_size, 3, *image_size)
|
||||
|
||||
with torch.no_grad():
|
||||
output = model(yandex_input)
|
||||
|
||||
print(f"Входной размер: {yandex_input.shape}")
|
||||
print(f"Выходной размер: {output.shape}")
|
||||
print(f"Диапазон выходных значений: [{output.min():.3f}, {output.max():.3f}]")
|
||||
|
||||
# Проверяем, что выход в диапазоне [-1, 1] (из-за Tanh)
|
||||
assert output.min() >= -1.0 and output.max() <= 1.0, "Выход не в диапазоне [-1, 1]"
|
||||
print("✓ Forward pass работает корректно")
|
||||
|
||||
return model
|
||||
|
||||
|
||||
def test_generator_step():
|
||||
"""Тестирование шага генератора."""
|
||||
print("\nТестирование шага генератора...")
|
||||
|
||||
model = create_image_gan(use_cuda=False)
|
||||
model.train_mode()
|
||||
|
||||
batch_size = 2
|
||||
yandex_img = torch.randn(batch_size, 3, 256, 256)
|
||||
google_img = torch.randn(batch_size, 3, 256, 256)
|
||||
|
||||
# Тестируем generator_step
|
||||
total_loss, gan_loss, l1_loss = model.generator_step(yandex_img, google_img)
|
||||
|
||||
print(f"Total loss: {total_loss.item():.6f}")
|
||||
print(f"GAN loss: {gan_loss.item():.6f}")
|
||||
print(f"L1 loss: {l1_loss.item():.6f}")
|
||||
|
||||
assert total_loss.item() > 0, "Потери должны быть положительными"
|
||||
print("✓ Шаг генератора работает корректно")
|
||||
|
||||
|
||||
def test_discriminator_step():
|
||||
"""Тестирование шага дискриминатора."""
|
||||
print("\nТестирование шага дискриминатора...")
|
||||
|
||||
model = create_image_gan(use_cuda=False)
|
||||
model.train_mode()
|
||||
|
||||
batch_size = 2
|
||||
yandex_img = torch.randn(batch_size, 3, 256, 256)
|
||||
google_img = torch.randn(batch_size, 3, 256, 256)
|
||||
|
||||
# Генерируем fake изображение
|
||||
with torch.no_grad():
|
||||
fake_google_img = model.generator(yandex_img)
|
||||
|
||||
# Тестируем discriminator_step
|
||||
total_loss, real_loss, fake_loss = model.discriminator_step(
|
||||
yandex_img, google_img, fake_google_img
|
||||
)
|
||||
|
||||
print(f"Total loss: {total_loss.item():.6f}")
|
||||
print(f"Real loss: {real_loss.item():.6f}")
|
||||
print(f"Fake loss: {fake_loss.item():.6f}")
|
||||
|
||||
assert total_loss.item() > 0, "Потери должны быть положительными"
|
||||
print("✓ Шаг дискриминатора работает корректно")
|
||||
|
||||
|
||||
def test_trainer_initialization():
|
||||
"""Тестирование инициализации тренера."""
|
||||
print("\nТестирование инициализации тренера...")
|
||||
|
||||
# Создаем модель
|
||||
model = create_image_gan(use_cuda=False)
|
||||
|
||||
# Создаем даталоадеры
|
||||
train_dataset = SimpleDataset(num_samples=50)
|
||||
val_dataset = SimpleDataset(num_samples=10)
|
||||
|
||||
train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True)
|
||||
val_loader = DataLoader(val_dataset, batch_size=2, shuffle=False)
|
||||
|
||||
# Конфигурация
|
||||
config = {
|
||||
"learning_rate": 2e-4,
|
||||
"beta1": 0.5,
|
||||
"beta2": 0.999,
|
||||
"output_dir": "test_runs/gan",
|
||||
"early_stopping_patience": 10,
|
||||
}
|
||||
|
||||
# Создаем тренер
|
||||
device = torch.device("cpu")
|
||||
trainer = GANTrainer(
|
||||
model=model,
|
||||
train_loader=train_loader,
|
||||
val_loader=val_loader,
|
||||
device=device,
|
||||
config=config,
|
||||
)
|
||||
|
||||
print(f"Тренер создан успешно")
|
||||
print(f"Оптимизатор генератора: {type(trainer.optimizer_G).__name__}")
|
||||
print(f"Оптимизатор дискриминатора: {type(trainer.optimizer_D).__name__}")
|
||||
print(f"Выходная директория: {trainer.output_dir}")
|
||||
|
||||
assert trainer.output_dir.exists(), "Выходная директория не создана"
|
||||
print("✓ Тренер инициализирован корректно")
|
||||
|
||||
return trainer, train_loader, val_loader
|
||||
|
||||
|
||||
def test_train_epoch():
|
||||
"""Тестирование одной эпохи обучения."""
|
||||
print("\nТестирование одной эпохи обучения...")
|
||||
|
||||
model = create_image_gan(use_cuda=False)
|
||||
train_dataset = SimpleDataset(num_samples=20)
|
||||
train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True)
|
||||
val_loader = DataLoader(SimpleDataset(num_samples=5), batch_size=2, shuffle=False)
|
||||
|
||||
config = {
|
||||
"learning_rate": 2e-4,
|
||||
"beta1": 0.5,
|
||||
"beta2": 0.999,
|
||||
"output_dir": "test_runs/gan",
|
||||
}
|
||||
|
||||
device = torch.device("cpu")
|
||||
trainer = GANTrainer(
|
||||
model=model,
|
||||
train_loader=train_loader,
|
||||
val_loader=val_loader,
|
||||
device=device,
|
||||
config=config,
|
||||
)
|
||||
|
||||
# Запускаем одну эпоху обучения
|
||||
avg_g_loss, avg_d_loss = trainer.train_epoch()
|
||||
|
||||
print(f"Средние потери за эпоху:")
|
||||
print(f" Генератор: {avg_g_loss:.6f}")
|
||||
print(f" Дискриминатор: {avg_d_loss:.6f}")
|
||||
|
||||
assert avg_g_loss > 0, "Потери генератора должны быть положительными"
|
||||
assert avg_d_loss > 0, "Потери дискриминатора должны быть положительными"
|
||||
print("✓ Эпоха обучения завершена успешно")
|
||||
|
||||
|
||||
def test_validation():
|
||||
"""Тестирование валидации."""
|
||||
print("\nТестирование валидации...")
|
||||
|
||||
model = create_image_gan(use_cuda=False)
|
||||
train_loader = DataLoader(SimpleDataset(num_samples=10), batch_size=2, shuffle=True)
|
||||
val_loader = DataLoader(SimpleDataset(num_samples=5), batch_size=2, shuffle=False)
|
||||
|
||||
config = {
|
||||
"learning_rate": 2e-4,
|
||||
"beta1": 0.5,
|
||||
"beta2": 0.999,
|
||||
"output_dir": "test_runs/gan",
|
||||
}
|
||||
|
||||
device = torch.device("cpu")
|
||||
trainer = GANTrainer(
|
||||
model=model,
|
||||
train_loader=train_loader,
|
||||
val_loader=val_loader,
|
||||
device=device,
|
||||
config=config,
|
||||
)
|
||||
|
||||
# Запускаем валидацию
|
||||
val_g_loss, val_d_loss = trainer.validate()
|
||||
|
||||
print(f"Потери на валидации:")
|
||||
print(f" Генератор: {val_g_loss:.6f}")
|
||||
print(f" Дискриминатор: {val_d_loss:.6f}")
|
||||
|
||||
assert val_g_loss > 0, "Потери генератора должны быть положительными"
|
||||
assert val_d_loss > 0, "Потери дискриминатора должны быть положительными"
|
||||
print("✓ Валидация завершена успешно")
|
||||
|
||||
|
||||
def test_checkpoint_saving():
|
||||
"""Тестирование сохранения чекпоинтов."""
|
||||
print("\nТестирование сохранения чекпоинтов...")
|
||||
|
||||
model = create_image_gan(use_cuda=False)
|
||||
train_loader = DataLoader(SimpleDataset(num_samples=10), batch_size=2, shuffle=True)
|
||||
val_loader = DataLoader(SimpleDataset(num_samples=5), batch_size=2, shuffle=False)
|
||||
|
||||
config = {
|
||||
"learning_rate": 2e-4,
|
||||
"beta1": 0.5,
|
||||
"beta2": 0.999,
|
||||
"output_dir": "test_runs/gan_checkpoint",
|
||||
}
|
||||
|
||||
device = torch.device("cpu")
|
||||
trainer = GANTrainer(
|
||||
model=model,
|
||||
train_loader=train_loader,
|
||||
val_loader=val_loader,
|
||||
device=device,
|
||||
config=config,
|
||||
)
|
||||
|
||||
# Сохраняем чекпоинт
|
||||
trainer.save_checkpoint(is_best=True)
|
||||
|
||||
# Проверяем, что файлы созданы
|
||||
checkpoint_files = list(trainer.output_dir.glob("*.pth"))
|
||||
print(f"Создано файлов чекпоинтов: {len(checkpoint_files)}")
|
||||
|
||||
for file in checkpoint_files:
|
||||
print(f" - {file.name}")
|
||||
|
||||
assert len(checkpoint_files) > 0, "Файлы чекпоинтов не созданы"
|
||||
print("✓ Чекпоинты сохранены успешно")
|
||||
|
||||
# Тестируем загрузку чекпоинта
|
||||
checkpoint_path = checkpoint_files[0]
|
||||
print(f"\nТестируем загрузку чекпоинта: {checkpoint_path}")
|
||||
|
||||
# Создаем новую модель и тренер
|
||||
new_model = create_image_gan(use_cuda=False)
|
||||
new_trainer = GANTrainer(
|
||||
model=new_model,
|
||||
train_loader=train_loader,
|
||||
val_loader=val_loader,
|
||||
device=device,
|
||||
config=config,
|
||||
)
|
||||
|
||||
# Загружаем чекпоинт
|
||||
new_trainer.load_checkpoint(str(checkpoint_path))
|
||||
|
||||
print(f"Загружен чекпоинт эпохи: {new_trainer.current_epoch + 1}")
|
||||
print("✓ Чекпоинт загружен успешно")
|
||||
|
||||
|
||||
def main():
|
||||
"""Основная функция тестирования."""
|
||||
print("=" * 60)
|
||||
print("Начало тестирования GAN trainer")
|
||||
print("=" * 60)
|
||||
|
||||
try:
|
||||
# Запускаем все тесты
|
||||
test_gan_model()
|
||||
test_generator_step()
|
||||
test_discriminator_step()
|
||||
test_trainer_initialization()
|
||||
test_train_epoch()
|
||||
test_validation()
|
||||
test_checkpoint_saving()
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("Все тесты пройдены успешно! ✓")
|
||||
print("=" * 60)
|
||||
|
||||
except Exception as e:
|
||||
print(f"\nОшибка при тестировании: {e}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
return 1
|
||||
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Очищаем тестовые директории
|
||||
import shutil
|
||||
|
||||
test_dirs = ["test_runs/gan", "test_runs/gan_checkpoint"]
|
||||
for dir_path in test_dirs:
|
||||
if Path(dir_path).exists():
|
||||
shutil.rmtree(dir_path)
|
||||
|
||||
# Запускаем тесты
|
||||
exit_code = main()
|
||||
|
||||
# Очищаем после тестов
|
||||
for dir_path in test_dirs:
|
||||
if Path(dir_path).exists():
|
||||
shutil.rmtree(dir_path)
|
||||
|
||||
exit(exit_code)
|
||||
@@ -1,347 +0,0 @@
|
||||
"""
|
||||
Пример обучения GAN модели для преобразования Yandex → Google карт.
|
||||
|
||||
Этот скрипт показывает, как использовать GANTrainer для обучения модели.
|
||||
"""
|
||||
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
# Добавляем путь к модулям
|
||||
sys.path.append(str(Path(__file__).parent.parent.parent))
|
||||
|
||||
from models.GAN.gan import create_image_gan
|
||||
from models.GAN.trainer import GANTrainer
|
||||
|
||||
|
||||
def create_simple_config():
|
||||
"""Создает простую конфигурацию для обучения."""
|
||||
config = {
|
||||
# Параметры оптимизатора
|
||||
"learning_rate": 2e-4,
|
||||
"beta1": 0.5,
|
||||
"beta2": 0.999,
|
||||
# Параметры обучения
|
||||
"batch_size": 4,
|
||||
"epochs": 100,
|
||||
# Параметры GAN
|
||||
"gan_mode": "vanilla", # "vanilla", "lsgan", или "wgangp"
|
||||
"lambda_L1": 100.0, # Вес L1 потерь
|
||||
# Регуляризация
|
||||
"grad_clip": 1.0,
|
||||
# Ранняя остановка
|
||||
"early_stopping_patience": 20,
|
||||
# Выходные данные
|
||||
"output_dir": "runs/gan_training",
|
||||
# Логирование
|
||||
"log_interval": 10, # Логировать каждые N батчей
|
||||
"save_interval": 5, # Сохранять чекпоинт каждые N эпох
|
||||
}
|
||||
return config
|
||||
|
||||
|
||||
def create_advanced_config():
|
||||
"""Создает расширенную конфигурацию для обучения."""
|
||||
config = {
|
||||
# Параметры оптимизатора
|
||||
"learning_rate": 2e-4,
|
||||
"beta1": 0.5,
|
||||
"beta2": 0.999,
|
||||
# Планировщик learning rate
|
||||
"use_scheduler": True,
|
||||
"scheduler_type": "linear", # "linear", "cosine", или "plateau"
|
||||
"scheduler_start_epoch": 50,
|
||||
"scheduler_end_epoch": 100,
|
||||
# Параметры обучения
|
||||
"batch_size": 8,
|
||||
"epochs": 200,
|
||||
# Параметры GAN
|
||||
"gan_mode": "lsgan", # LSGAN обычно более стабилен
|
||||
"lambda_L1": 100.0,
|
||||
# Аугментация данных
|
||||
"augmentation": {
|
||||
"random_crop": True,
|
||||
"crop_size": 256,
|
||||
"random_flip": True,
|
||||
"color_jitter": True,
|
||||
"brightness": 0.2,
|
||||
"contrast": 0.2,
|
||||
"saturation": 0.2,
|
||||
"hue": 0.1,
|
||||
},
|
||||
# Регуляризация
|
||||
"grad_clip": 1.0,
|
||||
"weight_decay": 1e-4,
|
||||
# Ранняя остановка
|
||||
"early_stopping_patience": 30,
|
||||
"early_stopping_min_delta": 1e-4,
|
||||
# Выходные данные
|
||||
"output_dir": "runs/gan_advanced",
|
||||
# Логирование
|
||||
"log_interval": 20,
|
||||
"save_interval": 10,
|
||||
"save_best_only": True, # Сохранять только лучшую модель
|
||||
# Визуализация
|
||||
"visualize_samples": True,
|
||||
"num_visualize": 4,
|
||||
"visualize_interval": 5, # Визуализировать каждые N эпох
|
||||
}
|
||||
return config
|
||||
|
||||
|
||||
def print_config_summary(config):
|
||||
"""Печатает сводку конфигурации."""
|
||||
print("=" * 60)
|
||||
print("Конфигурация обучения GAN")
|
||||
print("=" * 60)
|
||||
|
||||
print(f"\nПараметры модели:")
|
||||
print(f" Режим GAN: {config.get('gan_mode', 'vanilla')}")
|
||||
print(f" Вес L1 потерь: {config.get('lambda_L1', 100.0)}")
|
||||
|
||||
print(f"\nПараметры обучения:")
|
||||
print(f" Learning rate: {config.get('learning_rate', 2e-4)}")
|
||||
print(f" Batch size: {config.get('batch_size', 4)}")
|
||||
print(f" Эпох: {config.get('epochs', 100)}")
|
||||
print(f" Beta1: {config.get('beta1', 0.5)}")
|
||||
print(f" Beta2: {config.get('beta2', 0.999)}")
|
||||
|
||||
if config.get("use_scheduler", False):
|
||||
print(f" Планировщик LR: {config.get('scheduler_type', 'linear')}")
|
||||
|
||||
print(f"\nРегуляризация:")
|
||||
print(f" Gradient clipping: {config.get('grad_clip', 1.0)}")
|
||||
if "weight_decay" in config:
|
||||
print(f" Weight decay: {config['weight_decay']}")
|
||||
|
||||
print(f"\nРанняя остановка:")
|
||||
if config.get("early_stopping_patience", 0) > 0:
|
||||
print(f" Patience: {config['early_stopping_patience']} эпох")
|
||||
if "early_stopping_min_delta" in config:
|
||||
print(f" Min delta: {config['early_stopping_min_delta']}")
|
||||
|
||||
print(f"\nВыходные данные:")
|
||||
print(f" Директория: {config.get('output_dir', 'runs/gan')}")
|
||||
print(f" Интервал сохранения: {config.get('save_interval', 5)} эпох")
|
||||
|
||||
print(f"\nЛогирование:")
|
||||
print(f" Интервал логирования: {config.get('log_interval', 10)} батчей")
|
||||
|
||||
print("=" * 60)
|
||||
|
||||
|
||||
def setup_training():
|
||||
"""Настраивает обучение."""
|
||||
print("Настройка обучения GAN...")
|
||||
|
||||
# Выбираем конфигурацию
|
||||
use_advanced = False # Измените на True для расширенной конфигурации
|
||||
|
||||
if use_advanced:
|
||||
config = create_advanced_config()
|
||||
else:
|
||||
config = create_simple_config()
|
||||
|
||||
# Печатаем сводку конфигурации
|
||||
print_config_summary(config)
|
||||
|
||||
# Устройство
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
print(f"\nИспользуемое устройство: {device}")
|
||||
|
||||
if device.type == "cuda":
|
||||
print(f" GPU: {torch.cuda.get_device_name(0)}")
|
||||
print(
|
||||
f" Память: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB"
|
||||
)
|
||||
|
||||
# Создаем модель
|
||||
print("\nСоздание модели...")
|
||||
model = create_image_gan(
|
||||
input_channels=3,
|
||||
output_channels=3,
|
||||
gan_mode=config.get("gan_mode", "vanilla"),
|
||||
lambda_L1=config.get("lambda_L1", 100.0),
|
||||
use_cuda=(device.type == "cuda"),
|
||||
)
|
||||
|
||||
# Создаем даталоадеры
|
||||
print("\nСоздание даталоадеров...")
|
||||
# ЗАМЕНИТЕ ЭТО НА ВАШИ РЕАЛЬНЫЕ ДАННЫЕ
|
||||
# Пример:
|
||||
# from your_dataset_module import create_data_loaders
|
||||
# train_loader, val_loader = create_data_loaders(
|
||||
# data_dir="ваш/путь/к/данным",
|
||||
# batch_size=config["batch_size"],
|
||||
# image_size=(256, 256),
|
||||
# augment=config.get("augmentation", None),
|
||||
# )
|
||||
|
||||
# Для примера создаем фиктивные даталоадеры
|
||||
# ВАЖНО: Замените это на реальные данные!
|
||||
print(" ВНИМАНИЕ: Используются фиктивные данные!")
|
||||
print(" Замените на реальные даталоадеры!")
|
||||
|
||||
import numpy as np
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
class DummyDataset(Dataset):
|
||||
def __init__(self, num_samples=100):
|
||||
self.num_samples = num_samples
|
||||
|
||||
def __len__(self):
|
||||
return self.num_samples
|
||||
|
||||
def __getitem__(self, idx):
|
||||
# Фиктивные данные для примера
|
||||
yandex_img = torch.randn(3, 256, 256)
|
||||
google_img = torch.randn(3, 256, 256)
|
||||
return {"yandex_img": yandex_img, "google_img": google_img}
|
||||
|
||||
train_dataset = DummyDataset(num_samples=100)
|
||||
val_dataset = DummyDataset(num_samples=20)
|
||||
|
||||
train_loader = DataLoader(
|
||||
train_dataset,
|
||||
batch_size=config.get("batch_size", 4),
|
||||
shuffle=True,
|
||||
num_workers=0,
|
||||
)
|
||||
|
||||
val_loader = DataLoader(
|
||||
val_dataset,
|
||||
batch_size=config.get("batch_size", 4),
|
||||
shuffle=False,
|
||||
num_workers=0,
|
||||
)
|
||||
|
||||
print(f" Размер обучающего набора: {len(train_dataset)}")
|
||||
print(f" Размер валидационного набора: {len(val_dataset)}")
|
||||
print(f" Батчей в эпохе: {len(train_loader)}")
|
||||
|
||||
# Создаем тренер
|
||||
print("\nСоздание тренера...")
|
||||
trainer = GANTrainer(
|
||||
model=model,
|
||||
train_loader=train_loader,
|
||||
val_loader=val_loader,
|
||||
device=device,
|
||||
config=config,
|
||||
)
|
||||
|
||||
return trainer, config
|
||||
|
||||
|
||||
def train_model(trainer, config):
|
||||
"""Запускает обучение модели."""
|
||||
print("\n" + "=" * 60)
|
||||
print("Начало обучения")
|
||||
print("=" * 60)
|
||||
|
||||
epochs = config.get("epochs", 100)
|
||||
|
||||
try:
|
||||
trainer.train(num_epochs=epochs)
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("Обучение завершено успешно!")
|
||||
print("=" * 60)
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print("\n\nОбучение прервано пользователем.")
|
||||
print("Сохранение текущего состояния...")
|
||||
trainer.save_checkpoint(is_best=False)
|
||||
|
||||
except Exception as e:
|
||||
print(f"\n\nОшибка при обучении: {e}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
|
||||
# Пытаемся сохранить чекпоинт при ошибке
|
||||
try:
|
||||
trainer.save_checkpoint(is_best=False)
|
||||
print("Текущее состояние сохранено.")
|
||||
except:
|
||||
print("Не удалось сохранить состояние.")
|
||||
|
||||
|
||||
def evaluate_model(trainer, test_loader=None):
|
||||
"""Оценивает обученную модель."""
|
||||
print("\n" + "=" * 60)
|
||||
print("Оценка модели")
|
||||
print("=" * 60)
|
||||
|
||||
if test_loader is None:
|
||||
print("Тестовый даталоадер не предоставлен.")
|
||||
print("Используется валидационный даталоадер для оценки.")
|
||||
test_loader = trainer.val_loader
|
||||
|
||||
metrics = trainer.evaluate(test_loader)
|
||||
|
||||
print("\nМетрики оценки:")
|
||||
for key, value in metrics.items():
|
||||
print(f" {key}: {value:.6f}")
|
||||
|
||||
return metrics
|
||||
|
||||
|
||||
def generate_examples(model, device, num_examples=4):
|
||||
"""Генерирует примеры преобразования."""
|
||||
print("\n" + "=" * 60)
|
||||
print("Генерация примеров")
|
||||
print("=" * 60)
|
||||
|
||||
model.eval()
|
||||
|
||||
# Создаем фиктивные входные данные
|
||||
yandex_input = torch.randn(num_examples, 3, 256, 256).to(device)
|
||||
|
||||
with torch.no_grad():
|
||||
google_output = model(yandex_input)
|
||||
|
||||
print(f"Сгенерировано {num_examples} примеров")
|
||||
print(f"Размер входных данных: {yandex_input.shape}")
|
||||
print(f"Размер выходных данных: {google_output.shape}")
|
||||
|
||||
# Сохраняем примеры (в реальном коде сохраняйте как изображения)
|
||||
print("\nПримеры сгенерированы.")
|
||||
print("В реальном коде сохраняйте их как изображения для визуализации.")
|
||||
|
||||
return yandex_input, google_output
|
||||
|
||||
|
||||
def main():
|
||||
"""Основная функция."""
|
||||
print("=" * 60)
|
||||
print("Пример обучения GAN для преобразования Yandex → Google")
|
||||
print("=" * 60)
|
||||
|
||||
# Настройка
|
||||
trainer, config = setup_training()
|
||||
|
||||
# Обучение
|
||||
train_model(trainer, config)
|
||||
|
||||
# Оценка (требует реальных тестовых данных)
|
||||
# evaluate_model(trainer)
|
||||
|
||||
# Генерация примеров
|
||||
# generate_examples(trainer.model, trainer.device)
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("Скрипт завершен.")
|
||||
print("=" * 60)
|
||||
print("\nСледующие шаги:")
|
||||
print("1. Замените фиктивные даталоадеры на реальные данные")
|
||||
print("2. Настройте параметры в create_simple_config()")
|
||||
print("3. Запустите обучение с реальными данными")
|
||||
print("4. Визуализируйте результаты")
|
||||
print("=" * 60)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,415 +1,191 @@
|
||||
import json
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
from torch.utils.data import DataLoader
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
from tqdm import tqdm
|
||||
|
||||
# Type aliases
|
||||
ModuleType = nn.Module
|
||||
TensorType = torch.Tensor
|
||||
|
||||
|
||||
class GANTrainer:
|
||||
"""Trainer class for GAN model."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: ModuleType,
|
||||
train_loader: DataLoader,
|
||||
val_loader: DataLoader,
|
||||
device: torch.device,
|
||||
config: Dict[str, Any],
|
||||
):
|
||||
"""
|
||||
Initialize the GAN trainer.
|
||||
|
||||
Args:
|
||||
model: GAN model (ImageGAN)
|
||||
train_loader: Training data loader
|
||||
val_loader: Validation data loader
|
||||
device: Device to run training on
|
||||
config: Training configuration dictionary
|
||||
"""
|
||||
self.model = model.to(device)
|
||||
self.train_loader = train_loader
|
||||
self.val_loader = val_loader
|
||||
self.device = device
|
||||
self.config = config
|
||||
|
||||
# Optimizers
|
||||
lr = config.get("learning_rate", 2e-4)
|
||||
beta1 = config.get("beta1", 0.5)
|
||||
beta2 = config.get("beta2", 0.999)
|
||||
|
||||
# Separate optimizers for generator and discriminator
|
||||
# Note: self.model is expected to have .generator and .discriminator attributes
|
||||
self.optimizer_G = optim.Adam(
|
||||
self.model.generator.parameters(), lr=lr, betas=(beta1, beta2)
|
||||
)
|
||||
self.optimizer_D = optim.Adam(
|
||||
self.model.discriminator.parameters(), lr=lr, betas=(beta1, beta2)
|
||||
)
|
||||
|
||||
# Training state
|
||||
self.current_epoch = 0
|
||||
self.best_val_loss = float("inf")
|
||||
self.g_losses: List[float] = []
|
||||
self.d_losses: List[float] = []
|
||||
self.val_g_losses: List[float] = []
|
||||
self.val_d_losses: List[float] = []
|
||||
|
||||
# Create output directory
|
||||
self.output_dir = Path(config.get("output_dir", "runs/gan"))
|
||||
self.output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# TensorBoard writer
|
||||
self.writer = SummaryWriter(log_dir=self.output_dir / "tensorboard")
|
||||
|
||||
# Save configuration
|
||||
config_path = self.output_dir / "config.json"
|
||||
with open(config_path, "w") as f:
|
||||
json.dump(config, f, indent=2)
|
||||
|
||||
print(f"Training configuration saved to {config_path}")
|
||||
# Access parameters through the model's generator and discriminator
|
||||
generator_params = sum(p.numel() for p in self.model.generator.parameters())
|
||||
discriminator_params = sum(
|
||||
p.numel() for p in self.model.discriminator.parameters()
|
||||
)
|
||||
|
||||
print(f"Generator has {generator_params:,} parameters")
|
||||
print(f"Discriminator has {discriminator_params:,} parameters")
|
||||
|
||||
def train_epoch(self) -> Tuple[float, float]:
|
||||
"""
|
||||
Train for one epoch.
|
||||
|
||||
Returns:
|
||||
Tuple of (average generator loss, average discriminator loss)
|
||||
"""
|
||||
self.model.train()
|
||||
total_g_loss = 0.0
|
||||
total_d_loss = 0.0
|
||||
num_batches = len(self.train_loader)
|
||||
|
||||
progress_bar = tqdm(self.train_loader, desc=f"Epoch {self.current_epoch + 1}")
|
||||
for batch_idx, batch in enumerate(progress_bar):
|
||||
# Move data to device
|
||||
yandex_img = batch["yandex_img"].to(self.device)
|
||||
google_img = batch["google_img"].to(self.device)
|
||||
|
||||
# ========== Train Discriminator ==========
|
||||
self.optimizer_D.zero_grad()
|
||||
|
||||
# Generate fake image
|
||||
with torch.no_grad():
|
||||
fake_google_img = self.model.generator(yandex_img)
|
||||
|
||||
# Discriminator loss - returns tuple of tensors
|
||||
d_loss_tuple = self.model.discriminator_step(
|
||||
yandex_img, google_img, fake_google_img
|
||||
)
|
||||
d_loss, d_real_loss, d_fake_loss = d_loss_tuple
|
||||
|
||||
# Backward and optimize discriminator
|
||||
d_loss.backward()
|
||||
self.optimizer_D.step()
|
||||
|
||||
# ========== Train Generator ==========
|
||||
self.optimizer_G.zero_grad()
|
||||
|
||||
# Generate fake image
|
||||
fake_google_img = self.model.generator(yandex_img)
|
||||
|
||||
# Generator loss - returns tuple of tensors
|
||||
g_loss_tuple = self.model.generator_step(yandex_img, google_img)
|
||||
g_loss, g_gan_loss, g_l1_loss = g_loss_tuple
|
||||
|
||||
# Backward and optimize generator
|
||||
g_loss.backward()
|
||||
self.optimizer_G.step()
|
||||
|
||||
# Update statistics
|
||||
total_g_loss += g_loss.item()
|
||||
total_d_loss += d_loss.item()
|
||||
|
||||
# Update progress bar
|
||||
progress_bar.set_postfix(
|
||||
{
|
||||
"g_loss": g_loss.item(),
|
||||
"d_loss": d_loss.item(),
|
||||
"g_l1": g_l1_loss.item(),
|
||||
"d_real": d_real_loss.item(),
|
||||
"d_fake": d_fake_loss.item(),
|
||||
}
|
||||
)
|
||||
|
||||
# Log batch losses to TensorBoard
|
||||
global_step = self.current_epoch * num_batches + batch_idx
|
||||
self.writer.add_scalar("train/batch_g_loss", g_loss.item(), global_step)
|
||||
self.writer.add_scalar("train/batch_d_loss", d_loss.item(), global_step)
|
||||
self.writer.add_scalar(
|
||||
"train/batch_g_l1_loss", g_l1_loss.item(), global_step
|
||||
)
|
||||
self.writer.add_scalar(
|
||||
"train/batch_d_real_loss", d_real_loss.item(), global_step
|
||||
)
|
||||
self.writer.add_scalar(
|
||||
"train/batch_d_fake_loss", d_fake_loss.item(), global_step
|
||||
)
|
||||
|
||||
avg_g_loss = total_g_loss / num_batches
|
||||
avg_d_loss = total_d_loss / num_batches
|
||||
self.g_losses.append(avg_g_loss)
|
||||
self.d_losses.append(avg_d_loss)
|
||||
|
||||
return avg_g_loss, avg_d_loss
|
||||
|
||||
def validate(self) -> Tuple[float, float]:
|
||||
"""
|
||||
Validate the model.
|
||||
|
||||
Returns:
|
||||
Tuple of (average generator validation loss, average discriminator validation loss)
|
||||
"""
|
||||
self.model.eval()
|
||||
total_g_loss = 0.0
|
||||
total_d_loss = 0.0
|
||||
|
||||
progress_bar = tqdm(self.val_loader, desc="Validation")
|
||||
for batch in progress_bar:
|
||||
# Move data to device
|
||||
yandex_img = batch["yandex_img"].to(self.device)
|
||||
google_img = batch["google_img"].to(self.device)
|
||||
|
||||
with torch.no_grad():
|
||||
# Generate fake image
|
||||
fake_google_img = self.model.generator(yandex_img)
|
||||
|
||||
# Generator loss - returns tuple of tensors
|
||||
g_loss_tuple = self.model.generator_step(yandex_img, google_img)
|
||||
g_loss, g_gan_loss, g_l1_loss = g_loss_tuple
|
||||
|
||||
# Discriminator loss - returns tuple of tensors
|
||||
d_loss_tuple = self.model.discriminator_step(
|
||||
yandex_img, google_img, fake_google_img
|
||||
)
|
||||
d_loss, d_real_loss, d_fake_loss = d_loss_tuple
|
||||
|
||||
# Update statistics
|
||||
total_g_loss += g_loss.item()
|
||||
total_d_loss += d_loss.item()
|
||||
|
||||
# Update progress bar
|
||||
progress_bar.set_postfix({"g_loss": g_loss.item(), "d_loss": d_loss.item()})
|
||||
|
||||
avg_g_loss = total_g_loss / len(self.val_loader)
|
||||
avg_d_loss = total_d_loss / len(self.val_loader)
|
||||
self.val_g_losses.append(avg_g_loss)
|
||||
self.val_d_losses.append(avg_d_loss)
|
||||
|
||||
return avg_g_loss, avg_d_loss
|
||||
|
||||
def save_checkpoint(self, is_best: bool = False):
|
||||
"""
|
||||
Save training checkpoint.
|
||||
|
||||
Args:
|
||||
is_best: Whether this is the best model so far
|
||||
"""
|
||||
checkpoint = {
|
||||
"epoch": self.current_epoch,
|
||||
"model_state_dict": self.model.state_dict(),
|
||||
"optimizer_G_state_dict": self.optimizer_G.state_dict(),
|
||||
"optimizer_D_state_dict": self.optimizer_D.state_dict(),
|
||||
"g_losses": self.g_losses,
|
||||
"d_losses": self.d_losses,
|
||||
"val_g_losses": self.val_g_losses,
|
||||
"val_d_losses": self.val_d_losses,
|
||||
"best_val_loss": self.best_val_loss,
|
||||
"config": self.config,
|
||||
}
|
||||
|
||||
# Save regular checkpoint
|
||||
checkpoint_path = (
|
||||
self.output_dir / f"checkpoint_epoch_{self.current_epoch + 1}.pth"
|
||||
)
|
||||
torch.save(checkpoint, checkpoint_path)
|
||||
|
||||
# Save best model
|
||||
if is_best:
|
||||
best_path = self.output_dir / "model_best.pth"
|
||||
torch.save(checkpoint, best_path)
|
||||
print(f"Best model saved to {best_path}")
|
||||
|
||||
def load_checkpoint(self, checkpoint_path: str, resume_training: bool = False):
|
||||
"""
|
||||
Load training checkpoint.
|
||||
|
||||
Args:
|
||||
checkpoint_path: Path to checkpoint file
|
||||
resume_training: Если True, продолжить обучение с сохраненной эпохи
|
||||
"""
|
||||
checkpoint = torch.load(checkpoint_path, map_location=self.device)
|
||||
|
||||
self.model.load_state_dict(checkpoint["model_state_dict"])
|
||||
self.optimizer_G.load_state_dict(checkpoint["optimizer_G_state_dict"])
|
||||
self.optimizer_D.load_state_dict(checkpoint["optimizer_D_state_dict"])
|
||||
|
||||
self.current_epoch = checkpoint["epoch"]
|
||||
self.g_losses = checkpoint["g_losses"]
|
||||
self.d_losses = checkpoint["d_losses"]
|
||||
self.val_g_losses = checkpoint["val_g_losses"]
|
||||
self.val_d_losses = checkpoint["val_d_losses"]
|
||||
self.best_val_loss = checkpoint["best_val_loss"]
|
||||
|
||||
if resume_training:
|
||||
print(f"Resuming training from epoch {self.current_epoch + 1}")
|
||||
else:
|
||||
print(f"Loaded checkpoint from epoch {self.current_epoch + 1}")
|
||||
|
||||
def train(self, num_epochs: int, start_epoch: int = 0):
|
||||
"""
|
||||
Train the model for specified number of epochs.
|
||||
|
||||
Args:
|
||||
num_epochs: Number of epochs to train
|
||||
start_epoch: Starting epoch (useful when resuming training)
|
||||
"""
|
||||
print(
|
||||
f"Starting GAN training for {num_epochs} epochs from epoch {start_epoch + 1}..."
|
||||
)
|
||||
start_time = time.time()
|
||||
|
||||
for epoch in range(start_epoch, start_epoch + num_epochs):
|
||||
self.current_epoch = epoch
|
||||
|
||||
# Train for one epoch
|
||||
train_g_loss, train_d_loss = self.train_epoch()
|
||||
|
||||
# Validate
|
||||
val_g_loss, val_d_loss = self.validate()
|
||||
|
||||
# Log to TensorBoard
|
||||
self.writer.add_scalar("train/epoch_g_loss", train_g_loss, epoch)
|
||||
self.writer.add_scalar("train/epoch_d_loss", train_d_loss, epoch)
|
||||
self.writer.add_scalar("val/epoch_g_loss", val_g_loss, epoch)
|
||||
self.writer.add_scalar("val/epoch_d_loss", val_d_loss, epoch)
|
||||
|
||||
# Print epoch summary
|
||||
print(f"\nEpoch {epoch + 1}/{num_epochs}:")
|
||||
print(" Generator:")
|
||||
print(f" Train Loss: {train_g_loss:.6f}")
|
||||
print(f" Val Loss: {val_g_loss:.6f}")
|
||||
print(" Discriminator:")
|
||||
print(f" Train Loss: {train_d_loss:.6f}")
|
||||
print(f" Val Loss: {val_d_loss:.6f}")
|
||||
|
||||
# Save checkpoint
|
||||
val_total_loss = val_g_loss + val_d_loss
|
||||
is_best = val_total_loss < self.best_val_loss
|
||||
if is_best:
|
||||
self.best_val_loss = val_total_loss
|
||||
|
||||
self.save_checkpoint(is_best=is_best)
|
||||
|
||||
# Early stopping
|
||||
if self.config.get("early_stopping_patience", 0) > 0:
|
||||
val_losses = [
|
||||
g + d for g, d in zip(self.val_g_losses, self.val_d_losses)
|
||||
]
|
||||
if (
|
||||
epoch - np.argmin(val_losses)
|
||||
>= self.config["early_stopping_patience"]
|
||||
):
|
||||
print(f"Early stopping at epoch {epoch + 1}")
|
||||
break
|
||||
|
||||
# Training completed
|
||||
training_time = time.time() - start_time
|
||||
print(f"\nTraining completed in {training_time:.2f} seconds")
|
||||
print(f"Best validation total loss: {self.best_val_loss:.6f}")
|
||||
|
||||
# Save final model
|
||||
final_model_path = self.output_dir / "model_final.pth"
|
||||
torch.save(self.model.state_dict(), final_model_path)
|
||||
print(f"Final model saved to {final_model_path}")
|
||||
|
||||
# Save training history
|
||||
history_path = self.output_dir / "training_history.json"
|
||||
history = {
|
||||
"g_losses": self.g_losses,
|
||||
"d_losses": self.d_losses,
|
||||
"val_g_losses": self.val_g_losses,
|
||||
"val_d_losses": self.val_d_losses,
|
||||
"best_val_loss": self.best_val_loss,
|
||||
"total_epochs": self.current_epoch + 1,
|
||||
}
|
||||
with open(history_path, "w") as f:
|
||||
json.dump(history, f, indent=2)
|
||||
print(f"Training history saved to {history_path}")
|
||||
|
||||
# Close TensorBoard writer
|
||||
self.writer.close()
|
||||
|
||||
def evaluate(self, test_loader: DataLoader) -> Dict:
|
||||
"""
|
||||
Evaluate the model on test data.
|
||||
|
||||
Args:
|
||||
test_loader: Test data loader
|
||||
|
||||
Returns:
|
||||
Dictionary with evaluation metrics
|
||||
"""
|
||||
self.model.eval()
|
||||
total_g_loss = 0.0
|
||||
total_d_loss = 0.0
|
||||
|
||||
print("Evaluating model on test set...")
|
||||
|
||||
for batch in tqdm(test_loader, desc="Evaluation"):
|
||||
# Move data to device
|
||||
yandex_img = batch["yandex_img"].to(self.device)
|
||||
google_img = batch["google_img"].to(self.device)
|
||||
|
||||
with torch.no_grad():
|
||||
# Generate fake image
|
||||
fake_google_img = self.model.generator(yandex_img)
|
||||
|
||||
# Generator loss - returns tuple of tensors
|
||||
g_loss_tuple = self.model.generator_step(yandex_img, google_img)
|
||||
g_loss, g_gan_loss, g_l1_loss = g_loss_tuple
|
||||
|
||||
# Discriminator loss - returns tuple of tensors
|
||||
d_loss_tuple = self.model.discriminator_step(
|
||||
yandex_img, google_img, fake_google_img
|
||||
)
|
||||
d_loss, d_real_loss, d_fake_loss = d_loss_tuple
|
||||
|
||||
# Update statistics
|
||||
total_g_loss += g_loss.item()
|
||||
total_d_loss += d_loss.item()
|
||||
|
||||
avg_g_loss = total_g_loss / len(test_loader)
|
||||
avg_d_loss = total_d_loss / len(test_loader)
|
||||
|
||||
metrics = {
|
||||
"generator_loss": avg_g_loss,
|
||||
"discriminator_loss": avg_d_loss,
|
||||
"total_loss": avg_g_loss + avg_d_loss,
|
||||
}
|
||||
|
||||
print("\nTest Results:")
|
||||
print(f" Generator Loss: {avg_g_loss:.6f}")
|
||||
print(f" Discriminator Loss: {avg_d_loss:.6f}")
|
||||
print(f" Total Loss: {avg_g_loss + avg_d_loss:.6f}")
|
||||
|
||||
return metrics
|
||||
"""Trainer for GAN model."""
|
||||
|
||||
import json
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Tuple
|
||||
|
||||
import torch
|
||||
from torch.utils.data import DataLoader
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
class GANTrainer:
|
||||
"""Simple GAN trainer."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: torch.nn.Module,
|
||||
train_loader: DataLoader,
|
||||
val_loader: DataLoader,
|
||||
config: Dict[str, Any],
|
||||
):
|
||||
self.model = model
|
||||
self.train_loader = train_loader
|
||||
self.val_loader = val_loader
|
||||
self.config = config
|
||||
self.device = model.device
|
||||
|
||||
# Optimizers
|
||||
lr = config.get("learning_rate", 2e-4)
|
||||
beta1 = config.get("beta1", 0.5)
|
||||
beta2 = config.get("beta2", 0.999)
|
||||
self.opt_G = torch.optim.Adam(model.generator.parameters(), lr=lr, betas=(beta1, beta2))
|
||||
self.opt_D = torch.optim.Adam(model.discriminator.parameters(), lr=lr, betas=(beta1, beta2))
|
||||
|
||||
# Training state
|
||||
self.current_epoch = 0
|
||||
self.best_val_loss = float("inf")
|
||||
self.g_losses = []
|
||||
self.d_losses = []
|
||||
self.val_g_losses = []
|
||||
self.val_d_losses = []
|
||||
|
||||
# Output dir
|
||||
self.output_dir = Path(config.get("output_dir", "runs/gan"))
|
||||
self.output_dir.mkdir(parents=True, exist_ok=True)
|
||||
(self.output_dir / "checkpoints").mkdir(exist_ok=True)
|
||||
|
||||
# Save config
|
||||
with open(self.output_dir / "config.json", "w") as f:
|
||||
json.dump(config, f, indent=2)
|
||||
|
||||
def train_epoch(self) -> Tuple[float, float]:
|
||||
"""Train for one epoch."""
|
||||
self.model.train()
|
||||
total_g = total_d = 0.0
|
||||
num_batches = len(self.train_loader)
|
||||
|
||||
pbar = tqdm(self.train_loader, desc=f"Epoch {self.current_epoch + 1}")
|
||||
for batch in pbar:
|
||||
yandex_img = batch["yandex_img"].to(self.device)
|
||||
google_img = batch["google_img"].to(self.device)
|
||||
|
||||
# Train D
|
||||
self.opt_D.zero_grad()
|
||||
with torch.no_grad():
|
||||
fake_img = self.model.generator(yandex_img)
|
||||
d_loss = self.model.discriminator_step(yandex_img, google_img, fake_img)[0]
|
||||
d_loss.backward()
|
||||
self.opt_D.step()
|
||||
|
||||
# Train G
|
||||
self.opt_G.zero_grad()
|
||||
g_loss = self.model.generator_step(yandex_img, google_img)[0]
|
||||
g_loss.backward()
|
||||
self.opt_G.step()
|
||||
|
||||
total_g += g_loss.item()
|
||||
total_d += d_loss.item()
|
||||
pbar.set_postfix({"g_loss": g_loss.item(), "d_loss": d_loss.item()})
|
||||
|
||||
avg_g = total_g / num_batches
|
||||
avg_d = total_d / num_batches
|
||||
self.g_losses.append(avg_g)
|
||||
self.d_losses.append(avg_d)
|
||||
return avg_g, avg_d
|
||||
|
||||
@torch.no_grad()
|
||||
def validate(self) -> Tuple[float, float]:
|
||||
"""Validate the model."""
|
||||
self.model.eval()
|
||||
total_g = total_d = 0.0
|
||||
|
||||
for batch in tqdm(self.val_loader, desc="Val"):
|
||||
yandex_img = batch["yandex_img"].to(self.device)
|
||||
google_img = batch["google_img"].to(self.device)
|
||||
fake_img = self.model.generator(yandex_img)
|
||||
g_loss = self.model.generator_step(yandex_img, google_img)[0]
|
||||
d_loss = self.model.discriminator_step(yandex_img, google_img, fake_img)[0]
|
||||
total_g += g_loss.item()
|
||||
total_d += d_loss.item()
|
||||
|
||||
avg_g = total_g / len(self.val_loader)
|
||||
avg_d = total_d / len(self.val_loader)
|
||||
self.val_g_losses.append(avg_g)
|
||||
self.val_d_losses.append(avg_d)
|
||||
return avg_g, avg_d
|
||||
|
||||
def train(self, num_epochs: int):
|
||||
"""Train the model."""
|
||||
print(f"Training for {num_epochs} epochs...")
|
||||
|
||||
for epoch in range(num_epochs):
|
||||
self.current_epoch = epoch
|
||||
|
||||
# Train & validate
|
||||
train_g, train_d = self.train_epoch()
|
||||
val_g, val_d = self.validate()
|
||||
|
||||
# Save best checkpoint
|
||||
val_total = val_g + val_d
|
||||
if val_total < self.best_val_loss:
|
||||
self.best_val_loss = val_total
|
||||
self.save_checkpoint("best")
|
||||
|
||||
# Periodic checkpoint
|
||||
if (epoch + 1) % self.config.get("save_interval", 5) == 0:
|
||||
self.save_checkpoint(f"epoch_{epoch + 1}")
|
||||
|
||||
print(f"Epoch {epoch + 1}: train_g={train_g:.4f}, train_d={train_d:.4f}, val_g={val_g:.4f}, val_d={val_d:.4f}")
|
||||
|
||||
# Early stopping
|
||||
patience = self.config.get("early_stopping_patience", 0)
|
||||
if patience > 0 and len(self.val_g_losses) > patience:
|
||||
recent = self.val_g_losses[-patience:]
|
||||
if all(l >= min(self.val_g_losses[:-patience]) for l in recent):
|
||||
print(f"Early stopping at epoch {epoch + 1}")
|
||||
break
|
||||
|
||||
# Save final
|
||||
self.save_checkpoint("final")
|
||||
print(f"Training finished. Best val loss: {self.best_val_loss:.4f}")
|
||||
|
||||
def save_checkpoint(self, name: str):
|
||||
"""Save model checkpoint."""
|
||||
path = self.output_dir / "checkpoints" / f"{name}.pth"
|
||||
torch.save({
|
||||
"epoch": self.current_epoch,
|
||||
"generator": self.model.generator.state_dict(),
|
||||
"discriminator": self.model.discriminator.state_dict(),
|
||||
"opt_G": self.opt_G.state_dict(),
|
||||
"opt_D": self.opt_D.state_dict(),
|
||||
}, path)
|
||||
|
||||
|
||||
def create_trainer(
|
||||
model: torch.nn.Module,
|
||||
train_loader: DataLoader,
|
||||
val_loader: DataLoader,
|
||||
config: Dict[str, Any],
|
||||
) -> GANTrainer:
|
||||
"""Create a trainer instance."""
|
||||
return GANTrainer(model, train_loader, val_loader, config)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Quick test
|
||||
from config import create_config
|
||||
from dataloader import create_data_loaders
|
||||
from model import create_gan
|
||||
|
||||
config = create_config()
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
model = create_gan(use_cuda=False)
|
||||
train_loader, val_loader = create_data_loaders(
|
||||
root_dir=config["data_dir"],
|
||||
batch_size=4,
|
||||
image_size=tuple(config["image_size"]),
|
||||
num_workers=0,
|
||||
)
|
||||
|
||||
trainer = create_trainer(model, train_loader, val_loader, config)
|
||||
|
||||
# Test one training step (just to verify no errors)
|
||||
print("Testing one training step...")
|
||||
try:
|
||||
g_loss, d_loss = trainer.train_epoch()
|
||||
print(f"Training step succeeded: G={g_loss:.4f}, D={d_loss:.4f}")
|
||||
except Exception as e:
|
||||
print(f"Training step failed: {e}")
|
||||
Reference in New Issue
Block a user