Files
autopilot/models/GAN/README.md
2026-02-20 16:52:02 +03:00

254 lines
9.4 KiB
Markdown
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
# GAN Trainer для преобразования изображений Yandex → Google
Этот модуль содержит реализацию тренера для GAN (Generative Adversarial Network) модели, предназначенной для преобразования изображений карт Yandex в стиль Google Maps.
## Структура проекта
```
autopilot/models/GAN/
├── gan.py # Основная реализация GAN модели
├── trainer.py # Тренер для обучения GAN
├── test_trainer.py # Тесты для тренера
├── train_example.py # Пример использования тренера
└── README.md # Этот файл
```
## Модель GAN
Модель состоит из двух основных компонентов:
### 1. Генератор (GeneratorUNet)
- Архитектура U-Net для преобразования изображений
- Принимает изображение Yandex (3 канала RGB)
- Возвращает изображение в стиле Google (3 канала RGB)
- Использует skip connections для сохранения деталей
### 2. Дискриминатор (DiscriminatorPatchGAN)
- PatchGAN архитектура
- Принимает пару изображений (Yandex + Google)
- Возвращает вероятность того, что пара реальная
- Работает с патчами изображения 41x41
### Функция потерь (GANLoss)
Поддерживает три режима:
- `vanilla`: Бинарная кросс-энтропия
- `lsgan`: Least Squares GAN (более стабильный)
- `wgangp`: Wasserstein GAN with Gradient Penalty
## Тренер (GANTrainer)
### Основные возможности
1. **Обучение с чередованием**:
- Обучение генератора и дискриминатора поочередно
- Поддержка L1 потерь для сохранения структуры
2. **Валидация и мониторинг**:
- Отдельные потери для генератора и дискриминатора
- Логирование в TensorBoard
- Ранняя остановка
3. **Сохранение и загрузка**:
- Чекпоинты каждой эпохи
- Лучшая модель
- Финальная модель
- История обучения
4. **Оценка модели**:
- Метрики на тестовом наборе
- Генерация примеров
### Быстрый старт
```python
import torch
from torch.utils.data import DataLoader
from models.GAN.gan import create_image_gan
from models.GAN.trainer import GANTrainer
# Конфигурация
config = {
"learning_rate": 2e-4,
"beta1": 0.5,
"beta2": 0.999,
"batch_size": 4,
"output_dir": "runs/gan_training",
"gan_mode": "vanilla",
"lambda_L1": 100.0,
"early_stopping_patience": 20,
}
# Устройство
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Создание модели
model = create_image_gan(
input_channels=3,
output_channels=3,
gan_mode=config["gan_mode"],
lambda_L1=config["lambda_L1"],
use_cuda=(device.type == "cuda"),
)
# Создание даталоадеров (замените на свои данные)
train_loader = DataLoader(train_dataset, batch_size=config["batch_size"], shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=config["batch_size"], shuffle=False)
# Создание тренера
trainer = GANTrainer(
model=model,
train_loader=train_loader,
val_loader=val_loader,
device=device,
config=config,
)
# Обучение
trainer.train(num_epochs=100)
# Оценка
metrics = trainer.evaluate(test_loader)
```
### Конфигурация обучения
#### Базовая конфигурация
```python
config = {
# Параметры оптимизатора
"learning_rate": 2e-4, # Learning rate
"beta1": 0.5, # Adam beta1
"beta2": 0.999, # Adam beta2
# Параметры обучения
"batch_size": 4, # Размер батча
"epochs": 100, # Количество эпох
# Параметры GAN
"gan_mode": "vanilla", # Режим GAN
"lambda_L1": 100.0, # Вес L1 потерь
# Регуляризация
"grad_clip": 1.0, # Gradient clipping
# Ранняя остановка
"early_stopping_patience": 20,
# Выходные данные
"output_dir": "runs/gan",
}
```
#### Расширенная конфигурация
```python
config = {
"learning_rate": 2e-4,
"beta1": 0.5,
"beta2": 0.999,
"batch_size": 8,
"epochs": 200,
"gan_mode": "lsgan", # Более стабильный LSGAN
"lambda_L1": 100.0,
"grad_clip": 1.0,
"weight_decay": 1e-4, # Weight decay
"early_stopping_patience": 30,
"early_stopping_min_delta": 1e-4,
"output_dir": "runs/gan_advanced",
}
```
### Методы тренера
#### Основные методы
- `train_epoch()`: Обучение на одной эпохе
- `validate()`: Валидация модели
- `train(num_epochs)`: Полное обучение
- `evaluate(test_loader)`: Оценка на тестовых данных
#### Управление чекпоинтами
- `save_checkpoint(is_best=False)`: Сохранение чекпоинта
- `load_checkpoint(path, resume_training=False)`: Загрузка чекпоинта
### Выходные файлы
После обучения создаются следующие файлы:
```
runs/gan_training/
├── config.json # Конфигурация обучения
├── training_history.json # История потерь
├── model_best.pth # Лучшая модель
├── model_final.pth # Финальная модель
├── checkpoint_epoch_1.pth # Чекпоинты каждой эпохи
├── checkpoint_epoch_2.pth
├── ...
└── tensorboard/ # Логи TensorBoard
├── events.out.tfevents...
└── ...
```
### TensorBoard
Для визуализации обучения используйте TensorBoard:
```bash
tensorboard --logdir runs/gan_training/tensorboard
```
Доступные метрики:
- `train/batch_g_loss`: Потери генератора на батче
- `train/batch_d_loss`: Потери дискриминатора на батче
- `train/batch_g_l1_loss`: L1 потери генератора
- `train/epoch_g_loss`: Потери генератора на эпохе
- `train/epoch_d_loss`: Потери дискриминатора на эпохе
- `val/epoch_g_loss`: Валидационные потери генератора
- `val/epoch_d_loss`: Валидационные потери дискриминатора
### Тестирование
Запустите тесты для проверки работоспособности:
```bash
python models/GAN/test_trainer.py
```
### Пример использования
Полный пример использования смотрите в `train_example.py`.
### Советы по обучению
1. **Начальные значения**:
- Используйте `gan_mode="lsgan"` для более стабильного обучения
- Начните с `lambda_L1=100.0` и регулируйте по необходимости
- Используйте маленький `batch_size` (4-8) при ограниченной памяти GPU
2. **Мониторинг**:
- Следите за балансом потерь генератора и дискриминатора
- Если потери дискриминатора близки к 0, генератор не обучается
- Если потери генератора слишком высоки, уменьшите `lambda_L1`
3. **Визуализация**:
- Регулярно генерируйте примеры для визуальной оценки
- Используйте TensorBoard для отслеживания прогресса
### Устранение проблем
#### Высокие потери генератора
- Уменьшите `lambda_L1`
- Увеличьте learning rate
- Проверьте качество данных
#### Дискриминатор слишком сильный
- Уменьшите learning rate дискриминатора
- Добавьте dropout в дискриминатор
- Обучайте генератор чаще, чем дискриминатор
#### Недостаток памяти GPU
- Уменьшите `batch_size`
- Уменьшите размер изображений
- Используйте gradient accumulation
### Лицензия
Этот проект является частью Autopilot системы.