Files
autopilot/models/GAN
2026-04-04 17:50:10 +03:00
..
2026-02-20 16:52:02 +03:00
2026-04-04 17:50:10 +03:00
2026-04-04 17:50:10 +03:00
2026-02-20 16:52:02 +03:00

GAN Trainer для преобразования изображений Yandex → Google

Этот модуль содержит реализацию тренера для GAN (Generative Adversarial Network) модели, предназначенной для преобразования изображений карт Yandex в стиль Google Maps.

Структура проекта

autopilot/models/GAN/
├── gan.py              # Основная реализация GAN модели
├── trainer.py          # Тренер для обучения GAN
├── test_trainer.py     # Тесты для тренера
├── train_example.py    # Пример использования тренера
└── README.md           # Этот файл

Модель GAN

Модель состоит из двух основных компонентов:

1. Генератор (GeneratorUNet)

  • Архитектура U-Net для преобразования изображений
  • Принимает изображение Yandex (3 канала RGB)
  • Возвращает изображение в стиле Google (3 канала RGB)
  • Использует skip connections для сохранения деталей

2. Дискриминатор (DiscriminatorPatchGAN)

  • PatchGAN архитектура
  • Принимает пару изображений (Yandex + Google)
  • Возвращает вероятность того, что пара реальная
  • Работает с патчами изображения 41x41

Функция потерь (GANLoss)

Поддерживает три режима:

  • vanilla: Бинарная кросс-энтропия
  • lsgan: Least Squares GAN (более стабильный)
  • wgangp: Wasserstein GAN with Gradient Penalty

Тренер (GANTrainer)

Основные возможности

  1. Обучение с чередованием:

    • Обучение генератора и дискриминатора поочередно
    • Поддержка L1 потерь для сохранения структуры
  2. Валидация и мониторинг:

    • Отдельные потери для генератора и дискриминатора
    • Логирование в TensorBoard
    • Ранняя остановка
  3. Сохранение и загрузка:

    • Чекпоинты каждой эпохи
    • Лучшая модель
    • Финальная модель
    • История обучения
  4. Оценка модели:

    • Метрики на тестовом наборе
    • Генерация примеров

Быстрый старт

import torch
from torch.utils.data import DataLoader
from models.GAN.gan import create_image_gan
from models.GAN.trainer import GANTrainer

# Конфигурация
config = {
    "learning_rate": 2e-4,
    "beta1": 0.5,
    "beta2": 0.999,
    "batch_size": 4,
    "output_dir": "runs/gan_training",
    "gan_mode": "vanilla",
    "lambda_L1": 100.0,
    "early_stopping_patience": 20,
}

# Устройство
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Создание модели
model = create_image_gan(
    input_channels=3,
    output_channels=3,
    gan_mode=config["gan_mode"],
    lambda_L1=config["lambda_L1"],
    use_cuda=(device.type == "cuda"),
)

# Создание даталоадеров (замените на свои данные)
train_loader = DataLoader(train_dataset, batch_size=config["batch_size"], shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=config["batch_size"], shuffle=False)

# Создание тренера
trainer = GANTrainer(
    model=model,
    train_loader=train_loader,
    val_loader=val_loader,
    device=device,
    config=config,
)

# Обучение
trainer.train(num_epochs=100)

# Оценка
metrics = trainer.evaluate(test_loader)

Конфигурация обучения

Базовая конфигурация

config = {
    # Параметры оптимизатора
    "learning_rate": 2e-4,      # Learning rate
    "beta1": 0.5,               # Adam beta1
    "beta2": 0.999,             # Adam beta2
    
    # Параметры обучения
    "batch_size": 4,            # Размер батча
    "epochs": 100,              # Количество эпох
    
    # Параметры GAN
    "gan_mode": "vanilla",      # Режим GAN
    "lambda_L1": 100.0,         # Вес L1 потерь
    
    # Регуляризация
    "grad_clip": 1.0,           # Gradient clipping
    
    # Ранняя остановка
    "early_stopping_patience": 20,
    
    # Выходные данные
    "output_dir": "runs/gan",
}

Расширенная конфигурация

config = {
    "learning_rate": 2e-4,
    "beta1": 0.5,
    "beta2": 0.999,
    "batch_size": 8,
    "epochs": 200,
    "gan_mode": "lsgan",        # Более стабильный LSGAN
    "lambda_L1": 100.0,
    "grad_clip": 1.0,
    "weight_decay": 1e-4,       # Weight decay
    "early_stopping_patience": 30,
    "early_stopping_min_delta": 1e-4,
    "output_dir": "runs/gan_advanced",
}

Методы тренера

Основные методы

  • train_epoch(): Обучение на одной эпохе
  • validate(): Валидация модели
  • train(num_epochs): Полное обучение
  • evaluate(test_loader): Оценка на тестовых данных

Управление чекпоинтами

  • save_checkpoint(is_best=False): Сохранение чекпоинта
  • load_checkpoint(path, resume_training=False): Загрузка чекпоинта

Выходные файлы

После обучения создаются следующие файлы:

runs/gan_training/
├── config.json                 # Конфигурация обучения
├── training_history.json       # История потерь
├── model_best.pth             # Лучшая модель
├── model_final.pth            # Финальная модель
├── checkpoint_epoch_1.pth     # Чекпоинты каждой эпохи
├── checkpoint_epoch_2.pth
├── ...
└── tensorboard/               # Логи TensorBoard
    ├── events.out.tfevents...
    └── ...

TensorBoard

Для визуализации обучения используйте TensorBoard:

tensorboard --logdir runs/gan_training/tensorboard

Доступные метрики:

  • train/batch_g_loss: Потери генератора на батче
  • train/batch_d_loss: Потери дискриминатора на батче
  • train/batch_g_l1_loss: L1 потери генератора
  • train/epoch_g_loss: Потери генератора на эпохе
  • train/epoch_d_loss: Потери дискриминатора на эпохе
  • val/epoch_g_loss: Валидационные потери генератора
  • val/epoch_d_loss: Валидационные потери дискриминатора

Тестирование

Запустите тесты для проверки работоспособности:

python models/GAN/test_trainer.py

Пример использования

Полный пример использования смотрите в train_example.py.

Советы по обучению

  1. Начальные значения:

    • Используйте gan_mode="lsgan" для более стабильного обучения
    • Начните с lambda_L1=100.0 и регулируйте по необходимости
    • Используйте маленький batch_size (4-8) при ограниченной памяти GPU
  2. Мониторинг:

    • Следите за балансом потерь генератора и дискриминатора
    • Если потери дискриминатора близки к 0, генератор не обучается
    • Если потери генератора слишком высоки, уменьшите lambda_L1
  3. Визуализация:

    • Регулярно генерируйте примеры для визуальной оценки
    • Используйте TensorBoard для отслеживания прогресса

Устранение проблем

Высокие потери генератора

  • Уменьшите lambda_L1
  • Увеличьте learning rate
  • Проверьте качество данных

Дискриминатор слишком сильный

  • Уменьшите learning rate дискриминатора
  • Добавьте dropout в дискриминатор
  • Обучайте генератор чаще, чем дискриминатор

Недостаток памяти GPU

  • Уменьшите batch_size
  • Уменьшите размер изображений
  • Используйте gradient accumulation

Лицензия

Этот проект является частью Autopilot системы.