feat: add models
This commit is contained in:
254
models/GAN/README.md
Normal file
254
models/GAN/README.md
Normal file
@@ -0,0 +1,254 @@
|
||||
# GAN Trainer для преобразования изображений Yandex → Google
|
||||
|
||||
Этот модуль содержит реализацию тренера для GAN (Generative Adversarial Network) модели, предназначенной для преобразования изображений карт Yandex в стиль Google Maps.
|
||||
|
||||
## Структура проекта
|
||||
|
||||
```
|
||||
autopilot/models/GAN/
|
||||
├── gan.py # Основная реализация GAN модели
|
||||
├── trainer.py # Тренер для обучения GAN
|
||||
├── test_trainer.py # Тесты для тренера
|
||||
├── train_example.py # Пример использования тренера
|
||||
└── README.md # Этот файл
|
||||
```
|
||||
|
||||
## Модель GAN
|
||||
|
||||
Модель состоит из двух основных компонентов:
|
||||
|
||||
### 1. Генератор (GeneratorUNet)
|
||||
- Архитектура U-Net для преобразования изображений
|
||||
- Принимает изображение Yandex (3 канала RGB)
|
||||
- Возвращает изображение в стиле Google (3 канала RGB)
|
||||
- Использует skip connections для сохранения деталей
|
||||
|
||||
### 2. Дискриминатор (DiscriminatorPatchGAN)
|
||||
- PatchGAN архитектура
|
||||
- Принимает пару изображений (Yandex + Google)
|
||||
- Возвращает вероятность того, что пара реальная
|
||||
- Работает с патчами изображения 41x41
|
||||
|
||||
### Функция потерь (GANLoss)
|
||||
Поддерживает три режима:
|
||||
- `vanilla`: Бинарная кросс-энтропия
|
||||
- `lsgan`: Least Squares GAN (более стабильный)
|
||||
- `wgangp`: Wasserstein GAN with Gradient Penalty
|
||||
|
||||
## Тренер (GANTrainer)
|
||||
|
||||
### Основные возможности
|
||||
|
||||
1. **Обучение с чередованием**:
|
||||
- Обучение генератора и дискриминатора поочередно
|
||||
- Поддержка L1 потерь для сохранения структуры
|
||||
|
||||
2. **Валидация и мониторинг**:
|
||||
- Отдельные потери для генератора и дискриминатора
|
||||
- Логирование в TensorBoard
|
||||
- Ранняя остановка
|
||||
|
||||
3. **Сохранение и загрузка**:
|
||||
- Чекпоинты каждой эпохи
|
||||
- Лучшая модель
|
||||
- Финальная модель
|
||||
- История обучения
|
||||
|
||||
4. **Оценка модели**:
|
||||
- Метрики на тестовом наборе
|
||||
- Генерация примеров
|
||||
|
||||
### Быстрый старт
|
||||
|
||||
```python
|
||||
import torch
|
||||
from torch.utils.data import DataLoader
|
||||
from models.GAN.gan import create_image_gan
|
||||
from models.GAN.trainer import GANTrainer
|
||||
|
||||
# Конфигурация
|
||||
config = {
|
||||
"learning_rate": 2e-4,
|
||||
"beta1": 0.5,
|
||||
"beta2": 0.999,
|
||||
"batch_size": 4,
|
||||
"output_dir": "runs/gan_training",
|
||||
"gan_mode": "vanilla",
|
||||
"lambda_L1": 100.0,
|
||||
"early_stopping_patience": 20,
|
||||
}
|
||||
|
||||
# Устройство
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
# Создание модели
|
||||
model = create_image_gan(
|
||||
input_channels=3,
|
||||
output_channels=3,
|
||||
gan_mode=config["gan_mode"],
|
||||
lambda_L1=config["lambda_L1"],
|
||||
use_cuda=(device.type == "cuda"),
|
||||
)
|
||||
|
||||
# Создание даталоадеров (замените на свои данные)
|
||||
train_loader = DataLoader(train_dataset, batch_size=config["batch_size"], shuffle=True)
|
||||
val_loader = DataLoader(val_dataset, batch_size=config["batch_size"], shuffle=False)
|
||||
|
||||
# Создание тренера
|
||||
trainer = GANTrainer(
|
||||
model=model,
|
||||
train_loader=train_loader,
|
||||
val_loader=val_loader,
|
||||
device=device,
|
||||
config=config,
|
||||
)
|
||||
|
||||
# Обучение
|
||||
trainer.train(num_epochs=100)
|
||||
|
||||
# Оценка
|
||||
metrics = trainer.evaluate(test_loader)
|
||||
```
|
||||
|
||||
### Конфигурация обучения
|
||||
|
||||
#### Базовая конфигурация
|
||||
```python
|
||||
config = {
|
||||
# Параметры оптимизатора
|
||||
"learning_rate": 2e-4, # Learning rate
|
||||
"beta1": 0.5, # Adam beta1
|
||||
"beta2": 0.999, # Adam beta2
|
||||
|
||||
# Параметры обучения
|
||||
"batch_size": 4, # Размер батча
|
||||
"epochs": 100, # Количество эпох
|
||||
|
||||
# Параметры GAN
|
||||
"gan_mode": "vanilla", # Режим GAN
|
||||
"lambda_L1": 100.0, # Вес L1 потерь
|
||||
|
||||
# Регуляризация
|
||||
"grad_clip": 1.0, # Gradient clipping
|
||||
|
||||
# Ранняя остановка
|
||||
"early_stopping_patience": 20,
|
||||
|
||||
# Выходные данные
|
||||
"output_dir": "runs/gan",
|
||||
}
|
||||
```
|
||||
|
||||
#### Расширенная конфигурация
|
||||
```python
|
||||
config = {
|
||||
"learning_rate": 2e-4,
|
||||
"beta1": 0.5,
|
||||
"beta2": 0.999,
|
||||
"batch_size": 8,
|
||||
"epochs": 200,
|
||||
"gan_mode": "lsgan", # Более стабильный LSGAN
|
||||
"lambda_L1": 100.0,
|
||||
"grad_clip": 1.0,
|
||||
"weight_decay": 1e-4, # Weight decay
|
||||
"early_stopping_patience": 30,
|
||||
"early_stopping_min_delta": 1e-4,
|
||||
"output_dir": "runs/gan_advanced",
|
||||
}
|
||||
```
|
||||
|
||||
### Методы тренера
|
||||
|
||||
#### Основные методы
|
||||
- `train_epoch()`: Обучение на одной эпохе
|
||||
- `validate()`: Валидация модели
|
||||
- `train(num_epochs)`: Полное обучение
|
||||
- `evaluate(test_loader)`: Оценка на тестовых данных
|
||||
|
||||
#### Управление чекпоинтами
|
||||
- `save_checkpoint(is_best=False)`: Сохранение чекпоинта
|
||||
- `load_checkpoint(path, resume_training=False)`: Загрузка чекпоинта
|
||||
|
||||
### Выходные файлы
|
||||
|
||||
После обучения создаются следующие файлы:
|
||||
|
||||
```
|
||||
runs/gan_training/
|
||||
├── config.json # Конфигурация обучения
|
||||
├── training_history.json # История потерь
|
||||
├── model_best.pth # Лучшая модель
|
||||
├── model_final.pth # Финальная модель
|
||||
├── checkpoint_epoch_1.pth # Чекпоинты каждой эпохи
|
||||
├── checkpoint_epoch_2.pth
|
||||
├── ...
|
||||
└── tensorboard/ # Логи TensorBoard
|
||||
├── events.out.tfevents...
|
||||
└── ...
|
||||
```
|
||||
|
||||
### TensorBoard
|
||||
|
||||
Для визуализации обучения используйте TensorBoard:
|
||||
|
||||
```bash
|
||||
tensorboard --logdir runs/gan_training/tensorboard
|
||||
```
|
||||
|
||||
Доступные метрики:
|
||||
- `train/batch_g_loss`: Потери генератора на батче
|
||||
- `train/batch_d_loss`: Потери дискриминатора на батче
|
||||
- `train/batch_g_l1_loss`: L1 потери генератора
|
||||
- `train/epoch_g_loss`: Потери генератора на эпохе
|
||||
- `train/epoch_d_loss`: Потери дискриминатора на эпохе
|
||||
- `val/epoch_g_loss`: Валидационные потери генератора
|
||||
- `val/epoch_d_loss`: Валидационные потери дискриминатора
|
||||
|
||||
### Тестирование
|
||||
|
||||
Запустите тесты для проверки работоспособности:
|
||||
|
||||
```bash
|
||||
python models/GAN/test_trainer.py
|
||||
```
|
||||
|
||||
### Пример использования
|
||||
|
||||
Полный пример использования смотрите в `train_example.py`.
|
||||
|
||||
### Советы по обучению
|
||||
|
||||
1. **Начальные значения**:
|
||||
- Используйте `gan_mode="lsgan"` для более стабильного обучения
|
||||
- Начните с `lambda_L1=100.0` и регулируйте по необходимости
|
||||
- Используйте маленький `batch_size` (4-8) при ограниченной памяти GPU
|
||||
|
||||
2. **Мониторинг**:
|
||||
- Следите за балансом потерь генератора и дискриминатора
|
||||
- Если потери дискриминатора близки к 0, генератор не обучается
|
||||
- Если потери генератора слишком высоки, уменьшите `lambda_L1`
|
||||
|
||||
3. **Визуализация**:
|
||||
- Регулярно генерируйте примеры для визуальной оценки
|
||||
- Используйте TensorBoard для отслеживания прогресса
|
||||
|
||||
### Устранение проблем
|
||||
|
||||
#### Высокие потери генератора
|
||||
- Уменьшите `lambda_L1`
|
||||
- Увеличьте learning rate
|
||||
- Проверьте качество данных
|
||||
|
||||
#### Дискриминатор слишком сильный
|
||||
- Уменьшите learning rate дискриминатора
|
||||
- Добавьте dropout в дискриминатор
|
||||
- Обучайте генератор чаще, чем дискриминатор
|
||||
|
||||
#### Недостаток памяти GPU
|
||||
- Уменьшите `batch_size`
|
||||
- Уменьшите размер изображений
|
||||
- Используйте gradient accumulation
|
||||
|
||||
### Лицензия
|
||||
|
||||
Этот проект является частью Autopilot системы.
|
||||
Reference in New Issue
Block a user