From c6df3edab81d3dea35721650e00e30e16c13df77 Mon Sep 17 00:00:00 2001 From: russian_proger Date: Sun, 22 Mar 2026 21:10:05 +0300 Subject: [PATCH] ref: simplify and modularize GAN codebase --- models/GAN/config.py | 36 ++ models/GAN/dataloader.py | 162 +++++++++ models/GAN/gan.ipynb | 4 +- models/GAN/gan.py | 393 ---------------------- models/GAN/main.py | 30 ++ models/GAN/minimal_example.py | 136 -------- models/GAN/model.py | 256 ++++++++++++++ models/GAN/test_gan.py | 349 -------------------- models/GAN/test_trainer.py | 342 ------------------- models/GAN/train_example.py | 347 ------------------- models/GAN/trainer.py | 606 +++++++++++----------------------- 11 files changed, 677 insertions(+), 1984 deletions(-) create mode 100644 models/GAN/config.py create mode 100644 models/GAN/dataloader.py delete mode 100644 models/GAN/gan.py create mode 100644 models/GAN/main.py delete mode 100644 models/GAN/minimal_example.py create mode 100644 models/GAN/model.py delete mode 100644 models/GAN/test_gan.py delete mode 100644 models/GAN/test_trainer.py delete mode 100644 models/GAN/train_example.py diff --git a/models/GAN/config.py b/models/GAN/config.py new file mode 100644 index 0000000..b97e1b5 --- /dev/null +++ b/models/GAN/config.py @@ -0,0 +1,36 @@ +"""Configuration for GAN training.""" + + +def create_config(): + """Create default configuration dictionary.""" + return { + # Optimizer params + "learning_rate": 2e-4, + "beta1": 0.5, + "beta2": 0.999, + # Training params + "batch_size": 32, + "epochs": 100, + # GAN params + "gan_mode": "vanilla", + "lambda_L1": 100.0, + # Regularization + "grad_clip": 1.0, + # Early stopping + "early_stopping_patience": 20, + # Output + "output_dir": "runs/gan_training", + # Logging + "log_interval": 10, + "save_interval": 5, + # Data + "data_dir": r"C:\Users\admin\Projects\autopilot\datasets\ya_go_maps\images", + "image_size": [256, 256], + "train_split": 0.8, + "num_workers": 0, + } + + +if __name__ == "__main__": + config = create_config() + print("Default config:", config) \ No newline at end of file diff --git a/models/GAN/dataloader.py b/models/GAN/dataloader.py new file mode 100644 index 0000000..3472db1 --- /dev/null +++ b/models/GAN/dataloader.py @@ -0,0 +1,162 @@ +"""Data loader for Yandex-to-Google image translation.""" + +import os +from typing import Dict, List, Tuple + +import torch +from PIL import Image +from torch.utils.data import DataLoader, Dataset +from torchvision import transforms + + +class YaGoDataset(Dataset): + """Dataset loading pairs of Yandex and Google map images.""" + + def __init__( + self, + root_dir: str, + image_size: Tuple[int, int] = (256, 256), + augment: bool = False, + ): + """ + Args: + root_dir: Directory with images named {idx:04d}_google.png and {idx:04d}_yandex.png + image_size: Target image size (height, width) + augment: Whether to apply augmentation (not implemented for simplicity) + """ + self.root_dir = root_dir + self.image_size = image_size + self.augment = augment + + # Discover image pairs + self.pairs = self._find_pairs() + + # Transform to tensor + normalization + self.transform = transforms.Compose( + [ + transforms.ToTensor(), + transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), + ] + ) + + def _find_pairs(self) -> List[Dict]: + """Find all matching Google-Yandex image pairs.""" + pairs = [] + google_files = [f for f in os.listdir(self.root_dir) if f.endswith("_google.png")] + + for google_file in sorted(google_files): + idx_str = google_file.split("_")[0] + try: + idx = int(idx_str) + except ValueError: + continue + + yandex_file = f"{idx:04d}_yandex.png" + yandex_path = os.path.join(self.root_dir, yandex_file) + + if os.path.exists(yandex_path): + pairs.append( + { + "idx": idx, + "google_path": os.path.join(self.root_dir, google_file), + "yandex_path": yandex_path, + } + ) + + return pairs + + def __len__(self) -> int: + return len(self.pairs) + + def __getitem__(self, idx: int) -> dict: + pair = self.pairs[idx] + + # Load images + google_img = Image.open(pair["google_path"]).convert("RGB") + yandex_img = Image.open(pair["yandex_path"]).convert("RGB") + + # Resize + google_img = google_img.resize((self.image_size[1], self.image_size[0])) + yandex_img = yandex_img.resize((self.image_size[1], self.image_size[0])) + + # Apply transforms + google_tensor = self.transform(google_img) + yandex_tensor = self.transform(yandex_img) + + return { + "google_img": google_tensor, + "yandex_img": yandex_tensor, + "idx": torch.tensor(pair["idx"], dtype=torch.long), + } + + +def create_data_loaders( + root_dir: str, + batch_size: int = 32, + train_split: float = 0.8, + num_workers: int = 0, + image_size: Tuple[int, int] = (256, 256), +) -> Tuple[DataLoader, DataLoader]: + """ + Create train and validation data loaders. + + Args: + root_dir: Directory with image pairs + batch_size: Batch size + train_split: Fraction for training (0.0-1.0) + num_workers: DataLoader workers + image_size: Target image size + + Returns: + (train_loader, val_loader) + """ + # Full dataset + dataset = YaGoDataset(root_dir=root_dir, image_size=image_size) + + # Split + dataset_size = len(dataset) + train_size = int(train_split * dataset_size) + indices = torch.randperm(dataset_size).tolist() + train_indices = indices[:train_size] + val_indices = indices[train_size:] + + # Subsets + from torch.utils.data import Subset + + train_dataset = Subset(dataset, train_indices) + val_dataset = Subset(dataset, val_indices) + + # DataLoaders + train_loader = DataLoader( + train_dataset, + batch_size=batch_size, + shuffle=True, + num_workers=num_workers, + pin_memory=True, + ) + + val_loader = DataLoader( + val_dataset, + batch_size=batch_size, + shuffle=False, + num_workers=num_workers, + pin_memory=True, + ) + + return train_loader, val_loader + + +if __name__ == "__main__": + # Quick test + from config import create_config + + config = create_config() + train_loader, val_loader = create_data_loaders( + root_dir=config["data_dir"], + batch_size=4, + image_size=tuple(config["image_size"]), + ) + + batch = next(iter(train_loader)) + print(f"Batch shapes: google={batch['google_img'].shape}, yandex={batch['yandex_img'].shape}") + print(f"Train batches: {len(train_loader)}, Val batches: {len(val_loader)}") \ No newline at end of file diff --git a/models/GAN/gan.ipynb b/models/GAN/gan.ipynb index eaf007f..06305b3 100644 --- a/models/GAN/gan.ipynb +++ b/models/GAN/gan.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "code", - "execution_count": 23, + "execution_count": 1, "id": "bb583d80", "metadata": {}, "outputs": [], @@ -37,7 +37,7 @@ }, { "cell_type": "code", - "execution_count": 238, + "execution_count": 2, "id": "92144cc0", "metadata": {}, "outputs": [ diff --git a/models/GAN/gan.py b/models/GAN/gan.py deleted file mode 100644 index af34d9b..0000000 --- a/models/GAN/gan.py +++ /dev/null @@ -1,393 +0,0 @@ -from typing import Optional, Tuple - -import torch -import torch.nn as nn -import torch.nn.functional as F - - -class UNetDownBlock(nn.Module): - """Блок downsampling для U-Net генератора""" - - def __init__( - self, - in_channels: int, - out_channels: int, - normalize: bool = True, - dropout: float = 0.0, - ): - super().__init__() - layers = [ - nn.Conv2d( - in_channels, - out_channels, - kernel_size=4, - stride=2, - padding=1, - bias=False, - ) - ] - if normalize: - layers.append(nn.BatchNorm2d(out_channels)) - layers.append(nn.LeakyReLU(0.2, inplace=True)) - if dropout > 0: - layers.append(nn.Dropout2d(dropout)) - self.model = nn.Sequential(*layers) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - return self.model(x) - - -class UNetUpBlock(nn.Module): - """Блок upsampling для U-Net генератора""" - - def __init__(self, in_channels: int, out_channels: int, dropout: float = 0.0): - super().__init__() - layers = [ - nn.ConvTranspose2d( - in_channels, - out_channels, - kernel_size=4, - stride=2, - padding=1, - bias=False, - ), - nn.BatchNorm2d(out_channels), - nn.ReLU(inplace=True), - ] - if dropout > 0: - layers.append(nn.Dropout2d(dropout)) - self.model = nn.Sequential(*layers) - - def forward(self, x: torch.Tensor, skip_input: torch.Tensor) -> torch.Tensor: - x = self.model(x) - # Обрезаем skip connection до размера x, если необходимо - if x.shape != skip_input.shape: - diffY = skip_input.size(2) - x.size(2) - diffX = skip_input.size(3) - x.size(3) - x = F.pad( - x, [diffX // 2, diffX - diffX // 2, diffY // 2, diffY - diffY // 2] - ) - x = torch.cat([x, skip_input], dim=1) - return x - - -class GeneratorUNet(nn.Module): - """Генератор на основе U-Net архитектуры для преобразования Yandex → Google""" - - def __init__(self, in_channels: int = 3, out_channels: int = 3): - super().__init__() - - # Downsampling path - self.down1 = UNetDownBlock(in_channels, 64, normalize=False) - self.down2 = UNetDownBlock(64, 128) - self.down3 = UNetDownBlock(128, 256) - self.down4 = UNetDownBlock(256, 512) - self.down5 = UNetDownBlock(512, 512) - self.down6 = UNetDownBlock(512, 512) - self.down7 = UNetDownBlock(512, 512) - - # Bottleneck - self.bottleneck = nn.Sequential( - nn.Conv2d(512, 512, kernel_size=4, stride=2, padding=1, bias=False), - nn.ReLU(inplace=True), - ) - - # Upsampling path - self.up1 = UNetUpBlock(512, 512, dropout=0.5) - self.up2 = UNetUpBlock(1024, 512, dropout=0.5) - self.up3 = UNetUpBlock(1024, 512, dropout=0.5) - self.up4 = UNetUpBlock(1024, 512) - self.up5 = UNetUpBlock(1024, 256) - self.up6 = UNetUpBlock(512, 128) - self.up7 = UNetUpBlock(256, 64) - - # Final layer - self.final = nn.Sequential( - nn.ConvTranspose2d(128, out_channels, kernel_size=4, stride=2, padding=1), - nn.Tanh(), - ) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - # Downsampling - d1 = self.down1(x) # 350x350 - d2 = self.down2(d1) # 175x175 - d3 = self.down3(d2) # 88x88 - d4 = self.down4(d3) # 44x44 - d5 = self.down5(d4) # 22x22 - d6 = self.down6(d5) # 11x11 - d7 = self.down7(d6) # 6x6 - - # Bottleneck - u = self.bottleneck(d7) # 3x3 - - # Upsampling with skip connections - u = self.up1(u, d7) # 6x6 - u = self.up2(u, d6) # 11x11 - u = self.up3(u, d5) # 22x22 - u = self.up4(u, d4) # 44x44 - u = self.up5(u, d3) # 88x88 - u = self.up6(u, d2) # 175x175 - u = self.up7(u, d1) # 350x350 - - # Final output - return self.final(u) # 700x700 - - -class DiscriminatorPatchGAN(nn.Module): - """Дискриминатор PatchGAN для изображений 700x700""" - - def __init__( - self, in_channels: int = 6 - ): # 3 для реального + 3 для сгенерированного - super().__init__() - - def discriminator_block( - in_filters: int, out_filters: int, normalization: bool = True - ): - """Блок дискриминатора""" - layers = [ - nn.Conv2d(in_filters, out_filters, kernel_size=4, stride=2, padding=1) - ] - if normalization: - layers.append(nn.BatchNorm2d(out_filters)) - layers.append(nn.LeakyReLU(0.2, inplace=True)) - return layers - - self.model = nn.Sequential( - *discriminator_block(in_channels, 64, normalization=False), # 350x350 - *discriminator_block(64, 128), # 175x175 - *discriminator_block(128, 256), # 88x88 - *discriminator_block(256, 512), # 44x44 - nn.Conv2d(512, 1, kernel_size=4, stride=1, padding=1), # 41x41 - nn.Sigmoid(), - ) - - def forward(self, img_A: torch.Tensor, img_B: torch.Tensor) -> torch.Tensor: - """ - Принимает пару изображений (реальное и сгенерированное) - и возвращает вероятность того, что пара реальная - """ - # Объединяем два изображения по каналам - img_input = torch.cat((img_A, img_B), 1) - return self.model(img_input) - - -class GANLoss(nn.Module): - """Функция потерь для GAN""" - - def __init__( - self, - gan_mode: str = "vanilla", - target_real_label: float = 1.0, - target_fake_label: float = 0.0, - ): - super().__init__() - self.register_buffer("real_label", torch.tensor(target_real_label)) - self.register_buffer("fake_label", torch.tensor(target_fake_label)) - self.gan_mode = gan_mode - - if gan_mode == "vanilla": - self.loss = nn.BCEWithLogitsLoss() - elif gan_mode == "lsgan": - self.loss = nn.MSELoss() - elif gan_mode == "wgangp": - self.loss = None - else: - raise NotImplementedError(f"GAN mode {gan_mode} not implemented") - - def get_target_tensor( - self, prediction: torch.Tensor, target_is_real: bool - ) -> torch.Tensor: - """Создает тензор меток""" - if target_is_real: - target_tensor = self.real_label - else: - target_tensor = self.fake_label - return target_tensor.expand_as(prediction) - - def __call__(self, prediction: torch.Tensor, target_is_real: bool) -> torch.Tensor: - """Вычисляет потери""" - if self.gan_mode in ["vanilla", "lsgan"]: - target_tensor = self.get_target_tensor(prediction, target_is_real) - loss = self.loss(prediction, target_tensor) - elif self.gan_mode == "wgangp": - if target_is_real: - loss = -prediction.mean() - else: - loss = prediction.mean() - return loss - - -class ImageGAN(nn.Module): - """Основной класс GAN для преобразования изображений Yandex → Google""" - - def __init__( - self, - input_channels: int = 3, - output_channels: int = 3, - gan_mode: str = "vanilla", - lambda_L1: float = 100.0, - use_cuda: bool = True, - ): - super().__init__() - - self.generator = GeneratorUNet(input_channels, output_channels) - self.discriminator = DiscriminatorPatchGAN(input_channels + output_channels) - self.gan_loss = GANLoss(gan_mode) - self.l1_loss = nn.L1Loss() - self.lambda_L1 = lambda_L1 - - self.device = torch.device( - "cuda" if use_cuda and torch.cuda.is_available() else "cpu" - ) - self.to(self.device) - - def forward(self, yandex_image: torch.Tensor) -> torch.Tensor: - """Генерация изображения Google из Yandex""" - return self.generator(yandex_image) - - def generator_step( - self, yandex_image: torch.Tensor, real_google_image: torch.Tensor - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - """ - Шаг обучения генератора - - Returns: - total_loss: общие потери генератора - gan_loss: потери GAN - l1_loss: потери L1 - """ - # Генерируем изображение - fake_google_image = self.generator(yandex_image) - - # Оцениваем дискриминатором - fake_pred = self.discriminator(yandex_image, fake_google_image) - - # Потери GAN (пытаемся обмануть дискриминатор) - gan_loss = self.gan_loss(fake_pred, True) - - # Потери L1 для сохранения структуры - l1_loss = self.l1_loss(fake_google_image, real_google_image) * self.lambda_L1 - - # Общие потери - total_loss = gan_loss + l1_loss - - return total_loss, gan_loss, l1_loss - - def discriminator_step( - self, - yandex_image: torch.Tensor, - real_google_image: torch.Tensor, - fake_google_image: torch.Tensor, - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - """ - Шаг обучения дискриминатора - - Returns: - total_loss: общие потери дискриминатора - real_loss: потери на реальных изображениях - fake_loss: потери на сгенерированных изображениях - """ - # Предсказания для реальных пар - real_pred = self.discriminator(yandex_image, real_google_image) - real_loss = self.gan_loss(real_pred, True) - - # Предсказания для сгенерированных пар - fake_pred = self.discriminator(yandex_image, fake_google_image.detach()) - fake_loss = self.gan_loss(fake_pred, False) - - # Общие потери дискриминатора - total_loss = (real_loss + fake_loss) * 0.5 - - return total_loss, real_loss, fake_loss - - def to(self, device): - """Перемещает модель на устройство""" - self.generator.to(device) - self.discriminator.to(device) - return self - - def train_mode(self): - """Переключает модель в режим обучения""" - self.generator.train() - self.discriminator.train() - - def eval_mode(self): - """Переключает модель в режим оценки""" - self.generator.eval() - self.discriminator.eval() - - def save_checkpoint(self, path: str): - """Сохраняет чекпоинт модели""" - checkpoint = { - "generator_state_dict": self.generator.state_dict(), - "discriminator_state_dict": self.discriminator.state_dict(), - "generator_optimizer_state_dict": getattr( - self.generator, "optimizer_state_dict", None - ), - "discriminator_optimizer_state_dict": getattr( - self.discriminator, "optimizer_state_dict", None - ), - } - torch.save(checkpoint, path) - - def load_checkpoint(self, path: str): - """Загружает чекпоинт модели""" - checkpoint = torch.load(path, map_location=self.device) - self.generator.load_state_dict(checkpoint["generator_state_dict"]) - self.discriminator.load_state_dict(checkpoint["discriminator_state_dict"]) - - if checkpoint["generator_optimizer_state_dict"] is not None: - self.generator.optimizer_state_dict = checkpoint[ - "generator_optimizer_state_dict" - ] - if checkpoint["discriminator_optimizer_state_dict"] is not None: - self.discriminator.optimizer_state_dict = checkpoint[ - "discriminator_optimizer_state_dict" - ] - - -def create_image_gan( - input_channels: int = 3, - output_channels: int = 3, - gan_mode: str = "vanilla", - lambda_L1: float = 100.0, - use_cuda: bool = True, -) -> ImageGAN: - """ - Создает и возвращает модель GAN для преобразования изображений - - Args: - input_channels: количество входных каналов (обычно 3 для RGB) - output_channels: количество выходных каналов (обычно 3 для RGB) - gan_mode: режим GAN ('vanilla', 'lsgan', 'wgangp') - lambda_L1: вес L1 потерь - use_cuda: использовать ли CUDA если доступно - - Returns: - ImageGAN: модель GAN - """ - return ImageGAN( - input_channels=input_channels, - output_channels=output_channels, - gan_mode=gan_mode, - lambda_L1=lambda_L1, - use_cuda=use_cuda, - ) - - -# Вспомогательные функции для инициализации весов -def weights_init_normal(m): - """Инициализация весов с нормальным распределением""" - classname = m.__class__.__name__ - if classname.find("Conv") != -1: - nn.init.normal_(m.weight.data, 0.0, 0.02) - elif classname.find("BatchNorm") != -1: - nn.init.normal_(m.weight.data, 1.0, 0.02) - nn.init.constant_(m.batch_norm.bias.data, 0.0) - - -def initialize_gan_weights(generator: nn.Module, discriminator: nn.Module): - """Инициализирует веса генератора и дискриминатора""" - generator.apply(weights_init_normal) - discriminator.apply(weights_init_normal) diff --git a/models/GAN/main.py b/models/GAN/main.py new file mode 100644 index 0000000..5534ac9 --- /dev/null +++ b/models/GAN/main.py @@ -0,0 +1,30 @@ +"""Main entry point for GAN training.""" + +from config import create_config +from dataloader import create_data_loaders +from model import create_gan +from trainer import create_trainer + + +def main(): + """Run training pipeline.""" + config = create_config() + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + # Create components + model = create_gan(use_cuda=False) # Set to True to use GPU + train_loader, val_loader = create_data_loaders( + root_dir=config["data_dir"], + batch_size=config["batch_size"], + image_size=tuple(config["image_size"]), + num_workers=config["num_workers"], + ) + trainer = create_trainer(model, train_loader, val_loader, config) + + # Train + trainer.train(config["epochs"]) + + +if __name__ == "__main__": + import torch + main() \ No newline at end of file diff --git a/models/GAN/minimal_example.py b/models/GAN/minimal_example.py deleted file mode 100644 index b58624f..0000000 --- a/models/GAN/minimal_example.py +++ /dev/null @@ -1,136 +0,0 @@ -""" -Минимальный пример использования GAN trainer для преобразования Yandex → Google карт. - -Этот пример показывает самый простой способ использования тренера. -""" - -import sys -from pathlib import Path - -import torch -from torch.utils.data import DataLoader, Dataset - -# Добавляем путь к модулям -sys.path.append(str(Path(__file__).parent.parent.parent)) - -from models.GAN.gan import create_image_gan -from models.GAN.trainer import GANTrainer - - -class SimpleMapDataset(Dataset): - """Простой датасет с фиктивными данными для примера.""" - - def __init__(self, num_samples=100, image_size=(256, 256)): - self.num_samples = num_samples - self.image_size = image_size - - def __len__(self): - return self.num_samples - - def __getitem__(self, idx): - # Создаем фиктивные изображения - # В реальном коде замените на загрузку реальных изображений - yandex_img = torch.randn(3, self.image_size[0], self.image_size[1]) - google_img = torch.randn(3, self.image_size[0], self.image_size[1]) - - return {"yandex_img": yandex_img, "google_img": google_img} - - -def main(): - """Основная функция минимального примера.""" - print("Минимальный пример использования GAN trainer") - print("=" * 50) - - # 1. Конфигурация (минимальный набор параметров) - config = { - "learning_rate": 2e-4, - "batch_size": 4, - "output_dir": "runs/gan_minimal", - } - - # 2. Устройство (CPU или GPU) - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - print(f"Используемое устройство: {device}") - - # 3. Создание модели - print("\nСоздание GAN модели...") - model = create_image_gan( - input_channels=3, - output_channels=3, - gan_mode="vanilla", # Простейший режим - lambda_L1=100.0, # Стандартный вес L1 потерь - use_cuda=(device.type == "cuda"), - ) - - # 4. Создание даталоадеров - print("Создание даталоадеров...") - train_dataset = SimpleMapDataset(num_samples=50) - val_dataset = SimpleMapDataset(num_samples=10) - - train_loader = DataLoader( - train_dataset, - batch_size=config["batch_size"], - shuffle=True, - num_workers=0, - ) - - val_loader = DataLoader( - val_dataset, - batch_size=config["batch_size"], - shuffle=False, - num_workers=0, - ) - - print(f" Обучающих примеров: {len(train_dataset)}") - print(f" Валидационных примеров: {len(val_dataset)}") - - # 5. Создание тренера - print("\nСоздание тренера...") - trainer = GANTrainer( - model=model, - train_loader=train_loader, - val_loader=val_loader, - device=device, - config=config, - ) - - # 6. Обучение на небольшом количестве эпох - print("\nЗапуск обучения (3 эпохи для примера)...") - print("=" * 50) - - trainer.train(num_epochs=3) - - # 7. Генерация примеров - print("\nГенерация примеров преобразования...") - model.eval() - - # Создаем тестовые данные - test_yandex = torch.randn(2, 3, 256, 256).to(device) - - with torch.no_grad(): - generated_google = model(test_yandex) - - print(f"Входные изображения: {test_yandex.shape}") - print(f"Сгенерированные изображения: {generated_google.shape}") - print( - f"Диапазон значений: [{generated_google.min():.3f}, {generated_google.max():.3f}]" - ) - - # 8. Сохранение финальной модели - print("\nСохранение модели...") - model_save_path = "gan_model_minimal.pth" - torch.save(model.state_dict(), model_save_path) - print(f"Модель сохранена в: {model_save_path}") - - print("\n" + "=" * 50) - print("Минимальный пример завершен!") - print("\nДля реального использования:") - print("1. Замените SimpleMapDataset на ваш реальный датасет") - print("2. Настройте параметры в config") - print("3. Увеличьте количество эпох (например, до 100)") - print("4. Используйте реальные изображения карт") - print("=" * 50) - - -if __name__ == "__main__": - main() diff --git a/models/GAN/model.py b/models/GAN/model.py new file mode 100644 index 0000000..1632a3d --- /dev/null +++ b/models/GAN/model.py @@ -0,0 +1,256 @@ +"""GAN model for image translation Yandex -> Google.""" + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class UNetDownBlock(nn.Module): + """Downsampling block for U-Net.""" + + def __init__(self, in_channels: int, out_channels: int, normalize: bool = True, dropout: float = 0.0): + super().__init__() + layers = [ + nn.Conv2d(in_channels, out_channels, kernel_size=4, stride=2, padding=1, bias=False) + ] + if normalize: + layers.append(nn.BatchNorm2d(out_channels)) + layers.append(nn.LeakyReLU(0.2, inplace=True)) + if dropout > 0: + layers.append(nn.Dropout2d(dropout)) + self.model = nn.Sequential(*layers) + + def forward(self, x): + return self.model(x) + + +class UNetUpBlock(nn.Module): + """Upsampling block for U-Net.""" + + def __init__(self, in_channels: int, out_channels: int, dropout: float = 0.0): + super().__init__() + self.upconv = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=4, stride=2, padding=1, bias=False) + self.norm = nn.BatchNorm2d(out_channels) + self.relu = nn.ReLU(inplace=True) + if dropout > 0: + self.dropout = nn.Dropout2d(dropout) + else: + self.dropout = None + + def forward(self, x, skip_input): + print(x.shape) + print(skip_input.shape) + x = self.upconv(x) + # Pad if needed to match skip connection size + if x.shape != skip_input.shape: + diff_h = skip_input.size(2) - x.size(2) + diff_w = skip_input.size(3) - x.size(3) + x = F.pad(x, [diff_w // 2, diff_w - diff_w // 2, diff_h // 2, diff_h - diff_h // 2]) + x = self.norm(x) + x = self.relu(x) + if self.dropout: + x = self.dropout(x) + x = torch.cat([x, skip_input], dim=1) + return x + + +class GeneratorUNet(nn.Module): + """U-Net generator for Yandex -> Google translation.""" + + def __init__(self, in_channels: int = 3, out_channels: int = 3): + super().__init__() + + # Downsampling + self.down1 = UNetDownBlock(in_channels, 64, normalize=False) + self.down2 = UNetDownBlock(64, 128) + self.down3 = UNetDownBlock(128, 256) + self.down4 = UNetDownBlock(256, 512) + self.down5 = UNetDownBlock(512, 512) + self.down6 = UNetDownBlock(512, 512) + self.down7 = UNetDownBlock(512, 512) + + # Bottleneck + self.bottleneck = nn.Sequential( + nn.Conv2d(512, 512, kernel_size=4, stride=2, padding=1, bias=False), + nn.ReLU(inplace=True), + ) + + # Upsampling + self.up1 = UNetUpBlock(1024, 512, dropout=0.5) + self.up2 = UNetUpBlock(1024, 512, dropout=0.5) + self.up3 = UNetUpBlock(1024, 512, dropout=0.5) + self.up4 = UNetUpBlock(1024, 512) + self.up5 = UNetUpBlock(1024, 256) + self.up6 = UNetUpBlock(512, 128) + self.up7 = UNetUpBlock(256, 64) + + # Final + self.final = nn.Sequential( + nn.ConvTranspose2d(128, out_channels, kernel_size=4, stride=2, padding=1), + nn.Tanh(), + ) + + def forward(self, x): + # Down + d1 = self.down1(x) + d2 = self.down2(d1) + d3 = self.down3(d2) + d4 = self.down4(d3) + d5 = self.down5(d4) + d6 = self.down6(d5) + d7 = self.down7(d6) + + # Bottleneck + u = self.bottleneck(d7) + + # Up with skip connections + u = self.up1(u, d7) + u = self.up2(u, d6) + u = self.up3(u, d5) + u = self.up4(u, d4) + u = self.up5(u, d3) + u = self.up6(u, d2) + u = self.up7(u, d1) + + return self.final(u) + + +class DiscriminatorPatchGAN(nn.Module): + """PatchGAN discriminator.""" + + def __init__(self, in_channels: int = 6): + super().__init__() + self.model = nn.Sequential( + nn.Conv2d(in_channels, 64, kernel_size=4, stride=2, padding=1), + nn.LeakyReLU(0.2, inplace=True), + nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1), + nn.BatchNorm2d(128), + nn.LeakyReLU(0.2, inplace=True), + nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1), + nn.BatchNorm2d(256), + nn.LeakyReLU(0.2, inplace=True), + nn.Conv2d(256, 512, kernel_size=4, stride=2, padding=1), + nn.BatchNorm2d(512), + nn.LeakyReLU(0.2, inplace=True), + nn.Conv2d(512, 1, kernel_size=4, stride=1, padding=1), + nn.Sigmoid(), + ) + + def forward(self, img_A, img_B): + x = torch.cat([img_A, img_B], dim=1) + return self.model(x) + + +class GANLoss(nn.Module): + """GAN loss supporting different GAN modes.""" + + def __init__(self, gan_mode: str = "vanilla", target_real: float = 1.0, target_fake: float = 0.0): + super().__init__() + self.gan_mode = gan_mode + self.register_buffer("real_label", torch.tensor(target_real)) + self.register_buffer("fake_label", torch.tensor(target_fake)) + + if gan_mode == "vanilla": + self.loss_fn = nn.BCEWithLogitsLoss() + elif gan_mode == "lsgan": + self.loss_fn = nn.MSELoss() + elif gan_mode == "wgangp": + self.loss_fn = None + else: + raise ValueError(f"Unknown GAN mode: {gan_mode}") + + def forward(self, prediction: torch.Tensor, target_is_real: bool) -> torch.Tensor: + if self.gan_mode in ["vanilla", "lsgan"]: + target = self.real_label if target_is_real else self.fake_label + target = target.expand_as(prediction) + return self.loss_fn(prediction, target) + elif self.gan_mode == "wgangp": + return -prediction.mean() if target_is_real else prediction.mean() + + +class ImageGAN(nn.Module): + """Complete GAN model for image translation.""" + + def __init__( + self, + input_channels: int = 3, + output_channels: int = 3, + gan_mode: str = "vanilla", + lambda_L1: float = 100.0, + use_cuda: bool = True, + ): + super().__init__() + self.generator = GeneratorUNet(input_channels, output_channels) + self.discriminator = DiscriminatorPatchGAN(input_channels + output_channels) + self.gan_loss = GANLoss(gan_mode) + self.l1_loss = nn.L1Loss() + self.lambda_L1 = lambda_L1 + + self.device = torch.device("cuda" if use_cuda and torch.cuda.is_available() else "cpu") + self.to(self.device) + + def forward(self, yandex_image): + """Generate Google image from Yandex.""" + return self.generator(yandex_image) + + def generator_step(self, yandex_img, real_google_img): + """Compute generator losses.""" + fake_google = self.generator(yandex_img) + fake_pred = self.discriminator(yandex_img, fake_google) + gan_loss = self.gan_loss(fake_pred, True) + l1_loss = self.l1_loss(fake_google, real_google_img) * self.lambda_L1 + total_loss = gan_loss + l1_loss + return total_loss, gan_loss, l1_loss + + def discriminator_step(self, yandex_img, real_google_img, fake_google_img): + """Compute discriminator losses.""" + real_pred = self.discriminator(yandex_img, real_google_img) + real_loss = self.gan_loss(real_pred, True) + fake_pred = self.discriminator(yandex_img, fake_google_img.detach()) + fake_loss = self.gan_loss(fake_pred, False) + total_loss = (real_loss + fake_loss) * 0.5 + return total_loss, real_loss, fake_loss + + +def create_gan( + input_channels: int = 3, + output_channels: int = 3, + gan_mode: str = "vanilla", + lambda_L1: float = 100.0, + use_cuda: bool = True, +) -> ImageGAN: + """Create a GAN model.""" + return ImageGAN( + input_channels=input_channels, + output_channels=output_channels, + gan_mode=gan_mode, + lambda_L1=lambda_L1, + use_cuda=use_cuda, + ) + + +def initialize_weights(model: nn.Module): + """Initialize model weights.""" + for m in model.modules(): + if isinstance(m, nn.Conv2d): + nn.init.normal_(m.weight.data, 0.0, 0.02) + elif isinstance(m, nn.BatchNorm2d): + nn.init.normal_(m.weight.data, 1.0, 0.02) + nn.init.constant_(m.bias.data, 0.0) + + +if __name__ == "__main__": + # Quick test + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + model = create_gan(use_cuda=False) + print(f"Model created on {model.device}") + + # Test forward pass + test_input = torch.randn(2, 3, 256, 256).to(model.device) + output = model(test_input) + print(f"Output shape: {output.shape}") + + # Count parameters + gen_params = sum(p.numel() for p in model.generator.parameters()) + disc_params = sum(p.numel() for p in model.discriminator.parameters()) + print(f"Generator: {gen_params:,} params, Discriminator: {disc_params:,} params") \ No newline at end of file diff --git a/models/GAN/test_gan.py b/models/GAN/test_gan.py deleted file mode 100644 index e76469b..0000000 --- a/models/GAN/test_gan.py +++ /dev/null @@ -1,349 +0,0 @@ -import os -import sys - -import torch -import torch.nn as nn - -# Добавляем путь к модулю -sys.path.append( - os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) -) - -from gan import ( - DiscriminatorPatchGAN, - GeneratorUNet, - ImageGAN, - create_image_gan, - initialize_gan_weights, -) - - -def test_generator(): - """Тестирование генератора""" - print("=" * 60) - print("Тестирование генератора...") - print("=" * 60) - - # Создаем генератор - generator = GeneratorUNet(in_channels=3, out_channels=3) - - # Инициализируем веса - generator.apply( - lambda m: ( - nn.init.normal_(m.weight.data, 0.0, 0.02) - if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d) - else None - ) - ) - - # Создаем тестовый входной тензор (Yandex изображение) - batch_size = 2 - height, width = 700, 700 - yandex_image = torch.randn(batch_size, 3, height, width) - - print(f"Размер входного изображения: {yandex_image.shape}") - print( - f"Количество параметров генератора: {sum(p.numel() for p in generator.parameters()):,}" - ) - - # Прямой проход - with torch.no_grad(): - generated_image = generator(yandex_image) - - print(f"Размер сгенерированного изображения: {generated_image.shape}") - print( - f"Диапазон значений сгенерированного изображения: [{generated_image.min():.3f}, {generated_image.max():.3f}]" - ) - - # Проверка размеров - assert generated_image.shape == (batch_size, 3, height, width), ( - f"Ожидался размер {(batch_size, 3, height, width)}, получен {generated_image.shape}" - ) - - print("✓ Генератор работает корректно!") - return generator - - -def test_discriminator(): - """Тестирование дискриминатора""" - print("\n" + "=" * 60) - print("Тестирование дискриминатора...") - print("=" * 60) - - # Создаем дискриминатор - discriminator = DiscriminatorPatchGAN(in_channels=6) # 3 + 3 канала - - # Инициализируем веса - discriminator.apply( - lambda m: ( - nn.init.normal_(m.weight.data, 0.0, 0.02) - if isinstance(m, nn.Conv2d) - else None - ) - ) - - # Создаем тестовые тензоры - batch_size = 2 - height, width = 700, 700 - - yandex_image = torch.randn(batch_size, 3, height, width) - google_image = torch.randn(batch_size, 3, height, width) - - print(f"Размер Yandex изображения: {yandex_image.shape}") - print(f"Размер Google изображения: {google_image.shape}") - print( - f"Количество параметров дискриминатора: {sum(p.numel() for p in discriminator.parameters()):,}" - ) - - # Прямой проход - with torch.no_grad(): - prediction = discriminator(yandex_image, google_image) - - print(f"Размер выхода дискриминатора: {prediction.shape}") - print( - f"Диапазон значений предсказания: [{prediction.min():.3f}, {prediction.max():.3f}]" - ) - - # Проверка размеров (PatchGAN выдает карту вероятностей) - expected_height = 43 # Для изображения 700x700 после 4 downsampling блоков - expected_width = 43 - assert prediction.shape == (batch_size, 1, expected_height, expected_width), ( - f"Ожидался размер {(batch_size, 1, expected_height, expected_width)}, получен {prediction.shape}" - ) - - print( - f"✓ Дискриминатор работает корректно! Выходной размер: {prediction.shape[2]}x{prediction.shape[3]}" - ) - return discriminator - - -def test_gan_model(): - """Тестирование полной GAN модели""" - print("\n" + "=" * 60) - print("Тестирование полной GAN модели...") - print("=" * 60) - - # Создаем GAN модель - gan = ImageGAN( - input_channels=3, - output_channels=3, - gan_mode="vanilla", - lambda_L1=100.0, - use_cuda=False, # Тестируем на CPU для простоты - ) - - print(f"Устройство модели: {gan.device}") - print( - f"Количество параметров генератора: {sum(p.numel() for p in gan.generator.parameters()):,}" - ) - print( - f"Количество параметров дискриминатора: {sum(p.numel() for p in gan.discriminator.parameters()):,}" - ) - print(f"Общее количество параметров: {sum(p.numel() for p in gan.parameters()):,}") - - # Создаем тестовые данные - batch_size = 2 - height, width = 700, 700 - - yandex_image = torch.randn(batch_size, 3, height, width) - real_google_image = torch.randn(batch_size, 3, height, width) - - print(f"\nТестирование прямого прохода...") - with torch.no_grad(): - generated_image = gan(yandex_image) - - print(f"Размер сгенерированного изображения: {generated_image.shape}") - - print(f"\nТестирование шага генератора...") - gan.train_mode() - - # Тестируем шаг генератора - total_loss, gan_loss, l1_loss = gan.generator_step(yandex_image, real_google_image) - - print(f"Общие потери генератора: {total_loss.item():.6f}") - print(f"Потери GAN: {gan_loss.item():.6f}") - print(f"Потери L1: {l1_loss.item():.6f}") - - print(f"\nТестирование шага дискриминатора...") - # Создаем сгенерированное изображение для дискриминатора - with torch.no_grad(): - fake_google_image = gan.generator(yandex_image) - - total_d_loss, real_loss, fake_loss = gan.discriminator_step( - yandex_image, real_google_image, fake_google_image - ) - - print(f"Общие потери дискриминатора: {total_d_loss.item():.6f}") - print(f"Потери на реальных изображениях: {real_loss.item():.6f}") - print(f"Потери на сгенерированных изображениях: {fake_loss.item():.6f}") - - print(f"\nТестирование режимов обучения/оценки...") - gan.eval_mode() - print(f"Генератор в режиме eval: {not gan.generator.training}") - print(f"Дискриминатор в режиме eval: {not gan.discriminator.training}") - - gan.train_mode() - print(f"Генератор в режиме train: {gan.generator.training}") - print(f"Дискриминатор в режиме train: {gan.discriminator.training}") - - print("\n✓ Полная GAN модель работает корректно!") - return gan - - -def test_factory_function(): - """Тестирование фабричной функции""" - print("\n" + "=" * 60) - print("Тестирование фабричной функции...") - print("=" * 60) - - # Тестируем разные режимы GAN - for gan_mode in ["vanilla", "lsgan"]: - print(f"\nСоздание GAN в режиме '{gan_mode}'...") - gan = create_image_gan( - input_channels=3, - output_channels=3, - gan_mode=gan_mode, - lambda_L1=100.0, - use_cuda=False, - ) - - print(f" Режим GAN: {gan.gan_loss.gan_mode}") - print(f" Вес L1 потерь: {gan.lambda_L1}") - print(f" Устройство: {gan.device}") - - # Быстрая проверка прямого прохода - batch_size = 1 - yandex_image = torch.randn(batch_size, 3, 700, 700) - - with torch.no_grad(): - generated = gan(yandex_image) - - print(f" Размер выхода: {generated.shape}") - print(f" ✓ GAN в режиме '{gan_mode}' создан успешно") - - print("\n✓ Фабричная функция работает корректно!") - - -def test_weights_initialization(): - """Тестирование инициализации весов""" - print("\n" + "=" * 60) - print("Тестирование инициализации весов...") - print("=" * 60) - - # Создаем модели - generator = GeneratorUNet(3, 3) - discriminator = DiscriminatorPatchGAN(6) - - # Инициализируем веса - initialize_gan_weights(generator, discriminator) - - # Проверяем средние значения весов - def check_weights_mean(model, model_name): - conv_weights = [] - for name, param in model.named_parameters(): - if "weight" in name and ( - "conv" in name.lower() or "Conv" in str(param.__class__) - ): - conv_weights.append(param.data.mean().item()) - - if conv_weights: - avg_mean = sum(conv_weights) / len(conv_weights) - print(f" Среднее значение весов Conv слоев в {model_name}: {avg_mean:.6f}") - # Проверяем, что веса инициализированы около 0 - assert abs(avg_mean) < 0.1, f"Веса {model_name} не инициализированы около 0" - - check_weights_mean(generator, "генераторе") - check_weights_mean(discriminator, "дискриминаторе") - - print("✓ Инициализация весов работает корректно!") - - -def test_memory_usage(): - """Тестирование использования памяти""" - print("\n" + "=" * 60) - print("Тестирование использования памяти...") - print("=" * 60) - - import os - - import psutil - - # Получаем текущее использование памяти - process = psutil.Process(os.getpid()) - memory_before = process.memory_info().rss / 1024 / 1024 # в MB - - print(f"Память до создания моделей: {memory_before:.2f} MB") - - # Создаем несколько моделей - models = [] - for i in range(3): - gan = create_image_gan(use_cuda=False) - models.append(gan) - - # Делаем тестовый проход - batch_size = 1 - yandex_image = torch.randn(batch_size, 3, 700, 700) - real_google_image = torch.randn(batch_size, 3, 700, 700) - - with torch.no_grad(): - _ = gan(yandex_image) - _ = gan.generator_step(yandex_image, real_google_image) - - memory_after = process.memory_info().rss / 1024 / 1024 # в MB - memory_used = memory_after - memory_before - - print(f"Память после создания моделей: {memory_after:.2f} MB") - print(f"Использовано памяти: {memory_used:.2f} MB") - - # Очищаем модели - del models - import gc - - gc.collect() - - memory_final = process.memory_info().rss / 1024 / 1024 - print(f"Память после очистки: {memory_final:.2f} MB") - - print("✓ Тестирование памяти завершено!") - - -def main(): - """Основная функция тестирования""" - print("Начало тестирования GAN архитектуры для преобразования Yandex → Google") - print("Размер изображения: 700x700 пикселей") - print("=" * 60) - - try: - # Запускаем все тесты - test_generator() - test_discriminator() - test_gan_model() - test_factory_function() - test_weights_initialization() - test_memory_usage() - - print("\n" + "=" * 60) - print("ВСЕ ТЕСТЫ ПРОЙДЕНЫ УСПЕШНО! 🎉") - print("=" * 60) - print("\nАрхитектура GAN готова к использованию для преобразования") - print("изображений из стиля Yandex в стиль Google.") - print("\nОсновные характеристики:") - print(" • Генератор: U-Net архитектура") - print(" • Дискриминатор: PatchGAN (43x43 патчей)") - print(" • Размер входных/выходных изображений: 700x700") - print(" • Поддержка режимов: vanilla, lsgan") - print(" • L1 регуляризация для сохранения структуры") - - except Exception as e: - print(f"\n❌ Ошибка при тестировании: {e}") - import traceback - - traceback.print_exc() - return 1 - - return 0 - - -if __name__ == "__main__": - exit_code = main() - sys.exit(exit_code) diff --git a/models/GAN/test_trainer.py b/models/GAN/test_trainer.py deleted file mode 100644 index b14fac4..0000000 --- a/models/GAN/test_trainer.py +++ /dev/null @@ -1,342 +0,0 @@ -""" -Тестовый скрипт для проверки GAN trainer. -""" - -import sys -from pathlib import Path - -import numpy as np -import torch -import torch.nn as nn -from torch.utils.data import DataLoader, Dataset - -# Добавляем путь к модулям -sys.path.append(str(Path(__file__).parent.parent.parent)) - -from models.GAN.gan import create_image_gan -from models.GAN.trainer import GANTrainer - - -class SimpleDataset(Dataset): - """Простой датасет для тестирования.""" - - def __init__(self, num_samples=100, image_size=(256, 256)): - self.num_samples = num_samples - self.image_size = image_size - - def __len__(self): - return self.num_samples - - def __getitem__(self, idx): - # Создаем случайные изображения для тестирования - yandex_img = torch.randn(3, self.image_size[0], self.image_size[1]) - google_img = torch.randn(3, self.image_size[0], self.image_size[1]) - - return {"yandex_img": yandex_img, "google_img": google_img} - - -def test_gan_model(): - """Тестирование GAN модели.""" - print("Тестирование GAN модели...") - - # Создаем модель - model = create_image_gan( - input_channels=3, - output_channels=3, - gan_mode="vanilla", - lambda_L1=100.0, - use_cuda=False, # Используем CPU для тестирования - ) - - # Тестируем forward pass - batch_size = 2 - image_size = (256, 256) - yandex_input = torch.randn(batch_size, 3, *image_size) - - with torch.no_grad(): - output = model(yandex_input) - - print(f"Входной размер: {yandex_input.shape}") - print(f"Выходной размер: {output.shape}") - print(f"Диапазон выходных значений: [{output.min():.3f}, {output.max():.3f}]") - - # Проверяем, что выход в диапазоне [-1, 1] (из-за Tanh) - assert output.min() >= -1.0 and output.max() <= 1.0, "Выход не в диапазоне [-1, 1]" - print("✓ Forward pass работает корректно") - - return model - - -def test_generator_step(): - """Тестирование шага генератора.""" - print("\nТестирование шага генератора...") - - model = create_image_gan(use_cuda=False) - model.train_mode() - - batch_size = 2 - yandex_img = torch.randn(batch_size, 3, 256, 256) - google_img = torch.randn(batch_size, 3, 256, 256) - - # Тестируем generator_step - total_loss, gan_loss, l1_loss = model.generator_step(yandex_img, google_img) - - print(f"Total loss: {total_loss.item():.6f}") - print(f"GAN loss: {gan_loss.item():.6f}") - print(f"L1 loss: {l1_loss.item():.6f}") - - assert total_loss.item() > 0, "Потери должны быть положительными" - print("✓ Шаг генератора работает корректно") - - -def test_discriminator_step(): - """Тестирование шага дискриминатора.""" - print("\nТестирование шага дискриминатора...") - - model = create_image_gan(use_cuda=False) - model.train_mode() - - batch_size = 2 - yandex_img = torch.randn(batch_size, 3, 256, 256) - google_img = torch.randn(batch_size, 3, 256, 256) - - # Генерируем fake изображение - with torch.no_grad(): - fake_google_img = model.generator(yandex_img) - - # Тестируем discriminator_step - total_loss, real_loss, fake_loss = model.discriminator_step( - yandex_img, google_img, fake_google_img - ) - - print(f"Total loss: {total_loss.item():.6f}") - print(f"Real loss: {real_loss.item():.6f}") - print(f"Fake loss: {fake_loss.item():.6f}") - - assert total_loss.item() > 0, "Потери должны быть положительными" - print("✓ Шаг дискриминатора работает корректно") - - -def test_trainer_initialization(): - """Тестирование инициализации тренера.""" - print("\nТестирование инициализации тренера...") - - # Создаем модель - model = create_image_gan(use_cuda=False) - - # Создаем даталоадеры - train_dataset = SimpleDataset(num_samples=50) - val_dataset = SimpleDataset(num_samples=10) - - train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True) - val_loader = DataLoader(val_dataset, batch_size=2, shuffle=False) - - # Конфигурация - config = { - "learning_rate": 2e-4, - "beta1": 0.5, - "beta2": 0.999, - "output_dir": "test_runs/gan", - "early_stopping_patience": 10, - } - - # Создаем тренер - device = torch.device("cpu") - trainer = GANTrainer( - model=model, - train_loader=train_loader, - val_loader=val_loader, - device=device, - config=config, - ) - - print(f"Тренер создан успешно") - print(f"Оптимизатор генератора: {type(trainer.optimizer_G).__name__}") - print(f"Оптимизатор дискриминатора: {type(trainer.optimizer_D).__name__}") - print(f"Выходная директория: {trainer.output_dir}") - - assert trainer.output_dir.exists(), "Выходная директория не создана" - print("✓ Тренер инициализирован корректно") - - return trainer, train_loader, val_loader - - -def test_train_epoch(): - """Тестирование одной эпохи обучения.""" - print("\nТестирование одной эпохи обучения...") - - model = create_image_gan(use_cuda=False) - train_dataset = SimpleDataset(num_samples=20) - train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True) - val_loader = DataLoader(SimpleDataset(num_samples=5), batch_size=2, shuffle=False) - - config = { - "learning_rate": 2e-4, - "beta1": 0.5, - "beta2": 0.999, - "output_dir": "test_runs/gan", - } - - device = torch.device("cpu") - trainer = GANTrainer( - model=model, - train_loader=train_loader, - val_loader=val_loader, - device=device, - config=config, - ) - - # Запускаем одну эпоху обучения - avg_g_loss, avg_d_loss = trainer.train_epoch() - - print(f"Средние потери за эпоху:") - print(f" Генератор: {avg_g_loss:.6f}") - print(f" Дискриминатор: {avg_d_loss:.6f}") - - assert avg_g_loss > 0, "Потери генератора должны быть положительными" - assert avg_d_loss > 0, "Потери дискриминатора должны быть положительными" - print("✓ Эпоха обучения завершена успешно") - - -def test_validation(): - """Тестирование валидации.""" - print("\nТестирование валидации...") - - model = create_image_gan(use_cuda=False) - train_loader = DataLoader(SimpleDataset(num_samples=10), batch_size=2, shuffle=True) - val_loader = DataLoader(SimpleDataset(num_samples=5), batch_size=2, shuffle=False) - - config = { - "learning_rate": 2e-4, - "beta1": 0.5, - "beta2": 0.999, - "output_dir": "test_runs/gan", - } - - device = torch.device("cpu") - trainer = GANTrainer( - model=model, - train_loader=train_loader, - val_loader=val_loader, - device=device, - config=config, - ) - - # Запускаем валидацию - val_g_loss, val_d_loss = trainer.validate() - - print(f"Потери на валидации:") - print(f" Генератор: {val_g_loss:.6f}") - print(f" Дискриминатор: {val_d_loss:.6f}") - - assert val_g_loss > 0, "Потери генератора должны быть положительными" - assert val_d_loss > 0, "Потери дискриминатора должны быть положительными" - print("✓ Валидация завершена успешно") - - -def test_checkpoint_saving(): - """Тестирование сохранения чекпоинтов.""" - print("\nТестирование сохранения чекпоинтов...") - - model = create_image_gan(use_cuda=False) - train_loader = DataLoader(SimpleDataset(num_samples=10), batch_size=2, shuffle=True) - val_loader = DataLoader(SimpleDataset(num_samples=5), batch_size=2, shuffle=False) - - config = { - "learning_rate": 2e-4, - "beta1": 0.5, - "beta2": 0.999, - "output_dir": "test_runs/gan_checkpoint", - } - - device = torch.device("cpu") - trainer = GANTrainer( - model=model, - train_loader=train_loader, - val_loader=val_loader, - device=device, - config=config, - ) - - # Сохраняем чекпоинт - trainer.save_checkpoint(is_best=True) - - # Проверяем, что файлы созданы - checkpoint_files = list(trainer.output_dir.glob("*.pth")) - print(f"Создано файлов чекпоинтов: {len(checkpoint_files)}") - - for file in checkpoint_files: - print(f" - {file.name}") - - assert len(checkpoint_files) > 0, "Файлы чекпоинтов не созданы" - print("✓ Чекпоинты сохранены успешно") - - # Тестируем загрузку чекпоинта - checkpoint_path = checkpoint_files[0] - print(f"\nТестируем загрузку чекпоинта: {checkpoint_path}") - - # Создаем новую модель и тренер - new_model = create_image_gan(use_cuda=False) - new_trainer = GANTrainer( - model=new_model, - train_loader=train_loader, - val_loader=val_loader, - device=device, - config=config, - ) - - # Загружаем чекпоинт - new_trainer.load_checkpoint(str(checkpoint_path)) - - print(f"Загружен чекпоинт эпохи: {new_trainer.current_epoch + 1}") - print("✓ Чекпоинт загружен успешно") - - -def main(): - """Основная функция тестирования.""" - print("=" * 60) - print("Начало тестирования GAN trainer") - print("=" * 60) - - try: - # Запускаем все тесты - test_gan_model() - test_generator_step() - test_discriminator_step() - test_trainer_initialization() - test_train_epoch() - test_validation() - test_checkpoint_saving() - - print("\n" + "=" * 60) - print("Все тесты пройдены успешно! ✓") - print("=" * 60) - - except Exception as e: - print(f"\nОшибка при тестировании: {e}") - import traceback - - traceback.print_exc() - return 1 - - return 0 - - -if __name__ == "__main__": - # Очищаем тестовые директории - import shutil - - test_dirs = ["test_runs/gan", "test_runs/gan_checkpoint"] - for dir_path in test_dirs: - if Path(dir_path).exists(): - shutil.rmtree(dir_path) - - # Запускаем тесты - exit_code = main() - - # Очищаем после тестов - for dir_path in test_dirs: - if Path(dir_path).exists(): - shutil.rmtree(dir_path) - - exit(exit_code) diff --git a/models/GAN/train_example.py b/models/GAN/train_example.py deleted file mode 100644 index c3efdc3..0000000 --- a/models/GAN/train_example.py +++ /dev/null @@ -1,347 +0,0 @@ -""" -Пример обучения GAN модели для преобразования Yandex → Google карт. - -Этот скрипт показывает, как использовать GANTrainer для обучения модели. -""" - -import sys -from pathlib import Path - -import torch -from torch.utils.data import DataLoader - -# Добавляем путь к модулям -sys.path.append(str(Path(__file__).parent.parent.parent)) - -from models.GAN.gan import create_image_gan -from models.GAN.trainer import GANTrainer - - -def create_simple_config(): - """Создает простую конфигурацию для обучения.""" - config = { - # Параметры оптимизатора - "learning_rate": 2e-4, - "beta1": 0.5, - "beta2": 0.999, - # Параметры обучения - "batch_size": 4, - "epochs": 100, - # Параметры GAN - "gan_mode": "vanilla", # "vanilla", "lsgan", или "wgangp" - "lambda_L1": 100.0, # Вес L1 потерь - # Регуляризация - "grad_clip": 1.0, - # Ранняя остановка - "early_stopping_patience": 20, - # Выходные данные - "output_dir": "runs/gan_training", - # Логирование - "log_interval": 10, # Логировать каждые N батчей - "save_interval": 5, # Сохранять чекпоинт каждые N эпох - } - return config - - -def create_advanced_config(): - """Создает расширенную конфигурацию для обучения.""" - config = { - # Параметры оптимизатора - "learning_rate": 2e-4, - "beta1": 0.5, - "beta2": 0.999, - # Планировщик learning rate - "use_scheduler": True, - "scheduler_type": "linear", # "linear", "cosine", или "plateau" - "scheduler_start_epoch": 50, - "scheduler_end_epoch": 100, - # Параметры обучения - "batch_size": 8, - "epochs": 200, - # Параметры GAN - "gan_mode": "lsgan", # LSGAN обычно более стабилен - "lambda_L1": 100.0, - # Аугментация данных - "augmentation": { - "random_crop": True, - "crop_size": 256, - "random_flip": True, - "color_jitter": True, - "brightness": 0.2, - "contrast": 0.2, - "saturation": 0.2, - "hue": 0.1, - }, - # Регуляризация - "grad_clip": 1.0, - "weight_decay": 1e-4, - # Ранняя остановка - "early_stopping_patience": 30, - "early_stopping_min_delta": 1e-4, - # Выходные данные - "output_dir": "runs/gan_advanced", - # Логирование - "log_interval": 20, - "save_interval": 10, - "save_best_only": True, # Сохранять только лучшую модель - # Визуализация - "visualize_samples": True, - "num_visualize": 4, - "visualize_interval": 5, # Визуализировать каждые N эпох - } - return config - - -def print_config_summary(config): - """Печатает сводку конфигурации.""" - print("=" * 60) - print("Конфигурация обучения GAN") - print("=" * 60) - - print(f"\nПараметры модели:") - print(f" Режим GAN: {config.get('gan_mode', 'vanilla')}") - print(f" Вес L1 потерь: {config.get('lambda_L1', 100.0)}") - - print(f"\nПараметры обучения:") - print(f" Learning rate: {config.get('learning_rate', 2e-4)}") - print(f" Batch size: {config.get('batch_size', 4)}") - print(f" Эпох: {config.get('epochs', 100)}") - print(f" Beta1: {config.get('beta1', 0.5)}") - print(f" Beta2: {config.get('beta2', 0.999)}") - - if config.get("use_scheduler", False): - print(f" Планировщик LR: {config.get('scheduler_type', 'linear')}") - - print(f"\nРегуляризация:") - print(f" Gradient clipping: {config.get('grad_clip', 1.0)}") - if "weight_decay" in config: - print(f" Weight decay: {config['weight_decay']}") - - print(f"\nРанняя остановка:") - if config.get("early_stopping_patience", 0) > 0: - print(f" Patience: {config['early_stopping_patience']} эпох") - if "early_stopping_min_delta" in config: - print(f" Min delta: {config['early_stopping_min_delta']}") - - print(f"\nВыходные данные:") - print(f" Директория: {config.get('output_dir', 'runs/gan')}") - print(f" Интервал сохранения: {config.get('save_interval', 5)} эпох") - - print(f"\nЛогирование:") - print(f" Интервал логирования: {config.get('log_interval', 10)} батчей") - - print("=" * 60) - - -def setup_training(): - """Настраивает обучение.""" - print("Настройка обучения GAN...") - - # Выбираем конфигурацию - use_advanced = False # Измените на True для расширенной конфигурации - - if use_advanced: - config = create_advanced_config() - else: - config = create_simple_config() - - # Печатаем сводку конфигурации - print_config_summary(config) - - # Устройство - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - print(f"\nИспользуемое устройство: {device}") - - if device.type == "cuda": - print(f" GPU: {torch.cuda.get_device_name(0)}") - print( - f" Память: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB" - ) - - # Создаем модель - print("\nСоздание модели...") - model = create_image_gan( - input_channels=3, - output_channels=3, - gan_mode=config.get("gan_mode", "vanilla"), - lambda_L1=config.get("lambda_L1", 100.0), - use_cuda=(device.type == "cuda"), - ) - - # Создаем даталоадеры - print("\nСоздание даталоадеров...") - # ЗАМЕНИТЕ ЭТО НА ВАШИ РЕАЛЬНЫЕ ДАННЫЕ - # Пример: - # from your_dataset_module import create_data_loaders - # train_loader, val_loader = create_data_loaders( - # data_dir="ваш/путь/к/данным", - # batch_size=config["batch_size"], - # image_size=(256, 256), - # augment=config.get("augmentation", None), - # ) - - # Для примера создаем фиктивные даталоадеры - # ВАЖНО: Замените это на реальные данные! - print(" ВНИМАНИЕ: Используются фиктивные данные!") - print(" Замените на реальные даталоадеры!") - - import numpy as np - from torch.utils.data import Dataset - - class DummyDataset(Dataset): - def __init__(self, num_samples=100): - self.num_samples = num_samples - - def __len__(self): - return self.num_samples - - def __getitem__(self, idx): - # Фиктивные данные для примера - yandex_img = torch.randn(3, 256, 256) - google_img = torch.randn(3, 256, 256) - return {"yandex_img": yandex_img, "google_img": google_img} - - train_dataset = DummyDataset(num_samples=100) - val_dataset = DummyDataset(num_samples=20) - - train_loader = DataLoader( - train_dataset, - batch_size=config.get("batch_size", 4), - shuffle=True, - num_workers=0, - ) - - val_loader = DataLoader( - val_dataset, - batch_size=config.get("batch_size", 4), - shuffle=False, - num_workers=0, - ) - - print(f" Размер обучающего набора: {len(train_dataset)}") - print(f" Размер валидационного набора: {len(val_dataset)}") - print(f" Батчей в эпохе: {len(train_loader)}") - - # Создаем тренер - print("\nСоздание тренера...") - trainer = GANTrainer( - model=model, - train_loader=train_loader, - val_loader=val_loader, - device=device, - config=config, - ) - - return trainer, config - - -def train_model(trainer, config): - """Запускает обучение модели.""" - print("\n" + "=" * 60) - print("Начало обучения") - print("=" * 60) - - epochs = config.get("epochs", 100) - - try: - trainer.train(num_epochs=epochs) - - print("\n" + "=" * 60) - print("Обучение завершено успешно!") - print("=" * 60) - - except KeyboardInterrupt: - print("\n\nОбучение прервано пользователем.") - print("Сохранение текущего состояния...") - trainer.save_checkpoint(is_best=False) - - except Exception as e: - print(f"\n\nОшибка при обучении: {e}") - import traceback - - traceback.print_exc() - - # Пытаемся сохранить чекпоинт при ошибке - try: - trainer.save_checkpoint(is_best=False) - print("Текущее состояние сохранено.") - except: - print("Не удалось сохранить состояние.") - - -def evaluate_model(trainer, test_loader=None): - """Оценивает обученную модель.""" - print("\n" + "=" * 60) - print("Оценка модели") - print("=" * 60) - - if test_loader is None: - print("Тестовый даталоадер не предоставлен.") - print("Используется валидационный даталоадер для оценки.") - test_loader = trainer.val_loader - - metrics = trainer.evaluate(test_loader) - - print("\nМетрики оценки:") - for key, value in metrics.items(): - print(f" {key}: {value:.6f}") - - return metrics - - -def generate_examples(model, device, num_examples=4): - """Генерирует примеры преобразования.""" - print("\n" + "=" * 60) - print("Генерация примеров") - print("=" * 60) - - model.eval() - - # Создаем фиктивные входные данные - yandex_input = torch.randn(num_examples, 3, 256, 256).to(device) - - with torch.no_grad(): - google_output = model(yandex_input) - - print(f"Сгенерировано {num_examples} примеров") - print(f"Размер входных данных: {yandex_input.shape}") - print(f"Размер выходных данных: {google_output.shape}") - - # Сохраняем примеры (в реальном коде сохраняйте как изображения) - print("\nПримеры сгенерированы.") - print("В реальном коде сохраняйте их как изображения для визуализации.") - - return yandex_input, google_output - - -def main(): - """Основная функция.""" - print("=" * 60) - print("Пример обучения GAN для преобразования Yandex → Google") - print("=" * 60) - - # Настройка - trainer, config = setup_training() - - # Обучение - train_model(trainer, config) - - # Оценка (требует реальных тестовых данных) - # evaluate_model(trainer) - - # Генерация примеров - # generate_examples(trainer.model, trainer.device) - - print("\n" + "=" * 60) - print("Скрипт завершен.") - print("=" * 60) - print("\nСледующие шаги:") - print("1. Замените фиктивные даталоадеры на реальные данные") - print("2. Настройте параметры в create_simple_config()") - print("3. Запустите обучение с реальными данными") - print("4. Визуализируйте результаты") - print("=" * 60) - - -if __name__ == "__main__": - main() diff --git a/models/GAN/trainer.py b/models/GAN/trainer.py index a51a5fb..77fb0b0 100644 --- a/models/GAN/trainer.py +++ b/models/GAN/trainer.py @@ -1,415 +1,191 @@ -import json -import time -from pathlib import Path -from typing import Any, Dict, List, Tuple - -import numpy as np -import torch -import torch.nn as nn -import torch.optim as optim -from torch.utils.data import DataLoader -from torch.utils.tensorboard import SummaryWriter -from tqdm import tqdm - -# Type aliases -ModuleType = nn.Module -TensorType = torch.Tensor - - -class GANTrainer: - """Trainer class for GAN model.""" - - def __init__( - self, - model: ModuleType, - train_loader: DataLoader, - val_loader: DataLoader, - device: torch.device, - config: Dict[str, Any], - ): - """ - Initialize the GAN trainer. - - Args: - model: GAN model (ImageGAN) - train_loader: Training data loader - val_loader: Validation data loader - device: Device to run training on - config: Training configuration dictionary - """ - self.model = model.to(device) - self.train_loader = train_loader - self.val_loader = val_loader - self.device = device - self.config = config - - # Optimizers - lr = config.get("learning_rate", 2e-4) - beta1 = config.get("beta1", 0.5) - beta2 = config.get("beta2", 0.999) - - # Separate optimizers for generator and discriminator - # Note: self.model is expected to have .generator and .discriminator attributes - self.optimizer_G = optim.Adam( - self.model.generator.parameters(), lr=lr, betas=(beta1, beta2) - ) - self.optimizer_D = optim.Adam( - self.model.discriminator.parameters(), lr=lr, betas=(beta1, beta2) - ) - - # Training state - self.current_epoch = 0 - self.best_val_loss = float("inf") - self.g_losses: List[float] = [] - self.d_losses: List[float] = [] - self.val_g_losses: List[float] = [] - self.val_d_losses: List[float] = [] - - # Create output directory - self.output_dir = Path(config.get("output_dir", "runs/gan")) - self.output_dir.mkdir(parents=True, exist_ok=True) - - # TensorBoard writer - self.writer = SummaryWriter(log_dir=self.output_dir / "tensorboard") - - # Save configuration - config_path = self.output_dir / "config.json" - with open(config_path, "w") as f: - json.dump(config, f, indent=2) - - print(f"Training configuration saved to {config_path}") - # Access parameters through the model's generator and discriminator - generator_params = sum(p.numel() for p in self.model.generator.parameters()) - discriminator_params = sum( - p.numel() for p in self.model.discriminator.parameters() - ) - - print(f"Generator has {generator_params:,} parameters") - print(f"Discriminator has {discriminator_params:,} parameters") - - def train_epoch(self) -> Tuple[float, float]: - """ - Train for one epoch. - - Returns: - Tuple of (average generator loss, average discriminator loss) - """ - self.model.train() - total_g_loss = 0.0 - total_d_loss = 0.0 - num_batches = len(self.train_loader) - - progress_bar = tqdm(self.train_loader, desc=f"Epoch {self.current_epoch + 1}") - for batch_idx, batch in enumerate(progress_bar): - # Move data to device - yandex_img = batch["yandex_img"].to(self.device) - google_img = batch["google_img"].to(self.device) - - # ========== Train Discriminator ========== - self.optimizer_D.zero_grad() - - # Generate fake image - with torch.no_grad(): - fake_google_img = self.model.generator(yandex_img) - - # Discriminator loss - returns tuple of tensors - d_loss_tuple = self.model.discriminator_step( - yandex_img, google_img, fake_google_img - ) - d_loss, d_real_loss, d_fake_loss = d_loss_tuple - - # Backward and optimize discriminator - d_loss.backward() - self.optimizer_D.step() - - # ========== Train Generator ========== - self.optimizer_G.zero_grad() - - # Generate fake image - fake_google_img = self.model.generator(yandex_img) - - # Generator loss - returns tuple of tensors - g_loss_tuple = self.model.generator_step(yandex_img, google_img) - g_loss, g_gan_loss, g_l1_loss = g_loss_tuple - - # Backward and optimize generator - g_loss.backward() - self.optimizer_G.step() - - # Update statistics - total_g_loss += g_loss.item() - total_d_loss += d_loss.item() - - # Update progress bar - progress_bar.set_postfix( - { - "g_loss": g_loss.item(), - "d_loss": d_loss.item(), - "g_l1": g_l1_loss.item(), - "d_real": d_real_loss.item(), - "d_fake": d_fake_loss.item(), - } - ) - - # Log batch losses to TensorBoard - global_step = self.current_epoch * num_batches + batch_idx - self.writer.add_scalar("train/batch_g_loss", g_loss.item(), global_step) - self.writer.add_scalar("train/batch_d_loss", d_loss.item(), global_step) - self.writer.add_scalar( - "train/batch_g_l1_loss", g_l1_loss.item(), global_step - ) - self.writer.add_scalar( - "train/batch_d_real_loss", d_real_loss.item(), global_step - ) - self.writer.add_scalar( - "train/batch_d_fake_loss", d_fake_loss.item(), global_step - ) - - avg_g_loss = total_g_loss / num_batches - avg_d_loss = total_d_loss / num_batches - self.g_losses.append(avg_g_loss) - self.d_losses.append(avg_d_loss) - - return avg_g_loss, avg_d_loss - - def validate(self) -> Tuple[float, float]: - """ - Validate the model. - - Returns: - Tuple of (average generator validation loss, average discriminator validation loss) - """ - self.model.eval() - total_g_loss = 0.0 - total_d_loss = 0.0 - - progress_bar = tqdm(self.val_loader, desc="Validation") - for batch in progress_bar: - # Move data to device - yandex_img = batch["yandex_img"].to(self.device) - google_img = batch["google_img"].to(self.device) - - with torch.no_grad(): - # Generate fake image - fake_google_img = self.model.generator(yandex_img) - - # Generator loss - returns tuple of tensors - g_loss_tuple = self.model.generator_step(yandex_img, google_img) - g_loss, g_gan_loss, g_l1_loss = g_loss_tuple - - # Discriminator loss - returns tuple of tensors - d_loss_tuple = self.model.discriminator_step( - yandex_img, google_img, fake_google_img - ) - d_loss, d_real_loss, d_fake_loss = d_loss_tuple - - # Update statistics - total_g_loss += g_loss.item() - total_d_loss += d_loss.item() - - # Update progress bar - progress_bar.set_postfix({"g_loss": g_loss.item(), "d_loss": d_loss.item()}) - - avg_g_loss = total_g_loss / len(self.val_loader) - avg_d_loss = total_d_loss / len(self.val_loader) - self.val_g_losses.append(avg_g_loss) - self.val_d_losses.append(avg_d_loss) - - return avg_g_loss, avg_d_loss - - def save_checkpoint(self, is_best: bool = False): - """ - Save training checkpoint. - - Args: - is_best: Whether this is the best model so far - """ - checkpoint = { - "epoch": self.current_epoch, - "model_state_dict": self.model.state_dict(), - "optimizer_G_state_dict": self.optimizer_G.state_dict(), - "optimizer_D_state_dict": self.optimizer_D.state_dict(), - "g_losses": self.g_losses, - "d_losses": self.d_losses, - "val_g_losses": self.val_g_losses, - "val_d_losses": self.val_d_losses, - "best_val_loss": self.best_val_loss, - "config": self.config, - } - - # Save regular checkpoint - checkpoint_path = ( - self.output_dir / f"checkpoint_epoch_{self.current_epoch + 1}.pth" - ) - torch.save(checkpoint, checkpoint_path) - - # Save best model - if is_best: - best_path = self.output_dir / "model_best.pth" - torch.save(checkpoint, best_path) - print(f"Best model saved to {best_path}") - - def load_checkpoint(self, checkpoint_path: str, resume_training: bool = False): - """ - Load training checkpoint. - - Args: - checkpoint_path: Path to checkpoint file - resume_training: Если True, продолжить обучение с сохраненной эпохи - """ - checkpoint = torch.load(checkpoint_path, map_location=self.device) - - self.model.load_state_dict(checkpoint["model_state_dict"]) - self.optimizer_G.load_state_dict(checkpoint["optimizer_G_state_dict"]) - self.optimizer_D.load_state_dict(checkpoint["optimizer_D_state_dict"]) - - self.current_epoch = checkpoint["epoch"] - self.g_losses = checkpoint["g_losses"] - self.d_losses = checkpoint["d_losses"] - self.val_g_losses = checkpoint["val_g_losses"] - self.val_d_losses = checkpoint["val_d_losses"] - self.best_val_loss = checkpoint["best_val_loss"] - - if resume_training: - print(f"Resuming training from epoch {self.current_epoch + 1}") - else: - print(f"Loaded checkpoint from epoch {self.current_epoch + 1}") - - def train(self, num_epochs: int, start_epoch: int = 0): - """ - Train the model for specified number of epochs. - - Args: - num_epochs: Number of epochs to train - start_epoch: Starting epoch (useful when resuming training) - """ - print( - f"Starting GAN training for {num_epochs} epochs from epoch {start_epoch + 1}..." - ) - start_time = time.time() - - for epoch in range(start_epoch, start_epoch + num_epochs): - self.current_epoch = epoch - - # Train for one epoch - train_g_loss, train_d_loss = self.train_epoch() - - # Validate - val_g_loss, val_d_loss = self.validate() - - # Log to TensorBoard - self.writer.add_scalar("train/epoch_g_loss", train_g_loss, epoch) - self.writer.add_scalar("train/epoch_d_loss", train_d_loss, epoch) - self.writer.add_scalar("val/epoch_g_loss", val_g_loss, epoch) - self.writer.add_scalar("val/epoch_d_loss", val_d_loss, epoch) - - # Print epoch summary - print(f"\nEpoch {epoch + 1}/{num_epochs}:") - print(" Generator:") - print(f" Train Loss: {train_g_loss:.6f}") - print(f" Val Loss: {val_g_loss:.6f}") - print(" Discriminator:") - print(f" Train Loss: {train_d_loss:.6f}") - print(f" Val Loss: {val_d_loss:.6f}") - - # Save checkpoint - val_total_loss = val_g_loss + val_d_loss - is_best = val_total_loss < self.best_val_loss - if is_best: - self.best_val_loss = val_total_loss - - self.save_checkpoint(is_best=is_best) - - # Early stopping - if self.config.get("early_stopping_patience", 0) > 0: - val_losses = [ - g + d for g, d in zip(self.val_g_losses, self.val_d_losses) - ] - if ( - epoch - np.argmin(val_losses) - >= self.config["early_stopping_patience"] - ): - print(f"Early stopping at epoch {epoch + 1}") - break - - # Training completed - training_time = time.time() - start_time - print(f"\nTraining completed in {training_time:.2f} seconds") - print(f"Best validation total loss: {self.best_val_loss:.6f}") - - # Save final model - final_model_path = self.output_dir / "model_final.pth" - torch.save(self.model.state_dict(), final_model_path) - print(f"Final model saved to {final_model_path}") - - # Save training history - history_path = self.output_dir / "training_history.json" - history = { - "g_losses": self.g_losses, - "d_losses": self.d_losses, - "val_g_losses": self.val_g_losses, - "val_d_losses": self.val_d_losses, - "best_val_loss": self.best_val_loss, - "total_epochs": self.current_epoch + 1, - } - with open(history_path, "w") as f: - json.dump(history, f, indent=2) - print(f"Training history saved to {history_path}") - - # Close TensorBoard writer - self.writer.close() - - def evaluate(self, test_loader: DataLoader) -> Dict: - """ - Evaluate the model on test data. - - Args: - test_loader: Test data loader - - Returns: - Dictionary with evaluation metrics - """ - self.model.eval() - total_g_loss = 0.0 - total_d_loss = 0.0 - - print("Evaluating model on test set...") - - for batch in tqdm(test_loader, desc="Evaluation"): - # Move data to device - yandex_img = batch["yandex_img"].to(self.device) - google_img = batch["google_img"].to(self.device) - - with torch.no_grad(): - # Generate fake image - fake_google_img = self.model.generator(yandex_img) - - # Generator loss - returns tuple of tensors - g_loss_tuple = self.model.generator_step(yandex_img, google_img) - g_loss, g_gan_loss, g_l1_loss = g_loss_tuple - - # Discriminator loss - returns tuple of tensors - d_loss_tuple = self.model.discriminator_step( - yandex_img, google_img, fake_google_img - ) - d_loss, d_real_loss, d_fake_loss = d_loss_tuple - - # Update statistics - total_g_loss += g_loss.item() - total_d_loss += d_loss.item() - - avg_g_loss = total_g_loss / len(test_loader) - avg_d_loss = total_d_loss / len(test_loader) - - metrics = { - "generator_loss": avg_g_loss, - "discriminator_loss": avg_d_loss, - "total_loss": avg_g_loss + avg_d_loss, - } - - print("\nTest Results:") - print(f" Generator Loss: {avg_g_loss:.6f}") - print(f" Discriminator Loss: {avg_d_loss:.6f}") - print(f" Total Loss: {avg_g_loss + avg_d_loss:.6f}") - - return metrics +"""Trainer for GAN model.""" + +import json +import time +from pathlib import Path +from typing import Any, Dict, Tuple + +import torch +from torch.utils.data import DataLoader +from tqdm import tqdm + + +class GANTrainer: + """Simple GAN trainer.""" + + def __init__( + self, + model: torch.nn.Module, + train_loader: DataLoader, + val_loader: DataLoader, + config: Dict[str, Any], + ): + self.model = model + self.train_loader = train_loader + self.val_loader = val_loader + self.config = config + self.device = model.device + + # Optimizers + lr = config.get("learning_rate", 2e-4) + beta1 = config.get("beta1", 0.5) + beta2 = config.get("beta2", 0.999) + self.opt_G = torch.optim.Adam(model.generator.parameters(), lr=lr, betas=(beta1, beta2)) + self.opt_D = torch.optim.Adam(model.discriminator.parameters(), lr=lr, betas=(beta1, beta2)) + + # Training state + self.current_epoch = 0 + self.best_val_loss = float("inf") + self.g_losses = [] + self.d_losses = [] + self.val_g_losses = [] + self.val_d_losses = [] + + # Output dir + self.output_dir = Path(config.get("output_dir", "runs/gan")) + self.output_dir.mkdir(parents=True, exist_ok=True) + (self.output_dir / "checkpoints").mkdir(exist_ok=True) + + # Save config + with open(self.output_dir / "config.json", "w") as f: + json.dump(config, f, indent=2) + + def train_epoch(self) -> Tuple[float, float]: + """Train for one epoch.""" + self.model.train() + total_g = total_d = 0.0 + num_batches = len(self.train_loader) + + pbar = tqdm(self.train_loader, desc=f"Epoch {self.current_epoch + 1}") + for batch in pbar: + yandex_img = batch["yandex_img"].to(self.device) + google_img = batch["google_img"].to(self.device) + + # Train D + self.opt_D.zero_grad() + with torch.no_grad(): + fake_img = self.model.generator(yandex_img) + d_loss = self.model.discriminator_step(yandex_img, google_img, fake_img)[0] + d_loss.backward() + self.opt_D.step() + + # Train G + self.opt_G.zero_grad() + g_loss = self.model.generator_step(yandex_img, google_img)[0] + g_loss.backward() + self.opt_G.step() + + total_g += g_loss.item() + total_d += d_loss.item() + pbar.set_postfix({"g_loss": g_loss.item(), "d_loss": d_loss.item()}) + + avg_g = total_g / num_batches + avg_d = total_d / num_batches + self.g_losses.append(avg_g) + self.d_losses.append(avg_d) + return avg_g, avg_d + + @torch.no_grad() + def validate(self) -> Tuple[float, float]: + """Validate the model.""" + self.model.eval() + total_g = total_d = 0.0 + + for batch in tqdm(self.val_loader, desc="Val"): + yandex_img = batch["yandex_img"].to(self.device) + google_img = batch["google_img"].to(self.device) + fake_img = self.model.generator(yandex_img) + g_loss = self.model.generator_step(yandex_img, google_img)[0] + d_loss = self.model.discriminator_step(yandex_img, google_img, fake_img)[0] + total_g += g_loss.item() + total_d += d_loss.item() + + avg_g = total_g / len(self.val_loader) + avg_d = total_d / len(self.val_loader) + self.val_g_losses.append(avg_g) + self.val_d_losses.append(avg_d) + return avg_g, avg_d + + def train(self, num_epochs: int): + """Train the model.""" + print(f"Training for {num_epochs} epochs...") + + for epoch in range(num_epochs): + self.current_epoch = epoch + + # Train & validate + train_g, train_d = self.train_epoch() + val_g, val_d = self.validate() + + # Save best checkpoint + val_total = val_g + val_d + if val_total < self.best_val_loss: + self.best_val_loss = val_total + self.save_checkpoint("best") + + # Periodic checkpoint + if (epoch + 1) % self.config.get("save_interval", 5) == 0: + self.save_checkpoint(f"epoch_{epoch + 1}") + + print(f"Epoch {epoch + 1}: train_g={train_g:.4f}, train_d={train_d:.4f}, val_g={val_g:.4f}, val_d={val_d:.4f}") + + # Early stopping + patience = self.config.get("early_stopping_patience", 0) + if patience > 0 and len(self.val_g_losses) > patience: + recent = self.val_g_losses[-patience:] + if all(l >= min(self.val_g_losses[:-patience]) for l in recent): + print(f"Early stopping at epoch {epoch + 1}") + break + + # Save final + self.save_checkpoint("final") + print(f"Training finished. Best val loss: {self.best_val_loss:.4f}") + + def save_checkpoint(self, name: str): + """Save model checkpoint.""" + path = self.output_dir / "checkpoints" / f"{name}.pth" + torch.save({ + "epoch": self.current_epoch, + "generator": self.model.generator.state_dict(), + "discriminator": self.model.discriminator.state_dict(), + "opt_G": self.opt_G.state_dict(), + "opt_D": self.opt_D.state_dict(), + }, path) + + +def create_trainer( + model: torch.nn.Module, + train_loader: DataLoader, + val_loader: DataLoader, + config: Dict[str, Any], +) -> GANTrainer: + """Create a trainer instance.""" + return GANTrainer(model, train_loader, val_loader, config) + + +if __name__ == "__main__": + # Quick test + from config import create_config + from dataloader import create_data_loaders + from model import create_gan + + config = create_config() + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + model = create_gan(use_cuda=False) + train_loader, val_loader = create_data_loaders( + root_dir=config["data_dir"], + batch_size=4, + image_size=tuple(config["image_size"]), + num_workers=0, + ) + + trainer = create_trainer(model, train_loader, val_loader, config) + + # Test one training step (just to verify no errors) + print("Testing one training step...") + try: + g_loss, d_loss = trainer.train_epoch() + print(f"Training step succeeded: G={g_loss:.4f}, D={d_loss:.4f}") + except Exception as e: + print(f"Training step failed: {e}") \ No newline at end of file