254 lines
9.4 KiB
Markdown
254 lines
9.4 KiB
Markdown
# GAN Trainer для преобразования изображений Yandex → Google
|
||
|
||
Этот модуль содержит реализацию тренера для GAN (Generative Adversarial Network) модели, предназначенной для преобразования изображений карт Yandex в стиль Google Maps.
|
||
|
||
## Структура проекта
|
||
|
||
```
|
||
autopilot/models/GAN/
|
||
├── gan.py # Основная реализация GAN модели
|
||
├── trainer.py # Тренер для обучения GAN
|
||
├── test_trainer.py # Тесты для тренера
|
||
├── train_example.py # Пример использования тренера
|
||
└── README.md # Этот файл
|
||
```
|
||
|
||
## Модель GAN
|
||
|
||
Модель состоит из двух основных компонентов:
|
||
|
||
### 1. Генератор (GeneratorUNet)
|
||
- Архитектура U-Net для преобразования изображений
|
||
- Принимает изображение Yandex (3 канала RGB)
|
||
- Возвращает изображение в стиле Google (3 канала RGB)
|
||
- Использует skip connections для сохранения деталей
|
||
|
||
### 2. Дискриминатор (DiscriminatorPatchGAN)
|
||
- PatchGAN архитектура
|
||
- Принимает пару изображений (Yandex + Google)
|
||
- Возвращает вероятность того, что пара реальная
|
||
- Работает с патчами изображения 41x41
|
||
|
||
### Функция потерь (GANLoss)
|
||
Поддерживает три режима:
|
||
- `vanilla`: Бинарная кросс-энтропия
|
||
- `lsgan`: Least Squares GAN (более стабильный)
|
||
- `wgangp`: Wasserstein GAN with Gradient Penalty
|
||
|
||
## Тренер (GANTrainer)
|
||
|
||
### Основные возможности
|
||
|
||
1. **Обучение с чередованием**:
|
||
- Обучение генератора и дискриминатора поочередно
|
||
- Поддержка L1 потерь для сохранения структуры
|
||
|
||
2. **Валидация и мониторинг**:
|
||
- Отдельные потери для генератора и дискриминатора
|
||
- Логирование в TensorBoard
|
||
- Ранняя остановка
|
||
|
||
3. **Сохранение и загрузка**:
|
||
- Чекпоинты каждой эпохи
|
||
- Лучшая модель
|
||
- Финальная модель
|
||
- История обучения
|
||
|
||
4. **Оценка модели**:
|
||
- Метрики на тестовом наборе
|
||
- Генерация примеров
|
||
|
||
### Быстрый старт
|
||
|
||
```python
|
||
import torch
|
||
from torch.utils.data import DataLoader
|
||
from models.GAN.gan import create_image_gan
|
||
from models.GAN.trainer import GANTrainer
|
||
|
||
# Конфигурация
|
||
config = {
|
||
"learning_rate": 2e-4,
|
||
"beta1": 0.5,
|
||
"beta2": 0.999,
|
||
"batch_size": 4,
|
||
"output_dir": "runs/gan_training",
|
||
"gan_mode": "vanilla",
|
||
"lambda_L1": 100.0,
|
||
"early_stopping_patience": 20,
|
||
}
|
||
|
||
# Устройство
|
||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||
|
||
# Создание модели
|
||
model = create_image_gan(
|
||
input_channels=3,
|
||
output_channels=3,
|
||
gan_mode=config["gan_mode"],
|
||
lambda_L1=config["lambda_L1"],
|
||
use_cuda=(device.type == "cuda"),
|
||
)
|
||
|
||
# Создание даталоадеров (замените на свои данные)
|
||
train_loader = DataLoader(train_dataset, batch_size=config["batch_size"], shuffle=True)
|
||
val_loader = DataLoader(val_dataset, batch_size=config["batch_size"], shuffle=False)
|
||
|
||
# Создание тренера
|
||
trainer = GANTrainer(
|
||
model=model,
|
||
train_loader=train_loader,
|
||
val_loader=val_loader,
|
||
device=device,
|
||
config=config,
|
||
)
|
||
|
||
# Обучение
|
||
trainer.train(num_epochs=100)
|
||
|
||
# Оценка
|
||
metrics = trainer.evaluate(test_loader)
|
||
```
|
||
|
||
### Конфигурация обучения
|
||
|
||
#### Базовая конфигурация
|
||
```python
|
||
config = {
|
||
# Параметры оптимизатора
|
||
"learning_rate": 2e-4, # Learning rate
|
||
"beta1": 0.5, # Adam beta1
|
||
"beta2": 0.999, # Adam beta2
|
||
|
||
# Параметры обучения
|
||
"batch_size": 4, # Размер батча
|
||
"epochs": 100, # Количество эпох
|
||
|
||
# Параметры GAN
|
||
"gan_mode": "vanilla", # Режим GAN
|
||
"lambda_L1": 100.0, # Вес L1 потерь
|
||
|
||
# Регуляризация
|
||
"grad_clip": 1.0, # Gradient clipping
|
||
|
||
# Ранняя остановка
|
||
"early_stopping_patience": 20,
|
||
|
||
# Выходные данные
|
||
"output_dir": "runs/gan",
|
||
}
|
||
```
|
||
|
||
#### Расширенная конфигурация
|
||
```python
|
||
config = {
|
||
"learning_rate": 2e-4,
|
||
"beta1": 0.5,
|
||
"beta2": 0.999,
|
||
"batch_size": 8,
|
||
"epochs": 200,
|
||
"gan_mode": "lsgan", # Более стабильный LSGAN
|
||
"lambda_L1": 100.0,
|
||
"grad_clip": 1.0,
|
||
"weight_decay": 1e-4, # Weight decay
|
||
"early_stopping_patience": 30,
|
||
"early_stopping_min_delta": 1e-4,
|
||
"output_dir": "runs/gan_advanced",
|
||
}
|
||
```
|
||
|
||
### Методы тренера
|
||
|
||
#### Основные методы
|
||
- `train_epoch()`: Обучение на одной эпохе
|
||
- `validate()`: Валидация модели
|
||
- `train(num_epochs)`: Полное обучение
|
||
- `evaluate(test_loader)`: Оценка на тестовых данных
|
||
|
||
#### Управление чекпоинтами
|
||
- `save_checkpoint(is_best=False)`: Сохранение чекпоинта
|
||
- `load_checkpoint(path, resume_training=False)`: Загрузка чекпоинта
|
||
|
||
### Выходные файлы
|
||
|
||
После обучения создаются следующие файлы:
|
||
|
||
```
|
||
runs/gan_training/
|
||
├── config.json # Конфигурация обучения
|
||
├── training_history.json # История потерь
|
||
├── model_best.pth # Лучшая модель
|
||
├── model_final.pth # Финальная модель
|
||
├── checkpoint_epoch_1.pth # Чекпоинты каждой эпохи
|
||
├── checkpoint_epoch_2.pth
|
||
├── ...
|
||
└── tensorboard/ # Логи TensorBoard
|
||
├── events.out.tfevents...
|
||
└── ...
|
||
```
|
||
|
||
### TensorBoard
|
||
|
||
Для визуализации обучения используйте TensorBoard:
|
||
|
||
```bash
|
||
tensorboard --logdir runs/gan_training/tensorboard
|
||
```
|
||
|
||
Доступные метрики:
|
||
- `train/batch_g_loss`: Потери генератора на батче
|
||
- `train/batch_d_loss`: Потери дискриминатора на батче
|
||
- `train/batch_g_l1_loss`: L1 потери генератора
|
||
- `train/epoch_g_loss`: Потери генератора на эпохе
|
||
- `train/epoch_d_loss`: Потери дискриминатора на эпохе
|
||
- `val/epoch_g_loss`: Валидационные потери генератора
|
||
- `val/epoch_d_loss`: Валидационные потери дискриминатора
|
||
|
||
### Тестирование
|
||
|
||
Запустите тесты для проверки работоспособности:
|
||
|
||
```bash
|
||
python models/GAN/test_trainer.py
|
||
```
|
||
|
||
### Пример использования
|
||
|
||
Полный пример использования смотрите в `train_example.py`.
|
||
|
||
### Советы по обучению
|
||
|
||
1. **Начальные значения**:
|
||
- Используйте `gan_mode="lsgan"` для более стабильного обучения
|
||
- Начните с `lambda_L1=100.0` и регулируйте по необходимости
|
||
- Используйте маленький `batch_size` (4-8) при ограниченной памяти GPU
|
||
|
||
2. **Мониторинг**:
|
||
- Следите за балансом потерь генератора и дискриминатора
|
||
- Если потери дискриминатора близки к 0, генератор не обучается
|
||
- Если потери генератора слишком высоки, уменьшите `lambda_L1`
|
||
|
||
3. **Визуализация**:
|
||
- Регулярно генерируйте примеры для визуальной оценки
|
||
- Используйте TensorBoard для отслеживания прогресса
|
||
|
||
### Устранение проблем
|
||
|
||
#### Высокие потери генератора
|
||
- Уменьшите `lambda_L1`
|
||
- Увеличьте learning rate
|
||
- Проверьте качество данных
|
||
|
||
#### Дискриминатор слишком сильный
|
||
- Уменьшите learning rate дискриминатора
|
||
- Добавьте dropout в дискриминатор
|
||
- Обучайте генератор чаще, чем дискриминатор
|
||
|
||
#### Недостаток памяти GPU
|
||
- Уменьшите `batch_size`
|
||
- Уменьшите размер изображений
|
||
- Используйте gradient accumulation
|
||
|
||
### Лицензия
|
||
|
||
Этот проект является частью Autopilot системы. |