# GAN Trainer для преобразования изображений Yandex → Google Этот модуль содержит реализацию тренера для GAN (Generative Adversarial Network) модели, предназначенной для преобразования изображений карт Yandex в стиль Google Maps. ## Структура проекта ``` autopilot/models/GAN/ ├── gan.py # Основная реализация GAN модели ├── trainer.py # Тренер для обучения GAN ├── test_trainer.py # Тесты для тренера ├── train_example.py # Пример использования тренера └── README.md # Этот файл ``` ## Модель GAN Модель состоит из двух основных компонентов: ### 1. Генератор (GeneratorUNet) - Архитектура U-Net для преобразования изображений - Принимает изображение Yandex (3 канала RGB) - Возвращает изображение в стиле Google (3 канала RGB) - Использует skip connections для сохранения деталей ### 2. Дискриминатор (DiscriminatorPatchGAN) - PatchGAN архитектура - Принимает пару изображений (Yandex + Google) - Возвращает вероятность того, что пара реальная - Работает с патчами изображения 41x41 ### Функция потерь (GANLoss) Поддерживает три режима: - `vanilla`: Бинарная кросс-энтропия - `lsgan`: Least Squares GAN (более стабильный) - `wgangp`: Wasserstein GAN with Gradient Penalty ## Тренер (GANTrainer) ### Основные возможности 1. **Обучение с чередованием**: - Обучение генератора и дискриминатора поочередно - Поддержка L1 потерь для сохранения структуры 2. **Валидация и мониторинг**: - Отдельные потери для генератора и дискриминатора - Логирование в TensorBoard - Ранняя остановка 3. **Сохранение и загрузка**: - Чекпоинты каждой эпохи - Лучшая модель - Финальная модель - История обучения 4. **Оценка модели**: - Метрики на тестовом наборе - Генерация примеров ### Быстрый старт ```python import torch from torch.utils.data import DataLoader from models.GAN.gan import create_image_gan from models.GAN.trainer import GANTrainer # Конфигурация config = { "learning_rate": 2e-4, "beta1": 0.5, "beta2": 0.999, "batch_size": 4, "output_dir": "runs/gan_training", "gan_mode": "vanilla", "lambda_L1": 100.0, "early_stopping_patience": 20, } # Устройство device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Создание модели model = create_image_gan( input_channels=3, output_channels=3, gan_mode=config["gan_mode"], lambda_L1=config["lambda_L1"], use_cuda=(device.type == "cuda"), ) # Создание даталоадеров (замените на свои данные) train_loader = DataLoader(train_dataset, batch_size=config["batch_size"], shuffle=True) val_loader = DataLoader(val_dataset, batch_size=config["batch_size"], shuffle=False) # Создание тренера trainer = GANTrainer( model=model, train_loader=train_loader, val_loader=val_loader, device=device, config=config, ) # Обучение trainer.train(num_epochs=100) # Оценка metrics = trainer.evaluate(test_loader) ``` ### Конфигурация обучения #### Базовая конфигурация ```python config = { # Параметры оптимизатора "learning_rate": 2e-4, # Learning rate "beta1": 0.5, # Adam beta1 "beta2": 0.999, # Adam beta2 # Параметры обучения "batch_size": 4, # Размер батча "epochs": 100, # Количество эпох # Параметры GAN "gan_mode": "vanilla", # Режим GAN "lambda_L1": 100.0, # Вес L1 потерь # Регуляризация "grad_clip": 1.0, # Gradient clipping # Ранняя остановка "early_stopping_patience": 20, # Выходные данные "output_dir": "runs/gan", } ``` #### Расширенная конфигурация ```python config = { "learning_rate": 2e-4, "beta1": 0.5, "beta2": 0.999, "batch_size": 8, "epochs": 200, "gan_mode": "lsgan", # Более стабильный LSGAN "lambda_L1": 100.0, "grad_clip": 1.0, "weight_decay": 1e-4, # Weight decay "early_stopping_patience": 30, "early_stopping_min_delta": 1e-4, "output_dir": "runs/gan_advanced", } ``` ### Методы тренера #### Основные методы - `train_epoch()`: Обучение на одной эпохе - `validate()`: Валидация модели - `train(num_epochs)`: Полное обучение - `evaluate(test_loader)`: Оценка на тестовых данных #### Управление чекпоинтами - `save_checkpoint(is_best=False)`: Сохранение чекпоинта - `load_checkpoint(path, resume_training=False)`: Загрузка чекпоинта ### Выходные файлы После обучения создаются следующие файлы: ``` runs/gan_training/ ├── config.json # Конфигурация обучения ├── training_history.json # История потерь ├── model_best.pth # Лучшая модель ├── model_final.pth # Финальная модель ├── checkpoint_epoch_1.pth # Чекпоинты каждой эпохи ├── checkpoint_epoch_2.pth ├── ... └── tensorboard/ # Логи TensorBoard ├── events.out.tfevents... └── ... ``` ### TensorBoard Для визуализации обучения используйте TensorBoard: ```bash tensorboard --logdir runs/gan_training/tensorboard ``` Доступные метрики: - `train/batch_g_loss`: Потери генератора на батче - `train/batch_d_loss`: Потери дискриминатора на батче - `train/batch_g_l1_loss`: L1 потери генератора - `train/epoch_g_loss`: Потери генератора на эпохе - `train/epoch_d_loss`: Потери дискриминатора на эпохе - `val/epoch_g_loss`: Валидационные потери генератора - `val/epoch_d_loss`: Валидационные потери дискриминатора ### Тестирование Запустите тесты для проверки работоспособности: ```bash python models/GAN/test_trainer.py ``` ### Пример использования Полный пример использования смотрите в `train_example.py`. ### Советы по обучению 1. **Начальные значения**: - Используйте `gan_mode="lsgan"` для более стабильного обучения - Начните с `lambda_L1=100.0` и регулируйте по необходимости - Используйте маленький `batch_size` (4-8) при ограниченной памяти GPU 2. **Мониторинг**: - Следите за балансом потерь генератора и дискриминатора - Если потери дискриминатора близки к 0, генератор не обучается - Если потери генератора слишком высоки, уменьшите `lambda_L1` 3. **Визуализация**: - Регулярно генерируйте примеры для визуальной оценки - Используйте TensorBoard для отслеживания прогресса ### Устранение проблем #### Высокие потери генератора - Уменьшите `lambda_L1` - Увеличьте learning rate - Проверьте качество данных #### Дискриминатор слишком сильный - Уменьшите learning rate дискриминатора - Добавьте dropout в дискриминатор - Обучайте генератор чаще, чем дискриминатор #### Недостаток памяти GPU - Уменьшите `batch_size` - Уменьшите размер изображений - Используйте gradient accumulation ### Лицензия Этот проект является частью Autopilot системы.