feat: add models

This commit is contained in:
2026-02-20 16:52:02 +03:00
parent 6040f3b253
commit 0cc210968f
11 changed files with 4488 additions and 48 deletions

254
models/GAN/README.md Normal file
View File

@@ -0,0 +1,254 @@
# GAN Trainer для преобразования изображений Yandex → Google
Этот модуль содержит реализацию тренера для GAN (Generative Adversarial Network) модели, предназначенной для преобразования изображений карт Yandex в стиль Google Maps.
## Структура проекта
```
autopilot/models/GAN/
├── gan.py # Основная реализация GAN модели
├── trainer.py # Тренер для обучения GAN
├── test_trainer.py # Тесты для тренера
├── train_example.py # Пример использования тренера
└── README.md # Этот файл
```
## Модель GAN
Модель состоит из двух основных компонентов:
### 1. Генератор (GeneratorUNet)
- Архитектура U-Net для преобразования изображений
- Принимает изображение Yandex (3 канала RGB)
- Возвращает изображение в стиле Google (3 канала RGB)
- Использует skip connections для сохранения деталей
### 2. Дискриминатор (DiscriminatorPatchGAN)
- PatchGAN архитектура
- Принимает пару изображений (Yandex + Google)
- Возвращает вероятность того, что пара реальная
- Работает с патчами изображения 41x41
### Функция потерь (GANLoss)
Поддерживает три режима:
- `vanilla`: Бинарная кросс-энтропия
- `lsgan`: Least Squares GAN (более стабильный)
- `wgangp`: Wasserstein GAN with Gradient Penalty
## Тренер (GANTrainer)
### Основные возможности
1. **Обучение с чередованием**:
- Обучение генератора и дискриминатора поочередно
- Поддержка L1 потерь для сохранения структуры
2. **Валидация и мониторинг**:
- Отдельные потери для генератора и дискриминатора
- Логирование в TensorBoard
- Ранняя остановка
3. **Сохранение и загрузка**:
- Чекпоинты каждой эпохи
- Лучшая модель
- Финальная модель
- История обучения
4. **Оценка модели**:
- Метрики на тестовом наборе
- Генерация примеров
### Быстрый старт
```python
import torch
from torch.utils.data import DataLoader
from models.GAN.gan import create_image_gan
from models.GAN.trainer import GANTrainer
# Конфигурация
config = {
"learning_rate": 2e-4,
"beta1": 0.5,
"beta2": 0.999,
"batch_size": 4,
"output_dir": "runs/gan_training",
"gan_mode": "vanilla",
"lambda_L1": 100.0,
"early_stopping_patience": 20,
}
# Устройство
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Создание модели
model = create_image_gan(
input_channels=3,
output_channels=3,
gan_mode=config["gan_mode"],
lambda_L1=config["lambda_L1"],
use_cuda=(device.type == "cuda"),
)
# Создание даталоадеров (замените на свои данные)
train_loader = DataLoader(train_dataset, batch_size=config["batch_size"], shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=config["batch_size"], shuffle=False)
# Создание тренера
trainer = GANTrainer(
model=model,
train_loader=train_loader,
val_loader=val_loader,
device=device,
config=config,
)
# Обучение
trainer.train(num_epochs=100)
# Оценка
metrics = trainer.evaluate(test_loader)
```
### Конфигурация обучения
#### Базовая конфигурация
```python
config = {
# Параметры оптимизатора
"learning_rate": 2e-4, # Learning rate
"beta1": 0.5, # Adam beta1
"beta2": 0.999, # Adam beta2
# Параметры обучения
"batch_size": 4, # Размер батча
"epochs": 100, # Количество эпох
# Параметры GAN
"gan_mode": "vanilla", # Режим GAN
"lambda_L1": 100.0, # Вес L1 потерь
# Регуляризация
"grad_clip": 1.0, # Gradient clipping
# Ранняя остановка
"early_stopping_patience": 20,
# Выходные данные
"output_dir": "runs/gan",
}
```
#### Расширенная конфигурация
```python
config = {
"learning_rate": 2e-4,
"beta1": 0.5,
"beta2": 0.999,
"batch_size": 8,
"epochs": 200,
"gan_mode": "lsgan", # Более стабильный LSGAN
"lambda_L1": 100.0,
"grad_clip": 1.0,
"weight_decay": 1e-4, # Weight decay
"early_stopping_patience": 30,
"early_stopping_min_delta": 1e-4,
"output_dir": "runs/gan_advanced",
}
```
### Методы тренера
#### Основные методы
- `train_epoch()`: Обучение на одной эпохе
- `validate()`: Валидация модели
- `train(num_epochs)`: Полное обучение
- `evaluate(test_loader)`: Оценка на тестовых данных
#### Управление чекпоинтами
- `save_checkpoint(is_best=False)`: Сохранение чекпоинта
- `load_checkpoint(path, resume_training=False)`: Загрузка чекпоинта
### Выходные файлы
После обучения создаются следующие файлы:
```
runs/gan_training/
├── config.json # Конфигурация обучения
├── training_history.json # История потерь
├── model_best.pth # Лучшая модель
├── model_final.pth # Финальная модель
├── checkpoint_epoch_1.pth # Чекпоинты каждой эпохи
├── checkpoint_epoch_2.pth
├── ...
└── tensorboard/ # Логи TensorBoard
├── events.out.tfevents...
└── ...
```
### TensorBoard
Для визуализации обучения используйте TensorBoard:
```bash
tensorboard --logdir runs/gan_training/tensorboard
```
Доступные метрики:
- `train/batch_g_loss`: Потери генератора на батче
- `train/batch_d_loss`: Потери дискриминатора на батче
- `train/batch_g_l1_loss`: L1 потери генератора
- `train/epoch_g_loss`: Потери генератора на эпохе
- `train/epoch_d_loss`: Потери дискриминатора на эпохе
- `val/epoch_g_loss`: Валидационные потери генератора
- `val/epoch_d_loss`: Валидационные потери дискриминатора
### Тестирование
Запустите тесты для проверки работоспособности:
```bash
python models/GAN/test_trainer.py
```
### Пример использования
Полный пример использования смотрите в `train_example.py`.
### Советы по обучению
1. **Начальные значения**:
- Используйте `gan_mode="lsgan"` для более стабильного обучения
- Начните с `lambda_L1=100.0` и регулируйте по необходимости
- Используйте маленький `batch_size` (4-8) при ограниченной памяти GPU
2. **Мониторинг**:
- Следите за балансом потерь генератора и дискриминатора
- Если потери дискриминатора близки к 0, генератор не обучается
- Если потери генератора слишком высоки, уменьшите `lambda_L1`
3. **Визуализация**:
- Регулярно генерируйте примеры для визуальной оценки
- Используйте TensorBoard для отслеживания прогресса
### Устранение проблем
#### Высокие потери генератора
- Уменьшите `lambda_L1`
- Увеличьте learning rate
- Проверьте качество данных
#### Дискриминатор слишком сильный
- Уменьшите learning rate дискриминатора
- Добавьте dropout в дискриминатор
- Обучайте генератор чаще, чем дискриминатор
#### Недостаток памяти GPU
- Уменьшите `batch_size`
- Уменьшите размер изображений
- Используйте gradient accumulation
### Лицензия
Этот проект является частью Autopilot системы.