GAN Trainer для преобразования изображений Yandex → Google
Этот модуль содержит реализацию тренера для GAN (Generative Adversarial Network) модели, предназначенной для преобразования изображений карт Yandex в стиль Google Maps.
Структура проекта
autopilot/models/GAN/
├── gan.py # Основная реализация GAN модели
├── trainer.py # Тренер для обучения GAN
├── test_trainer.py # Тесты для тренера
├── train_example.py # Пример использования тренера
└── README.md # Этот файл
Модель GAN
Модель состоит из двух основных компонентов:
1. Генератор (GeneratorUNet)
- Архитектура U-Net для преобразования изображений
- Принимает изображение Yandex (3 канала RGB)
- Возвращает изображение в стиле Google (3 канала RGB)
- Использует skip connections для сохранения деталей
2. Дискриминатор (DiscriminatorPatchGAN)
- PatchGAN архитектура
- Принимает пару изображений (Yandex + Google)
- Возвращает вероятность того, что пара реальная
- Работает с патчами изображения 41x41
Функция потерь (GANLoss)
Поддерживает три режима:
vanilla: Бинарная кросс-энтропияlsgan: Least Squares GAN (более стабильный)wgangp: Wasserstein GAN with Gradient Penalty
Тренер (GANTrainer)
Основные возможности
-
Обучение с чередованием:
- Обучение генератора и дискриминатора поочередно
- Поддержка L1 потерь для сохранения структуры
-
Валидация и мониторинг:
- Отдельные потери для генератора и дискриминатора
- Логирование в TensorBoard
- Ранняя остановка
-
Сохранение и загрузка:
- Чекпоинты каждой эпохи
- Лучшая модель
- Финальная модель
- История обучения
-
Оценка модели:
- Метрики на тестовом наборе
- Генерация примеров
Быстрый старт
import torch
from torch.utils.data import DataLoader
from models.GAN.gan import create_image_gan
from models.GAN.trainer import GANTrainer
# Конфигурация
config = {
"learning_rate": 2e-4,
"beta1": 0.5,
"beta2": 0.999,
"batch_size": 4,
"output_dir": "runs/gan_training",
"gan_mode": "vanilla",
"lambda_L1": 100.0,
"early_stopping_patience": 20,
}
# Устройство
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Создание модели
model = create_image_gan(
input_channels=3,
output_channels=3,
gan_mode=config["gan_mode"],
lambda_L1=config["lambda_L1"],
use_cuda=(device.type == "cuda"),
)
# Создание даталоадеров (замените на свои данные)
train_loader = DataLoader(train_dataset, batch_size=config["batch_size"], shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=config["batch_size"], shuffle=False)
# Создание тренера
trainer = GANTrainer(
model=model,
train_loader=train_loader,
val_loader=val_loader,
device=device,
config=config,
)
# Обучение
trainer.train(num_epochs=100)
# Оценка
metrics = trainer.evaluate(test_loader)
Конфигурация обучения
Базовая конфигурация
config = {
# Параметры оптимизатора
"learning_rate": 2e-4, # Learning rate
"beta1": 0.5, # Adam beta1
"beta2": 0.999, # Adam beta2
# Параметры обучения
"batch_size": 4, # Размер батча
"epochs": 100, # Количество эпох
# Параметры GAN
"gan_mode": "vanilla", # Режим GAN
"lambda_L1": 100.0, # Вес L1 потерь
# Регуляризация
"grad_clip": 1.0, # Gradient clipping
# Ранняя остановка
"early_stopping_patience": 20,
# Выходные данные
"output_dir": "runs/gan",
}
Расширенная конфигурация
config = {
"learning_rate": 2e-4,
"beta1": 0.5,
"beta2": 0.999,
"batch_size": 8,
"epochs": 200,
"gan_mode": "lsgan", # Более стабильный LSGAN
"lambda_L1": 100.0,
"grad_clip": 1.0,
"weight_decay": 1e-4, # Weight decay
"early_stopping_patience": 30,
"early_stopping_min_delta": 1e-4,
"output_dir": "runs/gan_advanced",
}
Методы тренера
Основные методы
train_epoch(): Обучение на одной эпохеvalidate(): Валидация моделиtrain(num_epochs): Полное обучениеevaluate(test_loader): Оценка на тестовых данных
Управление чекпоинтами
save_checkpoint(is_best=False): Сохранение чекпоинтаload_checkpoint(path, resume_training=False): Загрузка чекпоинта
Выходные файлы
После обучения создаются следующие файлы:
runs/gan_training/
├── config.json # Конфигурация обучения
├── training_history.json # История потерь
├── model_best.pth # Лучшая модель
├── model_final.pth # Финальная модель
├── checkpoint_epoch_1.pth # Чекпоинты каждой эпохи
├── checkpoint_epoch_2.pth
├── ...
└── tensorboard/ # Логи TensorBoard
├── events.out.tfevents...
└── ...
TensorBoard
Для визуализации обучения используйте TensorBoard:
tensorboard --logdir runs/gan_training/tensorboard
Доступные метрики:
train/batch_g_loss: Потери генератора на батчеtrain/batch_d_loss: Потери дискриминатора на батчеtrain/batch_g_l1_loss: L1 потери генератораtrain/epoch_g_loss: Потери генератора на эпохеtrain/epoch_d_loss: Потери дискриминатора на эпохеval/epoch_g_loss: Валидационные потери генератораval/epoch_d_loss: Валидационные потери дискриминатора
Тестирование
Запустите тесты для проверки работоспособности:
python models/GAN/test_trainer.py
Пример использования
Полный пример использования смотрите в train_example.py.
Советы по обучению
-
Начальные значения:
- Используйте
gan_mode="lsgan"для более стабильного обучения - Начните с
lambda_L1=100.0и регулируйте по необходимости - Используйте маленький
batch_size(4-8) при ограниченной памяти GPU
- Используйте
-
Мониторинг:
- Следите за балансом потерь генератора и дискриминатора
- Если потери дискриминатора близки к 0, генератор не обучается
- Если потери генератора слишком высоки, уменьшите
lambda_L1
-
Визуализация:
- Регулярно генерируйте примеры для визуальной оценки
- Используйте TensorBoard для отслеживания прогресса
Устранение проблем
Высокие потери генератора
- Уменьшите
lambda_L1 - Увеличьте learning rate
- Проверьте качество данных
Дискриминатор слишком сильный
- Уменьшите learning rate дискриминатора
- Добавьте dropout в дискриминатор
- Обучайте генератор чаще, чем дискриминатор
Недостаток памяти GPU
- Уменьшите
batch_size - Уменьшите размер изображений
- Используйте gradient accumulation
Лицензия
Этот проект является частью Autopilot системы.