feat: add models #2
1
models/GAN/.gitignore
vendored
Normal file
1
models/GAN/.gitignore
vendored
Normal file
@@ -0,0 +1 @@
|
|||||||
|
runs
|
||||||
254
models/GAN/README.md
Normal file
254
models/GAN/README.md
Normal file
@@ -0,0 +1,254 @@
|
|||||||
|
# GAN Trainer для преобразования изображений Yandex → Google
|
||||||
|
|
||||||
|
Этот модуль содержит реализацию тренера для GAN (Generative Adversarial Network) модели, предназначенной для преобразования изображений карт Yandex в стиль Google Maps.
|
||||||
|
|
||||||
|
## Структура проекта
|
||||||
|
|
||||||
|
```
|
||||||
|
autopilot/models/GAN/
|
||||||
|
├── gan.py # Основная реализация GAN модели
|
||||||
|
├── trainer.py # Тренер для обучения GAN
|
||||||
|
├── test_trainer.py # Тесты для тренера
|
||||||
|
├── train_example.py # Пример использования тренера
|
||||||
|
└── README.md # Этот файл
|
||||||
|
```
|
||||||
|
|
||||||
|
## Модель GAN
|
||||||
|
|
||||||
|
Модель состоит из двух основных компонентов:
|
||||||
|
|
||||||
|
### 1. Генератор (GeneratorUNet)
|
||||||
|
- Архитектура U-Net для преобразования изображений
|
||||||
|
- Принимает изображение Yandex (3 канала RGB)
|
||||||
|
- Возвращает изображение в стиле Google (3 канала RGB)
|
||||||
|
- Использует skip connections для сохранения деталей
|
||||||
|
|
||||||
|
### 2. Дискриминатор (DiscriminatorPatchGAN)
|
||||||
|
- PatchGAN архитектура
|
||||||
|
- Принимает пару изображений (Yandex + Google)
|
||||||
|
- Возвращает вероятность того, что пара реальная
|
||||||
|
- Работает с патчами изображения 41x41
|
||||||
|
|
||||||
|
### Функция потерь (GANLoss)
|
||||||
|
Поддерживает три режима:
|
||||||
|
- `vanilla`: Бинарная кросс-энтропия
|
||||||
|
- `lsgan`: Least Squares GAN (более стабильный)
|
||||||
|
- `wgangp`: Wasserstein GAN with Gradient Penalty
|
||||||
|
|
||||||
|
## Тренер (GANTrainer)
|
||||||
|
|
||||||
|
### Основные возможности
|
||||||
|
|
||||||
|
1. **Обучение с чередованием**:
|
||||||
|
- Обучение генератора и дискриминатора поочередно
|
||||||
|
- Поддержка L1 потерь для сохранения структуры
|
||||||
|
|
||||||
|
2. **Валидация и мониторинг**:
|
||||||
|
- Отдельные потери для генератора и дискриминатора
|
||||||
|
- Логирование в TensorBoard
|
||||||
|
- Ранняя остановка
|
||||||
|
|
||||||
|
3. **Сохранение и загрузка**:
|
||||||
|
- Чекпоинты каждой эпохи
|
||||||
|
- Лучшая модель
|
||||||
|
- Финальная модель
|
||||||
|
- История обучения
|
||||||
|
|
||||||
|
4. **Оценка модели**:
|
||||||
|
- Метрики на тестовом наборе
|
||||||
|
- Генерация примеров
|
||||||
|
|
||||||
|
### Быстрый старт
|
||||||
|
|
||||||
|
```python
|
||||||
|
import torch
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
from models.GAN.gan import create_image_gan
|
||||||
|
from models.GAN.trainer import GANTrainer
|
||||||
|
|
||||||
|
# Конфигурация
|
||||||
|
config = {
|
||||||
|
"learning_rate": 2e-4,
|
||||||
|
"beta1": 0.5,
|
||||||
|
"beta2": 0.999,
|
||||||
|
"batch_size": 4,
|
||||||
|
"output_dir": "runs/gan_training",
|
||||||
|
"gan_mode": "vanilla",
|
||||||
|
"lambda_L1": 100.0,
|
||||||
|
"early_stopping_patience": 20,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Устройство
|
||||||
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
|
||||||
|
# Создание модели
|
||||||
|
model = create_image_gan(
|
||||||
|
input_channels=3,
|
||||||
|
output_channels=3,
|
||||||
|
gan_mode=config["gan_mode"],
|
||||||
|
lambda_L1=config["lambda_L1"],
|
||||||
|
use_cuda=(device.type == "cuda"),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Создание даталоадеров (замените на свои данные)
|
||||||
|
train_loader = DataLoader(train_dataset, batch_size=config["batch_size"], shuffle=True)
|
||||||
|
val_loader = DataLoader(val_dataset, batch_size=config["batch_size"], shuffle=False)
|
||||||
|
|
||||||
|
# Создание тренера
|
||||||
|
trainer = GANTrainer(
|
||||||
|
model=model,
|
||||||
|
train_loader=train_loader,
|
||||||
|
val_loader=val_loader,
|
||||||
|
device=device,
|
||||||
|
config=config,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Обучение
|
||||||
|
trainer.train(num_epochs=100)
|
||||||
|
|
||||||
|
# Оценка
|
||||||
|
metrics = trainer.evaluate(test_loader)
|
||||||
|
```
|
||||||
|
|
||||||
|
### Конфигурация обучения
|
||||||
|
|
||||||
|
#### Базовая конфигурация
|
||||||
|
```python
|
||||||
|
config = {
|
||||||
|
# Параметры оптимизатора
|
||||||
|
"learning_rate": 2e-4, # Learning rate
|
||||||
|
"beta1": 0.5, # Adam beta1
|
||||||
|
"beta2": 0.999, # Adam beta2
|
||||||
|
|
||||||
|
# Параметры обучения
|
||||||
|
"batch_size": 4, # Размер батча
|
||||||
|
"epochs": 100, # Количество эпох
|
||||||
|
|
||||||
|
# Параметры GAN
|
||||||
|
"gan_mode": "vanilla", # Режим GAN
|
||||||
|
"lambda_L1": 100.0, # Вес L1 потерь
|
||||||
|
|
||||||
|
# Регуляризация
|
||||||
|
"grad_clip": 1.0, # Gradient clipping
|
||||||
|
|
||||||
|
# Ранняя остановка
|
||||||
|
"early_stopping_patience": 20,
|
||||||
|
|
||||||
|
# Выходные данные
|
||||||
|
"output_dir": "runs/gan",
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Расширенная конфигурация
|
||||||
|
```python
|
||||||
|
config = {
|
||||||
|
"learning_rate": 2e-4,
|
||||||
|
"beta1": 0.5,
|
||||||
|
"beta2": 0.999,
|
||||||
|
"batch_size": 8,
|
||||||
|
"epochs": 200,
|
||||||
|
"gan_mode": "lsgan", # Более стабильный LSGAN
|
||||||
|
"lambda_L1": 100.0,
|
||||||
|
"grad_clip": 1.0,
|
||||||
|
"weight_decay": 1e-4, # Weight decay
|
||||||
|
"early_stopping_patience": 30,
|
||||||
|
"early_stopping_min_delta": 1e-4,
|
||||||
|
"output_dir": "runs/gan_advanced",
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### Методы тренера
|
||||||
|
|
||||||
|
#### Основные методы
|
||||||
|
- `train_epoch()`: Обучение на одной эпохе
|
||||||
|
- `validate()`: Валидация модели
|
||||||
|
- `train(num_epochs)`: Полное обучение
|
||||||
|
- `evaluate(test_loader)`: Оценка на тестовых данных
|
||||||
|
|
||||||
|
#### Управление чекпоинтами
|
||||||
|
- `save_checkpoint(is_best=False)`: Сохранение чекпоинта
|
||||||
|
- `load_checkpoint(path, resume_training=False)`: Загрузка чекпоинта
|
||||||
|
|
||||||
|
### Выходные файлы
|
||||||
|
|
||||||
|
После обучения создаются следующие файлы:
|
||||||
|
|
||||||
|
```
|
||||||
|
runs/gan_training/
|
||||||
|
├── config.json # Конфигурация обучения
|
||||||
|
├── training_history.json # История потерь
|
||||||
|
├── model_best.pth # Лучшая модель
|
||||||
|
├── model_final.pth # Финальная модель
|
||||||
|
├── checkpoint_epoch_1.pth # Чекпоинты каждой эпохи
|
||||||
|
├── checkpoint_epoch_2.pth
|
||||||
|
├── ...
|
||||||
|
└── tensorboard/ # Логи TensorBoard
|
||||||
|
├── events.out.tfevents...
|
||||||
|
└── ...
|
||||||
|
```
|
||||||
|
|
||||||
|
### TensorBoard
|
||||||
|
|
||||||
|
Для визуализации обучения используйте TensorBoard:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
tensorboard --logdir runs/gan_training/tensorboard
|
||||||
|
```
|
||||||
|
|
||||||
|
Доступные метрики:
|
||||||
|
- `train/batch_g_loss`: Потери генератора на батче
|
||||||
|
- `train/batch_d_loss`: Потери дискриминатора на батче
|
||||||
|
- `train/batch_g_l1_loss`: L1 потери генератора
|
||||||
|
- `train/epoch_g_loss`: Потери генератора на эпохе
|
||||||
|
- `train/epoch_d_loss`: Потери дискриминатора на эпохе
|
||||||
|
- `val/epoch_g_loss`: Валидационные потери генератора
|
||||||
|
- `val/epoch_d_loss`: Валидационные потери дискриминатора
|
||||||
|
|
||||||
|
### Тестирование
|
||||||
|
|
||||||
|
Запустите тесты для проверки работоспособности:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python models/GAN/test_trainer.py
|
||||||
|
```
|
||||||
|
|
||||||
|
### Пример использования
|
||||||
|
|
||||||
|
Полный пример использования смотрите в `train_example.py`.
|
||||||
|
|
||||||
|
### Советы по обучению
|
||||||
|
|
||||||
|
1. **Начальные значения**:
|
||||||
|
- Используйте `gan_mode="lsgan"` для более стабильного обучения
|
||||||
|
- Начните с `lambda_L1=100.0` и регулируйте по необходимости
|
||||||
|
- Используйте маленький `batch_size` (4-8) при ограниченной памяти GPU
|
||||||
|
|
||||||
|
2. **Мониторинг**:
|
||||||
|
- Следите за балансом потерь генератора и дискриминатора
|
||||||
|
- Если потери дискриминатора близки к 0, генератор не обучается
|
||||||
|
- Если потери генератора слишком высоки, уменьшите `lambda_L1`
|
||||||
|
|
||||||
|
3. **Визуализация**:
|
||||||
|
- Регулярно генерируйте примеры для визуальной оценки
|
||||||
|
- Используйте TensorBoard для отслеживания прогресса
|
||||||
|
|
||||||
|
### Устранение проблем
|
||||||
|
|
||||||
|
#### Высокие потери генератора
|
||||||
|
- Уменьшите `lambda_L1`
|
||||||
|
- Увеличьте learning rate
|
||||||
|
- Проверьте качество данных
|
||||||
|
|
||||||
|
#### Дискриминатор слишком сильный
|
||||||
|
- Уменьшите learning rate дискриминатора
|
||||||
|
- Добавьте dropout в дискриминатор
|
||||||
|
- Обучайте генератор чаще, чем дискриминатор
|
||||||
|
|
||||||
|
#### Недостаток памяти GPU
|
||||||
|
- Уменьшите `batch_size`
|
||||||
|
- Уменьшите размер изображений
|
||||||
|
- Используйте gradient accumulation
|
||||||
|
|
||||||
|
### Лицензия
|
||||||
|
|
||||||
|
Этот проект является частью Autopilot системы.
|
||||||
1777
models/GAN/gan.ipynb
Normal file
1777
models/GAN/gan.ipynb
Normal file
File diff suppressed because one or more lines are too long
393
models/GAN/gan.py
Normal file
393
models/GAN/gan.py
Normal file
@@ -0,0 +1,393 @@
|
|||||||
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
|
||||||
|
class UNetDownBlock(nn.Module):
|
||||||
|
"""Блок downsampling для U-Net генератора"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels: int,
|
||||||
|
out_channels: int,
|
||||||
|
normalize: bool = True,
|
||||||
|
dropout: float = 0.0,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
layers = [
|
||||||
|
nn.Conv2d(
|
||||||
|
in_channels,
|
||||||
|
out_channels,
|
||||||
|
kernel_size=4,
|
||||||
|
stride=2,
|
||||||
|
padding=1,
|
||||||
|
bias=False,
|
||||||
|
)
|
||||||
|
]
|
||||||
|
if normalize:
|
||||||
|
layers.append(nn.BatchNorm2d(out_channels))
|
||||||
|
layers.append(nn.LeakyReLU(0.2, inplace=True))
|
||||||
|
if dropout > 0:
|
||||||
|
layers.append(nn.Dropout2d(dropout))
|
||||||
|
self.model = nn.Sequential(*layers)
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
return self.model(x)
|
||||||
|
|
||||||
|
|
||||||
|
class UNetUpBlock(nn.Module):
|
||||||
|
"""Блок upsampling для U-Net генератора"""
|
||||||
|
|
||||||
|
def __init__(self, in_channels: int, out_channels: int, dropout: float = 0.0):
|
||||||
|
super().__init__()
|
||||||
|
layers = [
|
||||||
|
nn.ConvTranspose2d(
|
||||||
|
in_channels,
|
||||||
|
out_channels,
|
||||||
|
kernel_size=4,
|
||||||
|
stride=2,
|
||||||
|
padding=1,
|
||||||
|
bias=False,
|
||||||
|
),
|
||||||
|
nn.BatchNorm2d(out_channels),
|
||||||
|
nn.ReLU(inplace=True),
|
||||||
|
]
|
||||||
|
if dropout > 0:
|
||||||
|
layers.append(nn.Dropout2d(dropout))
|
||||||
|
self.model = nn.Sequential(*layers)
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor, skip_input: torch.Tensor) -> torch.Tensor:
|
||||||
|
x = self.model(x)
|
||||||
|
# Обрезаем skip connection до размера x, если необходимо
|
||||||
|
if x.shape != skip_input.shape:
|
||||||
|
diffY = skip_input.size(2) - x.size(2)
|
||||||
|
diffX = skip_input.size(3) - x.size(3)
|
||||||
|
x = F.pad(
|
||||||
|
x, [diffX // 2, diffX - diffX // 2, diffY // 2, diffY - diffY // 2]
|
||||||
|
)
|
||||||
|
x = torch.cat([x, skip_input], dim=1)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class GeneratorUNet(nn.Module):
|
||||||
|
"""Генератор на основе U-Net архитектуры для преобразования Yandex → Google"""
|
||||||
|
|
||||||
|
def __init__(self, in_channels: int = 3, out_channels: int = 3):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
# Downsampling path
|
||||||
|
self.down1 = UNetDownBlock(in_channels, 64, normalize=False)
|
||||||
|
self.down2 = UNetDownBlock(64, 128)
|
||||||
|
self.down3 = UNetDownBlock(128, 256)
|
||||||
|
self.down4 = UNetDownBlock(256, 512)
|
||||||
|
self.down5 = UNetDownBlock(512, 512)
|
||||||
|
self.down6 = UNetDownBlock(512, 512)
|
||||||
|
self.down7 = UNetDownBlock(512, 512)
|
||||||
|
|
||||||
|
# Bottleneck
|
||||||
|
self.bottleneck = nn.Sequential(
|
||||||
|
nn.Conv2d(512, 512, kernel_size=4, stride=2, padding=1, bias=False),
|
||||||
|
nn.ReLU(inplace=True),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Upsampling path
|
||||||
|
self.up1 = UNetUpBlock(512, 512, dropout=0.5)
|
||||||
|
self.up2 = UNetUpBlock(1024, 512, dropout=0.5)
|
||||||
|
self.up3 = UNetUpBlock(1024, 512, dropout=0.5)
|
||||||
|
self.up4 = UNetUpBlock(1024, 512)
|
||||||
|
self.up5 = UNetUpBlock(1024, 256)
|
||||||
|
self.up6 = UNetUpBlock(512, 128)
|
||||||
|
self.up7 = UNetUpBlock(256, 64)
|
||||||
|
|
||||||
|
# Final layer
|
||||||
|
self.final = nn.Sequential(
|
||||||
|
nn.ConvTranspose2d(128, out_channels, kernel_size=4, stride=2, padding=1),
|
||||||
|
nn.Tanh(),
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
# Downsampling
|
||||||
|
d1 = self.down1(x) # 350x350
|
||||||
|
d2 = self.down2(d1) # 175x175
|
||||||
|
d3 = self.down3(d2) # 88x88
|
||||||
|
d4 = self.down4(d3) # 44x44
|
||||||
|
d5 = self.down5(d4) # 22x22
|
||||||
|
d6 = self.down6(d5) # 11x11
|
||||||
|
d7 = self.down7(d6) # 6x6
|
||||||
|
|
||||||
|
# Bottleneck
|
||||||
|
u = self.bottleneck(d7) # 3x3
|
||||||
|
|
||||||
|
# Upsampling with skip connections
|
||||||
|
u = self.up1(u, d7) # 6x6
|
||||||
|
u = self.up2(u, d6) # 11x11
|
||||||
|
u = self.up3(u, d5) # 22x22
|
||||||
|
u = self.up4(u, d4) # 44x44
|
||||||
|
u = self.up5(u, d3) # 88x88
|
||||||
|
u = self.up6(u, d2) # 175x175
|
||||||
|
u = self.up7(u, d1) # 350x350
|
||||||
|
|
||||||
|
# Final output
|
||||||
|
return self.final(u) # 700x700
|
||||||
|
|
||||||
|
|
||||||
|
class DiscriminatorPatchGAN(nn.Module):
|
||||||
|
"""Дискриминатор PatchGAN для изображений 700x700"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self, in_channels: int = 6
|
||||||
|
): # 3 для реального + 3 для сгенерированного
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
def discriminator_block(
|
||||||
|
in_filters: int, out_filters: int, normalization: bool = True
|
||||||
|
):
|
||||||
|
"""Блок дискриминатора"""
|
||||||
|
layers = [
|
||||||
|
nn.Conv2d(in_filters, out_filters, kernel_size=4, stride=2, padding=1)
|
||||||
|
]
|
||||||
|
if normalization:
|
||||||
|
layers.append(nn.BatchNorm2d(out_filters))
|
||||||
|
layers.append(nn.LeakyReLU(0.2, inplace=True))
|
||||||
|
return layers
|
||||||
|
|
||||||
|
self.model = nn.Sequential(
|
||||||
|
*discriminator_block(in_channels, 64, normalization=False), # 350x350
|
||||||
|
*discriminator_block(64, 128), # 175x175
|
||||||
|
*discriminator_block(128, 256), # 88x88
|
||||||
|
*discriminator_block(256, 512), # 44x44
|
||||||
|
nn.Conv2d(512, 1, kernel_size=4, stride=1, padding=1), # 41x41
|
||||||
|
nn.Sigmoid(),
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, img_A: torch.Tensor, img_B: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Принимает пару изображений (реальное и сгенерированное)
|
||||||
|
и возвращает вероятность того, что пара реальная
|
||||||
|
"""
|
||||||
|
# Объединяем два изображения по каналам
|
||||||
|
img_input = torch.cat((img_A, img_B), 1)
|
||||||
|
return self.model(img_input)
|
||||||
|
|
||||||
|
|
||||||
|
class GANLoss(nn.Module):
|
||||||
|
"""Функция потерь для GAN"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
gan_mode: str = "vanilla",
|
||||||
|
target_real_label: float = 1.0,
|
||||||
|
target_fake_label: float = 0.0,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.register_buffer("real_label", torch.tensor(target_real_label))
|
||||||
|
self.register_buffer("fake_label", torch.tensor(target_fake_label))
|
||||||
|
self.gan_mode = gan_mode
|
||||||
|
|
||||||
|
if gan_mode == "vanilla":
|
||||||
|
self.loss = nn.BCEWithLogitsLoss()
|
||||||
|
elif gan_mode == "lsgan":
|
||||||
|
self.loss = nn.MSELoss()
|
||||||
|
elif gan_mode == "wgangp":
|
||||||
|
self.loss = None
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(f"GAN mode {gan_mode} not implemented")
|
||||||
|
|
||||||
|
def get_target_tensor(
|
||||||
|
self, prediction: torch.Tensor, target_is_real: bool
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""Создает тензор меток"""
|
||||||
|
if target_is_real:
|
||||||
|
target_tensor = self.real_label
|
||||||
|
else:
|
||||||
|
target_tensor = self.fake_label
|
||||||
|
return target_tensor.expand_as(prediction)
|
||||||
|
|
||||||
|
def __call__(self, prediction: torch.Tensor, target_is_real: bool) -> torch.Tensor:
|
||||||
|
"""Вычисляет потери"""
|
||||||
|
if self.gan_mode in ["vanilla", "lsgan"]:
|
||||||
|
target_tensor = self.get_target_tensor(prediction, target_is_real)
|
||||||
|
loss = self.loss(prediction, target_tensor)
|
||||||
|
elif self.gan_mode == "wgangp":
|
||||||
|
if target_is_real:
|
||||||
|
loss = -prediction.mean()
|
||||||
|
else:
|
||||||
|
loss = prediction.mean()
|
||||||
|
return loss
|
||||||
|
|
||||||
|
|
||||||
|
class ImageGAN(nn.Module):
|
||||||
|
"""Основной класс GAN для преобразования изображений Yandex → Google"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
input_channels: int = 3,
|
||||||
|
output_channels: int = 3,
|
||||||
|
gan_mode: str = "vanilla",
|
||||||
|
lambda_L1: float = 100.0,
|
||||||
|
use_cuda: bool = True,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.generator = GeneratorUNet(input_channels, output_channels)
|
||||||
|
self.discriminator = DiscriminatorPatchGAN(input_channels + output_channels)
|
||||||
|
self.gan_loss = GANLoss(gan_mode)
|
||||||
|
self.l1_loss = nn.L1Loss()
|
||||||
|
self.lambda_L1 = lambda_L1
|
||||||
|
|
||||||
|
self.device = torch.device(
|
||||||
|
"cuda" if use_cuda and torch.cuda.is_available() else "cpu"
|
||||||
|
)
|
||||||
|
self.to(self.device)
|
||||||
|
|
||||||
|
def forward(self, yandex_image: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""Генерация изображения Google из Yandex"""
|
||||||
|
return self.generator(yandex_image)
|
||||||
|
|
||||||
|
def generator_step(
|
||||||
|
self, yandex_image: torch.Tensor, real_google_image: torch.Tensor
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||||
|
"""
|
||||||
|
Шаг обучения генератора
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
total_loss: общие потери генератора
|
||||||
|
gan_loss: потери GAN
|
||||||
|
l1_loss: потери L1
|
||||||
|
"""
|
||||||
|
# Генерируем изображение
|
||||||
|
fake_google_image = self.generator(yandex_image)
|
||||||
|
|
||||||
|
# Оцениваем дискриминатором
|
||||||
|
fake_pred = self.discriminator(yandex_image, fake_google_image)
|
||||||
|
|
||||||
|
# Потери GAN (пытаемся обмануть дискриминатор)
|
||||||
|
gan_loss = self.gan_loss(fake_pred, True)
|
||||||
|
|
||||||
|
# Потери L1 для сохранения структуры
|
||||||
|
l1_loss = self.l1_loss(fake_google_image, real_google_image) * self.lambda_L1
|
||||||
|
|
||||||
|
# Общие потери
|
||||||
|
total_loss = gan_loss + l1_loss
|
||||||
|
|
||||||
|
return total_loss, gan_loss, l1_loss
|
||||||
|
|
||||||
|
def discriminator_step(
|
||||||
|
self,
|
||||||
|
yandex_image: torch.Tensor,
|
||||||
|
real_google_image: torch.Tensor,
|
||||||
|
fake_google_image: torch.Tensor,
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||||
|
"""
|
||||||
|
Шаг обучения дискриминатора
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
total_loss: общие потери дискриминатора
|
||||||
|
real_loss: потери на реальных изображениях
|
||||||
|
fake_loss: потери на сгенерированных изображениях
|
||||||
|
"""
|
||||||
|
# Предсказания для реальных пар
|
||||||
|
real_pred = self.discriminator(yandex_image, real_google_image)
|
||||||
|
real_loss = self.gan_loss(real_pred, True)
|
||||||
|
|
||||||
|
# Предсказания для сгенерированных пар
|
||||||
|
fake_pred = self.discriminator(yandex_image, fake_google_image.detach())
|
||||||
|
fake_loss = self.gan_loss(fake_pred, False)
|
||||||
|
|
||||||
|
# Общие потери дискриминатора
|
||||||
|
total_loss = (real_loss + fake_loss) * 0.5
|
||||||
|
|
||||||
|
return total_loss, real_loss, fake_loss
|
||||||
|
|
||||||
|
def to(self, device):
|
||||||
|
"""Перемещает модель на устройство"""
|
||||||
|
self.generator.to(device)
|
||||||
|
self.discriminator.to(device)
|
||||||
|
return self
|
||||||
|
|
||||||
|
def train_mode(self):
|
||||||
|
"""Переключает модель в режим обучения"""
|
||||||
|
self.generator.train()
|
||||||
|
self.discriminator.train()
|
||||||
|
|
||||||
|
def eval_mode(self):
|
||||||
|
"""Переключает модель в режим оценки"""
|
||||||
|
self.generator.eval()
|
||||||
|
self.discriminator.eval()
|
||||||
|
|
||||||
|
def save_checkpoint(self, path: str):
|
||||||
|
"""Сохраняет чекпоинт модели"""
|
||||||
|
checkpoint = {
|
||||||
|
"generator_state_dict": self.generator.state_dict(),
|
||||||
|
"discriminator_state_dict": self.discriminator.state_dict(),
|
||||||
|
"generator_optimizer_state_dict": getattr(
|
||||||
|
self.generator, "optimizer_state_dict", None
|
||||||
|
),
|
||||||
|
"discriminator_optimizer_state_dict": getattr(
|
||||||
|
self.discriminator, "optimizer_state_dict", None
|
||||||
|
),
|
||||||
|
}
|
||||||
|
torch.save(checkpoint, path)
|
||||||
|
|
||||||
|
def load_checkpoint(self, path: str):
|
||||||
|
"""Загружает чекпоинт модели"""
|
||||||
|
checkpoint = torch.load(path, map_location=self.device)
|
||||||
|
self.generator.load_state_dict(checkpoint["generator_state_dict"])
|
||||||
|
self.discriminator.load_state_dict(checkpoint["discriminator_state_dict"])
|
||||||
|
|
||||||
|
if checkpoint["generator_optimizer_state_dict"] is not None:
|
||||||
|
self.generator.optimizer_state_dict = checkpoint[
|
||||||
|
"generator_optimizer_state_dict"
|
||||||
|
]
|
||||||
|
if checkpoint["discriminator_optimizer_state_dict"] is not None:
|
||||||
|
self.discriminator.optimizer_state_dict = checkpoint[
|
||||||
|
"discriminator_optimizer_state_dict"
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def create_image_gan(
|
||||||
|
input_channels: int = 3,
|
||||||
|
output_channels: int = 3,
|
||||||
|
gan_mode: str = "vanilla",
|
||||||
|
lambda_L1: float = 100.0,
|
||||||
|
use_cuda: bool = True,
|
||||||
|
) -> ImageGAN:
|
||||||
|
"""
|
||||||
|
Создает и возвращает модель GAN для преобразования изображений
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input_channels: количество входных каналов (обычно 3 для RGB)
|
||||||
|
output_channels: количество выходных каналов (обычно 3 для RGB)
|
||||||
|
gan_mode: режим GAN ('vanilla', 'lsgan', 'wgangp')
|
||||||
|
lambda_L1: вес L1 потерь
|
||||||
|
use_cuda: использовать ли CUDA если доступно
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ImageGAN: модель GAN
|
||||||
|
"""
|
||||||
|
return ImageGAN(
|
||||||
|
input_channels=input_channels,
|
||||||
|
output_channels=output_channels,
|
||||||
|
gan_mode=gan_mode,
|
||||||
|
lambda_L1=lambda_L1,
|
||||||
|
use_cuda=use_cuda,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Вспомогательные функции для инициализации весов
|
||||||
|
def weights_init_normal(m):
|
||||||
|
"""Инициализация весов с нормальным распределением"""
|
||||||
|
classname = m.__class__.__name__
|
||||||
|
if classname.find("Conv") != -1:
|
||||||
|
nn.init.normal_(m.weight.data, 0.0, 0.02)
|
||||||
|
elif classname.find("BatchNorm") != -1:
|
||||||
|
nn.init.normal_(m.weight.data, 1.0, 0.02)
|
||||||
|
nn.init.constant_(m.batch_norm.bias.data, 0.0)
|
||||||
|
|
||||||
|
|
||||||
|
def initialize_gan_weights(generator: nn.Module, discriminator: nn.Module):
|
||||||
|
"""Инициализирует веса генератора и дискриминатора"""
|
||||||
|
generator.apply(weights_init_normal)
|
||||||
|
discriminator.apply(weights_init_normal)
|
||||||
136
models/GAN/minimal_example.py
Normal file
136
models/GAN/minimal_example.py
Normal file
@@ -0,0 +1,136 @@
|
|||||||
|
"""
|
||||||
|
Минимальный пример использования GAN trainer для преобразования Yandex → Google карт.
|
||||||
|
|
||||||
|
Этот пример показывает самый простой способ использования тренера.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch.utils.data import DataLoader, Dataset
|
||||||
|
|
||||||
|
# Добавляем путь к модулям
|
||||||
|
sys.path.append(str(Path(__file__).parent.parent.parent))
|
||||||
|
|
||||||
|
from models.GAN.gan import create_image_gan
|
||||||
|
from models.GAN.trainer import GANTrainer
|
||||||
|
|
||||||
|
|
||||||
|
class SimpleMapDataset(Dataset):
|
||||||
|
"""Простой датасет с фиктивными данными для примера."""
|
||||||
|
|
||||||
|
def __init__(self, num_samples=100, image_size=(256, 256)):
|
||||||
|
self.num_samples = num_samples
|
||||||
|
self.image_size = image_size
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return self.num_samples
|
||||||
|
|
||||||
|
def __getitem__(self, idx):
|
||||||
|
# Создаем фиктивные изображения
|
||||||
|
# В реальном коде замените на загрузку реальных изображений
|
||||||
|
yandex_img = torch.randn(3, self.image_size[0], self.image_size[1])
|
||||||
|
google_img = torch.randn(3, self.image_size[0], self.image_size[1])
|
||||||
|
|
||||||
|
return {"yandex_img": yandex_img, "google_img": google_img}
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
"""Основная функция минимального примера."""
|
||||||
|
print("Минимальный пример использования GAN trainer")
|
||||||
|
print("=" * 50)
|
||||||
|
|
||||||
|
# 1. Конфигурация (минимальный набор параметров)
|
||||||
|
config = {
|
||||||
|
"learning_rate": 2e-4,
|
||||||
|
"batch_size": 4,
|
||||||
|
"output_dir": "runs/gan_minimal",
|
||||||
|
}
|
||||||
|
|
||||||
|
# 2. Устройство (CPU или GPU)
|
||||||
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
print(f"Используемое устройство: {device}")
|
||||||
|
|
||||||
|
# 3. Создание модели
|
||||||
|
print("\nСоздание GAN модели...")
|
||||||
|
model = create_image_gan(
|
||||||
|
input_channels=3,
|
||||||
|
output_channels=3,
|
||||||
|
gan_mode="vanilla", # Простейший режим
|
||||||
|
lambda_L1=100.0, # Стандартный вес L1 потерь
|
||||||
|
use_cuda=(device.type == "cuda"),
|
||||||
|
)
|
||||||
|
|
||||||
|
# 4. Создание даталоадеров
|
||||||
|
print("Создание даталоадеров...")
|
||||||
|
train_dataset = SimpleMapDataset(num_samples=50)
|
||||||
|
val_dataset = SimpleMapDataset(num_samples=10)
|
||||||
|
|
||||||
|
train_loader = DataLoader(
|
||||||
|
train_dataset,
|
||||||
|
batch_size=config["batch_size"],
|
||||||
|
shuffle=True,
|
||||||
|
num_workers=0,
|
||||||
|
)
|
||||||
|
|
||||||
|
val_loader = DataLoader(
|
||||||
|
val_dataset,
|
||||||
|
batch_size=config["batch_size"],
|
||||||
|
shuffle=False,
|
||||||
|
num_workers=0,
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f" Обучающих примеров: {len(train_dataset)}")
|
||||||
|
print(f" Валидационных примеров: {len(val_dataset)}")
|
||||||
|
|
||||||
|
# 5. Создание тренера
|
||||||
|
print("\nСоздание тренера...")
|
||||||
|
trainer = GANTrainer(
|
||||||
|
model=model,
|
||||||
|
train_loader=train_loader,
|
||||||
|
val_loader=val_loader,
|
||||||
|
device=device,
|
||||||
|
config=config,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 6. Обучение на небольшом количестве эпох
|
||||||
|
print("\nЗапуск обучения (3 эпохи для примера)...")
|
||||||
|
print("=" * 50)
|
||||||
|
|
||||||
|
trainer.train(num_epochs=3)
|
||||||
|
|
||||||
|
# 7. Генерация примеров
|
||||||
|
print("\nГенерация примеров преобразования...")
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
# Создаем тестовые данные
|
||||||
|
test_yandex = torch.randn(2, 3, 256, 256).to(device)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
generated_google = model(test_yandex)
|
||||||
|
|
||||||
|
print(f"Входные изображения: {test_yandex.shape}")
|
||||||
|
print(f"Сгенерированные изображения: {generated_google.shape}")
|
||||||
|
print(
|
||||||
|
f"Диапазон значений: [{generated_google.min():.3f}, {generated_google.max():.3f}]"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 8. Сохранение финальной модели
|
||||||
|
print("\nСохранение модели...")
|
||||||
|
model_save_path = "gan_model_minimal.pth"
|
||||||
|
torch.save(model.state_dict(), model_save_path)
|
||||||
|
print(f"Модель сохранена в: {model_save_path}")
|
||||||
|
|
||||||
|
print("\n" + "=" * 50)
|
||||||
|
print("Минимальный пример завершен!")
|
||||||
|
print("\nДля реального использования:")
|
||||||
|
print("1. Замените SimpleMapDataset на ваш реальный датасет")
|
||||||
|
print("2. Настройте параметры в config")
|
||||||
|
print("3. Увеличьте количество эпох (например, до 100)")
|
||||||
|
print("4. Используйте реальные изображения карт")
|
||||||
|
print("=" * 50)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
349
models/GAN/test_gan.py
Normal file
349
models/GAN/test_gan.py
Normal file
@@ -0,0 +1,349 @@
|
|||||||
|
import os
|
||||||
|
import sys
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
# Добавляем путь к модулю
|
||||||
|
sys.path.append(
|
||||||
|
os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||||
|
)
|
||||||
|
|
||||||
|
from gan import (
|
||||||
|
DiscriminatorPatchGAN,
|
||||||
|
GeneratorUNet,
|
||||||
|
ImageGAN,
|
||||||
|
create_image_gan,
|
||||||
|
initialize_gan_weights,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_generator():
|
||||||
|
"""Тестирование генератора"""
|
||||||
|
print("=" * 60)
|
||||||
|
print("Тестирование генератора...")
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
# Создаем генератор
|
||||||
|
generator = GeneratorUNet(in_channels=3, out_channels=3)
|
||||||
|
|
||||||
|
# Инициализируем веса
|
||||||
|
generator.apply(
|
||||||
|
lambda m: (
|
||||||
|
nn.init.normal_(m.weight.data, 0.0, 0.02)
|
||||||
|
if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d)
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Создаем тестовый входной тензор (Yandex изображение)
|
||||||
|
batch_size = 2
|
||||||
|
height, width = 700, 700
|
||||||
|
yandex_image = torch.randn(batch_size, 3, height, width)
|
||||||
|
|
||||||
|
print(f"Размер входного изображения: {yandex_image.shape}")
|
||||||
|
print(
|
||||||
|
f"Количество параметров генератора: {sum(p.numel() for p in generator.parameters()):,}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Прямой проход
|
||||||
|
with torch.no_grad():
|
||||||
|
generated_image = generator(yandex_image)
|
||||||
|
|
||||||
|
print(f"Размер сгенерированного изображения: {generated_image.shape}")
|
||||||
|
print(
|
||||||
|
f"Диапазон значений сгенерированного изображения: [{generated_image.min():.3f}, {generated_image.max():.3f}]"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Проверка размеров
|
||||||
|
assert generated_image.shape == (batch_size, 3, height, width), (
|
||||||
|
f"Ожидался размер {(batch_size, 3, height, width)}, получен {generated_image.shape}"
|
||||||
|
)
|
||||||
|
|
||||||
|
print("✓ Генератор работает корректно!")
|
||||||
|
return generator
|
||||||
|
|
||||||
|
|
||||||
|
def test_discriminator():
|
||||||
|
"""Тестирование дискриминатора"""
|
||||||
|
print("\n" + "=" * 60)
|
||||||
|
print("Тестирование дискриминатора...")
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
# Создаем дискриминатор
|
||||||
|
discriminator = DiscriminatorPatchGAN(in_channels=6) # 3 + 3 канала
|
||||||
|
|
||||||
|
# Инициализируем веса
|
||||||
|
discriminator.apply(
|
||||||
|
lambda m: (
|
||||||
|
nn.init.normal_(m.weight.data, 0.0, 0.02)
|
||||||
|
if isinstance(m, nn.Conv2d)
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Создаем тестовые тензоры
|
||||||
|
batch_size = 2
|
||||||
|
height, width = 700, 700
|
||||||
|
|
||||||
|
yandex_image = torch.randn(batch_size, 3, height, width)
|
||||||
|
google_image = torch.randn(batch_size, 3, height, width)
|
||||||
|
|
||||||
|
print(f"Размер Yandex изображения: {yandex_image.shape}")
|
||||||
|
print(f"Размер Google изображения: {google_image.shape}")
|
||||||
|
print(
|
||||||
|
f"Количество параметров дискриминатора: {sum(p.numel() for p in discriminator.parameters()):,}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Прямой проход
|
||||||
|
with torch.no_grad():
|
||||||
|
prediction = discriminator(yandex_image, google_image)
|
||||||
|
|
||||||
|
print(f"Размер выхода дискриминатора: {prediction.shape}")
|
||||||
|
print(
|
||||||
|
f"Диапазон значений предсказания: [{prediction.min():.3f}, {prediction.max():.3f}]"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Проверка размеров (PatchGAN выдает карту вероятностей)
|
||||||
|
expected_height = 43 # Для изображения 700x700 после 4 downsampling блоков
|
||||||
|
expected_width = 43
|
||||||
|
assert prediction.shape == (batch_size, 1, expected_height, expected_width), (
|
||||||
|
f"Ожидался размер {(batch_size, 1, expected_height, expected_width)}, получен {prediction.shape}"
|
||||||
|
)
|
||||||
|
|
||||||
|
print(
|
||||||
|
f"✓ Дискриминатор работает корректно! Выходной размер: {prediction.shape[2]}x{prediction.shape[3]}"
|
||||||
|
)
|
||||||
|
return discriminator
|
||||||
|
|
||||||
|
|
||||||
|
def test_gan_model():
|
||||||
|
"""Тестирование полной GAN модели"""
|
||||||
|
print("\n" + "=" * 60)
|
||||||
|
print("Тестирование полной GAN модели...")
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
# Создаем GAN модель
|
||||||
|
gan = ImageGAN(
|
||||||
|
input_channels=3,
|
||||||
|
output_channels=3,
|
||||||
|
gan_mode="vanilla",
|
||||||
|
lambda_L1=100.0,
|
||||||
|
use_cuda=False, # Тестируем на CPU для простоты
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"Устройство модели: {gan.device}")
|
||||||
|
print(
|
||||||
|
f"Количество параметров генератора: {sum(p.numel() for p in gan.generator.parameters()):,}"
|
||||||
|
)
|
||||||
|
print(
|
||||||
|
f"Количество параметров дискриминатора: {sum(p.numel() for p in gan.discriminator.parameters()):,}"
|
||||||
|
)
|
||||||
|
print(f"Общее количество параметров: {sum(p.numel() for p in gan.parameters()):,}")
|
||||||
|
|
||||||
|
# Создаем тестовые данные
|
||||||
|
batch_size = 2
|
||||||
|
height, width = 700, 700
|
||||||
|
|
||||||
|
yandex_image = torch.randn(batch_size, 3, height, width)
|
||||||
|
real_google_image = torch.randn(batch_size, 3, height, width)
|
||||||
|
|
||||||
|
print(f"\nТестирование прямого прохода...")
|
||||||
|
with torch.no_grad():
|
||||||
|
generated_image = gan(yandex_image)
|
||||||
|
|
||||||
|
print(f"Размер сгенерированного изображения: {generated_image.shape}")
|
||||||
|
|
||||||
|
print(f"\nТестирование шага генератора...")
|
||||||
|
gan.train_mode()
|
||||||
|
|
||||||
|
# Тестируем шаг генератора
|
||||||
|
total_loss, gan_loss, l1_loss = gan.generator_step(yandex_image, real_google_image)
|
||||||
|
|
||||||
|
print(f"Общие потери генератора: {total_loss.item():.6f}")
|
||||||
|
print(f"Потери GAN: {gan_loss.item():.6f}")
|
||||||
|
print(f"Потери L1: {l1_loss.item():.6f}")
|
||||||
|
|
||||||
|
print(f"\nТестирование шага дискриминатора...")
|
||||||
|
# Создаем сгенерированное изображение для дискриминатора
|
||||||
|
with torch.no_grad():
|
||||||
|
fake_google_image = gan.generator(yandex_image)
|
||||||
|
|
||||||
|
total_d_loss, real_loss, fake_loss = gan.discriminator_step(
|
||||||
|
yandex_image, real_google_image, fake_google_image
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"Общие потери дискриминатора: {total_d_loss.item():.6f}")
|
||||||
|
print(f"Потери на реальных изображениях: {real_loss.item():.6f}")
|
||||||
|
print(f"Потери на сгенерированных изображениях: {fake_loss.item():.6f}")
|
||||||
|
|
||||||
|
print(f"\nТестирование режимов обучения/оценки...")
|
||||||
|
gan.eval_mode()
|
||||||
|
print(f"Генератор в режиме eval: {not gan.generator.training}")
|
||||||
|
print(f"Дискриминатор в режиме eval: {not gan.discriminator.training}")
|
||||||
|
|
||||||
|
gan.train_mode()
|
||||||
|
print(f"Генератор в режиме train: {gan.generator.training}")
|
||||||
|
print(f"Дискриминатор в режиме train: {gan.discriminator.training}")
|
||||||
|
|
||||||
|
print("\n✓ Полная GAN модель работает корректно!")
|
||||||
|
return gan
|
||||||
|
|
||||||
|
|
||||||
|
def test_factory_function():
|
||||||
|
"""Тестирование фабричной функции"""
|
||||||
|
print("\n" + "=" * 60)
|
||||||
|
print("Тестирование фабричной функции...")
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
# Тестируем разные режимы GAN
|
||||||
|
for gan_mode in ["vanilla", "lsgan"]:
|
||||||
|
print(f"\nСоздание GAN в режиме '{gan_mode}'...")
|
||||||
|
gan = create_image_gan(
|
||||||
|
input_channels=3,
|
||||||
|
output_channels=3,
|
||||||
|
gan_mode=gan_mode,
|
||||||
|
lambda_L1=100.0,
|
||||||
|
use_cuda=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f" Режим GAN: {gan.gan_loss.gan_mode}")
|
||||||
|
print(f" Вес L1 потерь: {gan.lambda_L1}")
|
||||||
|
print(f" Устройство: {gan.device}")
|
||||||
|
|
||||||
|
# Быстрая проверка прямого прохода
|
||||||
|
batch_size = 1
|
||||||
|
yandex_image = torch.randn(batch_size, 3, 700, 700)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
generated = gan(yandex_image)
|
||||||
|
|
||||||
|
print(f" Размер выхода: {generated.shape}")
|
||||||
|
print(f" ✓ GAN в режиме '{gan_mode}' создан успешно")
|
||||||
|
|
||||||
|
print("\n✓ Фабричная функция работает корректно!")
|
||||||
|
|
||||||
|
|
||||||
|
def test_weights_initialization():
|
||||||
|
"""Тестирование инициализации весов"""
|
||||||
|
print("\n" + "=" * 60)
|
||||||
|
print("Тестирование инициализации весов...")
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
# Создаем модели
|
||||||
|
generator = GeneratorUNet(3, 3)
|
||||||
|
discriminator = DiscriminatorPatchGAN(6)
|
||||||
|
|
||||||
|
# Инициализируем веса
|
||||||
|
initialize_gan_weights(generator, discriminator)
|
||||||
|
|
||||||
|
# Проверяем средние значения весов
|
||||||
|
def check_weights_mean(model, model_name):
|
||||||
|
conv_weights = []
|
||||||
|
for name, param in model.named_parameters():
|
||||||
|
if "weight" in name and (
|
||||||
|
"conv" in name.lower() or "Conv" in str(param.__class__)
|
||||||
|
):
|
||||||
|
conv_weights.append(param.data.mean().item())
|
||||||
|
|
||||||
|
if conv_weights:
|
||||||
|
avg_mean = sum(conv_weights) / len(conv_weights)
|
||||||
|
print(f" Среднее значение весов Conv слоев в {model_name}: {avg_mean:.6f}")
|
||||||
|
# Проверяем, что веса инициализированы около 0
|
||||||
|
assert abs(avg_mean) < 0.1, f"Веса {model_name} не инициализированы около 0"
|
||||||
|
|
||||||
|
check_weights_mean(generator, "генераторе")
|
||||||
|
check_weights_mean(discriminator, "дискриминаторе")
|
||||||
|
|
||||||
|
print("✓ Инициализация весов работает корректно!")
|
||||||
|
|
||||||
|
|
||||||
|
def test_memory_usage():
|
||||||
|
"""Тестирование использования памяти"""
|
||||||
|
print("\n" + "=" * 60)
|
||||||
|
print("Тестирование использования памяти...")
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
import os
|
||||||
|
|
||||||
|
import psutil
|
||||||
|
|
||||||
|
# Получаем текущее использование памяти
|
||||||
|
process = psutil.Process(os.getpid())
|
||||||
|
memory_before = process.memory_info().rss / 1024 / 1024 # в MB
|
||||||
|
|
||||||
|
print(f"Память до создания моделей: {memory_before:.2f} MB")
|
||||||
|
|
||||||
|
# Создаем несколько моделей
|
||||||
|
models = []
|
||||||
|
for i in range(3):
|
||||||
|
gan = create_image_gan(use_cuda=False)
|
||||||
|
models.append(gan)
|
||||||
|
|
||||||
|
# Делаем тестовый проход
|
||||||
|
batch_size = 1
|
||||||
|
yandex_image = torch.randn(batch_size, 3, 700, 700)
|
||||||
|
real_google_image = torch.randn(batch_size, 3, 700, 700)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
_ = gan(yandex_image)
|
||||||
|
_ = gan.generator_step(yandex_image, real_google_image)
|
||||||
|
|
||||||
|
memory_after = process.memory_info().rss / 1024 / 1024 # в MB
|
||||||
|
memory_used = memory_after - memory_before
|
||||||
|
|
||||||
|
print(f"Память после создания моделей: {memory_after:.2f} MB")
|
||||||
|
print(f"Использовано памяти: {memory_used:.2f} MB")
|
||||||
|
|
||||||
|
# Очищаем модели
|
||||||
|
del models
|
||||||
|
import gc
|
||||||
|
|
||||||
|
gc.collect()
|
||||||
|
|
||||||
|
memory_final = process.memory_info().rss / 1024 / 1024
|
||||||
|
print(f"Память после очистки: {memory_final:.2f} MB")
|
||||||
|
|
||||||
|
print("✓ Тестирование памяти завершено!")
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
"""Основная функция тестирования"""
|
||||||
|
print("Начало тестирования GAN архитектуры для преобразования Yandex → Google")
|
||||||
|
print("Размер изображения: 700x700 пикселей")
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Запускаем все тесты
|
||||||
|
test_generator()
|
||||||
|
test_discriminator()
|
||||||
|
test_gan_model()
|
||||||
|
test_factory_function()
|
||||||
|
test_weights_initialization()
|
||||||
|
test_memory_usage()
|
||||||
|
|
||||||
|
print("\n" + "=" * 60)
|
||||||
|
print("ВСЕ ТЕСТЫ ПРОЙДЕНЫ УСПЕШНО! 🎉")
|
||||||
|
print("=" * 60)
|
||||||
|
print("\nАрхитектура GAN готова к использованию для преобразования")
|
||||||
|
print("изображений из стиля Yandex в стиль Google.")
|
||||||
|
print("\nОсновные характеристики:")
|
||||||
|
print(" • Генератор: U-Net архитектура")
|
||||||
|
print(" • Дискриминатор: PatchGAN (43x43 патчей)")
|
||||||
|
print(" • Размер входных/выходных изображений: 700x700")
|
||||||
|
print(" • Поддержка режимов: vanilla, lsgan")
|
||||||
|
print(" • L1 регуляризация для сохранения структуры")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"\n❌ Ошибка при тестировании: {e}")
|
||||||
|
import traceback
|
||||||
|
|
||||||
|
traceback.print_exc()
|
||||||
|
return 1
|
||||||
|
|
||||||
|
return 0
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
exit_code = main()
|
||||||
|
sys.exit(exit_code)
|
||||||
342
models/GAN/test_trainer.py
Normal file
342
models/GAN/test_trainer.py
Normal file
@@ -0,0 +1,342 @@
|
|||||||
|
"""
|
||||||
|
Тестовый скрипт для проверки GAN trainer.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from torch.utils.data import DataLoader, Dataset
|
||||||
|
|
||||||
|
# Добавляем путь к модулям
|
||||||
|
sys.path.append(str(Path(__file__).parent.parent.parent))
|
||||||
|
|
||||||
|
from models.GAN.gan import create_image_gan
|
||||||
|
from models.GAN.trainer import GANTrainer
|
||||||
|
|
||||||
|
|
||||||
|
class SimpleDataset(Dataset):
|
||||||
|
"""Простой датасет для тестирования."""
|
||||||
|
|
||||||
|
def __init__(self, num_samples=100, image_size=(256, 256)):
|
||||||
|
self.num_samples = num_samples
|
||||||
|
self.image_size = image_size
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return self.num_samples
|
||||||
|
|
||||||
|
def __getitem__(self, idx):
|
||||||
|
# Создаем случайные изображения для тестирования
|
||||||
|
yandex_img = torch.randn(3, self.image_size[0], self.image_size[1])
|
||||||
|
google_img = torch.randn(3, self.image_size[0], self.image_size[1])
|
||||||
|
|
||||||
|
return {"yandex_img": yandex_img, "google_img": google_img}
|
||||||
|
|
||||||
|
|
||||||
|
def test_gan_model():
|
||||||
|
"""Тестирование GAN модели."""
|
||||||
|
print("Тестирование GAN модели...")
|
||||||
|
|
||||||
|
# Создаем модель
|
||||||
|
model = create_image_gan(
|
||||||
|
input_channels=3,
|
||||||
|
output_channels=3,
|
||||||
|
gan_mode="vanilla",
|
||||||
|
lambda_L1=100.0,
|
||||||
|
use_cuda=False, # Используем CPU для тестирования
|
||||||
|
)
|
||||||
|
|
||||||
|
# Тестируем forward pass
|
||||||
|
batch_size = 2
|
||||||
|
image_size = (256, 256)
|
||||||
|
yandex_input = torch.randn(batch_size, 3, *image_size)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
output = model(yandex_input)
|
||||||
|
|
||||||
|
print(f"Входной размер: {yandex_input.shape}")
|
||||||
|
print(f"Выходной размер: {output.shape}")
|
||||||
|
print(f"Диапазон выходных значений: [{output.min():.3f}, {output.max():.3f}]")
|
||||||
|
|
||||||
|
# Проверяем, что выход в диапазоне [-1, 1] (из-за Tanh)
|
||||||
|
assert output.min() >= -1.0 and output.max() <= 1.0, "Выход не в диапазоне [-1, 1]"
|
||||||
|
print("✓ Forward pass работает корректно")
|
||||||
|
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def test_generator_step():
|
||||||
|
"""Тестирование шага генератора."""
|
||||||
|
print("\nТестирование шага генератора...")
|
||||||
|
|
||||||
|
model = create_image_gan(use_cuda=False)
|
||||||
|
model.train_mode()
|
||||||
|
|
||||||
|
batch_size = 2
|
||||||
|
yandex_img = torch.randn(batch_size, 3, 256, 256)
|
||||||
|
google_img = torch.randn(batch_size, 3, 256, 256)
|
||||||
|
|
||||||
|
# Тестируем generator_step
|
||||||
|
total_loss, gan_loss, l1_loss = model.generator_step(yandex_img, google_img)
|
||||||
|
|
||||||
|
print(f"Total loss: {total_loss.item():.6f}")
|
||||||
|
print(f"GAN loss: {gan_loss.item():.6f}")
|
||||||
|
print(f"L1 loss: {l1_loss.item():.6f}")
|
||||||
|
|
||||||
|
assert total_loss.item() > 0, "Потери должны быть положительными"
|
||||||
|
print("✓ Шаг генератора работает корректно")
|
||||||
|
|
||||||
|
|
||||||
|
def test_discriminator_step():
|
||||||
|
"""Тестирование шага дискриминатора."""
|
||||||
|
print("\nТестирование шага дискриминатора...")
|
||||||
|
|
||||||
|
model = create_image_gan(use_cuda=False)
|
||||||
|
model.train_mode()
|
||||||
|
|
||||||
|
batch_size = 2
|
||||||
|
yandex_img = torch.randn(batch_size, 3, 256, 256)
|
||||||
|
google_img = torch.randn(batch_size, 3, 256, 256)
|
||||||
|
|
||||||
|
# Генерируем fake изображение
|
||||||
|
with torch.no_grad():
|
||||||
|
fake_google_img = model.generator(yandex_img)
|
||||||
|
|
||||||
|
# Тестируем discriminator_step
|
||||||
|
total_loss, real_loss, fake_loss = model.discriminator_step(
|
||||||
|
yandex_img, google_img, fake_google_img
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"Total loss: {total_loss.item():.6f}")
|
||||||
|
print(f"Real loss: {real_loss.item():.6f}")
|
||||||
|
print(f"Fake loss: {fake_loss.item():.6f}")
|
||||||
|
|
||||||
|
assert total_loss.item() > 0, "Потери должны быть положительными"
|
||||||
|
print("✓ Шаг дискриминатора работает корректно")
|
||||||
|
|
||||||
|
|
||||||
|
def test_trainer_initialization():
|
||||||
|
"""Тестирование инициализации тренера."""
|
||||||
|
print("\nТестирование инициализации тренера...")
|
||||||
|
|
||||||
|
# Создаем модель
|
||||||
|
model = create_image_gan(use_cuda=False)
|
||||||
|
|
||||||
|
# Создаем даталоадеры
|
||||||
|
train_dataset = SimpleDataset(num_samples=50)
|
||||||
|
val_dataset = SimpleDataset(num_samples=10)
|
||||||
|
|
||||||
|
train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True)
|
||||||
|
val_loader = DataLoader(val_dataset, batch_size=2, shuffle=False)
|
||||||
|
|
||||||
|
# Конфигурация
|
||||||
|
config = {
|
||||||
|
"learning_rate": 2e-4,
|
||||||
|
"beta1": 0.5,
|
||||||
|
"beta2": 0.999,
|
||||||
|
"output_dir": "test_runs/gan",
|
||||||
|
"early_stopping_patience": 10,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Создаем тренер
|
||||||
|
device = torch.device("cpu")
|
||||||
|
trainer = GANTrainer(
|
||||||
|
model=model,
|
||||||
|
train_loader=train_loader,
|
||||||
|
val_loader=val_loader,
|
||||||
|
device=device,
|
||||||
|
config=config,
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"Тренер создан успешно")
|
||||||
|
print(f"Оптимизатор генератора: {type(trainer.optimizer_G).__name__}")
|
||||||
|
print(f"Оптимизатор дискриминатора: {type(trainer.optimizer_D).__name__}")
|
||||||
|
print(f"Выходная директория: {trainer.output_dir}")
|
||||||
|
|
||||||
|
assert trainer.output_dir.exists(), "Выходная директория не создана"
|
||||||
|
print("✓ Тренер инициализирован корректно")
|
||||||
|
|
||||||
|
return trainer, train_loader, val_loader
|
||||||
|
|
||||||
|
|
||||||
|
def test_train_epoch():
|
||||||
|
"""Тестирование одной эпохи обучения."""
|
||||||
|
print("\nТестирование одной эпохи обучения...")
|
||||||
|
|
||||||
|
model = create_image_gan(use_cuda=False)
|
||||||
|
train_dataset = SimpleDataset(num_samples=20)
|
||||||
|
train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True)
|
||||||
|
val_loader = DataLoader(SimpleDataset(num_samples=5), batch_size=2, shuffle=False)
|
||||||
|
|
||||||
|
config = {
|
||||||
|
"learning_rate": 2e-4,
|
||||||
|
"beta1": 0.5,
|
||||||
|
"beta2": 0.999,
|
||||||
|
"output_dir": "test_runs/gan",
|
||||||
|
}
|
||||||
|
|
||||||
|
device = torch.device("cpu")
|
||||||
|
trainer = GANTrainer(
|
||||||
|
model=model,
|
||||||
|
train_loader=train_loader,
|
||||||
|
val_loader=val_loader,
|
||||||
|
device=device,
|
||||||
|
config=config,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Запускаем одну эпоху обучения
|
||||||
|
avg_g_loss, avg_d_loss = trainer.train_epoch()
|
||||||
|
|
||||||
|
print(f"Средние потери за эпоху:")
|
||||||
|
print(f" Генератор: {avg_g_loss:.6f}")
|
||||||
|
print(f" Дискриминатор: {avg_d_loss:.6f}")
|
||||||
|
|
||||||
|
assert avg_g_loss > 0, "Потери генератора должны быть положительными"
|
||||||
|
assert avg_d_loss > 0, "Потери дискриминатора должны быть положительными"
|
||||||
|
print("✓ Эпоха обучения завершена успешно")
|
||||||
|
|
||||||
|
|
||||||
|
def test_validation():
|
||||||
|
"""Тестирование валидации."""
|
||||||
|
print("\nТестирование валидации...")
|
||||||
|
|
||||||
|
model = create_image_gan(use_cuda=False)
|
||||||
|
train_loader = DataLoader(SimpleDataset(num_samples=10), batch_size=2, shuffle=True)
|
||||||
|
val_loader = DataLoader(SimpleDataset(num_samples=5), batch_size=2, shuffle=False)
|
||||||
|
|
||||||
|
config = {
|
||||||
|
"learning_rate": 2e-4,
|
||||||
|
"beta1": 0.5,
|
||||||
|
"beta2": 0.999,
|
||||||
|
"output_dir": "test_runs/gan",
|
||||||
|
}
|
||||||
|
|
||||||
|
device = torch.device("cpu")
|
||||||
|
trainer = GANTrainer(
|
||||||
|
model=model,
|
||||||
|
train_loader=train_loader,
|
||||||
|
val_loader=val_loader,
|
||||||
|
device=device,
|
||||||
|
config=config,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Запускаем валидацию
|
||||||
|
val_g_loss, val_d_loss = trainer.validate()
|
||||||
|
|
||||||
|
print(f"Потери на валидации:")
|
||||||
|
print(f" Генератор: {val_g_loss:.6f}")
|
||||||
|
print(f" Дискриминатор: {val_d_loss:.6f}")
|
||||||
|
|
||||||
|
assert val_g_loss > 0, "Потери генератора должны быть положительными"
|
||||||
|
assert val_d_loss > 0, "Потери дискриминатора должны быть положительными"
|
||||||
|
print("✓ Валидация завершена успешно")
|
||||||
|
|
||||||
|
|
||||||
|
def test_checkpoint_saving():
|
||||||
|
"""Тестирование сохранения чекпоинтов."""
|
||||||
|
print("\nТестирование сохранения чекпоинтов...")
|
||||||
|
|
||||||
|
model = create_image_gan(use_cuda=False)
|
||||||
|
train_loader = DataLoader(SimpleDataset(num_samples=10), batch_size=2, shuffle=True)
|
||||||
|
val_loader = DataLoader(SimpleDataset(num_samples=5), batch_size=2, shuffle=False)
|
||||||
|
|
||||||
|
config = {
|
||||||
|
"learning_rate": 2e-4,
|
||||||
|
"beta1": 0.5,
|
||||||
|
"beta2": 0.999,
|
||||||
|
"output_dir": "test_runs/gan_checkpoint",
|
||||||
|
}
|
||||||
|
|
||||||
|
device = torch.device("cpu")
|
||||||
|
trainer = GANTrainer(
|
||||||
|
model=model,
|
||||||
|
train_loader=train_loader,
|
||||||
|
val_loader=val_loader,
|
||||||
|
device=device,
|
||||||
|
config=config,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Сохраняем чекпоинт
|
||||||
|
trainer.save_checkpoint(is_best=True)
|
||||||
|
|
||||||
|
# Проверяем, что файлы созданы
|
||||||
|
checkpoint_files = list(trainer.output_dir.glob("*.pth"))
|
||||||
|
print(f"Создано файлов чекпоинтов: {len(checkpoint_files)}")
|
||||||
|
|
||||||
|
for file in checkpoint_files:
|
||||||
|
print(f" - {file.name}")
|
||||||
|
|
||||||
|
assert len(checkpoint_files) > 0, "Файлы чекпоинтов не созданы"
|
||||||
|
print("✓ Чекпоинты сохранены успешно")
|
||||||
|
|
||||||
|
# Тестируем загрузку чекпоинта
|
||||||
|
checkpoint_path = checkpoint_files[0]
|
||||||
|
print(f"\nТестируем загрузку чекпоинта: {checkpoint_path}")
|
||||||
|
|
||||||
|
# Создаем новую модель и тренер
|
||||||
|
new_model = create_image_gan(use_cuda=False)
|
||||||
|
new_trainer = GANTrainer(
|
||||||
|
model=new_model,
|
||||||
|
train_loader=train_loader,
|
||||||
|
val_loader=val_loader,
|
||||||
|
device=device,
|
||||||
|
config=config,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Загружаем чекпоинт
|
||||||
|
new_trainer.load_checkpoint(str(checkpoint_path))
|
||||||
|
|
||||||
|
print(f"Загружен чекпоинт эпохи: {new_trainer.current_epoch + 1}")
|
||||||
|
print("✓ Чекпоинт загружен успешно")
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
"""Основная функция тестирования."""
|
||||||
|
print("=" * 60)
|
||||||
|
print("Начало тестирования GAN trainer")
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Запускаем все тесты
|
||||||
|
test_gan_model()
|
||||||
|
test_generator_step()
|
||||||
|
test_discriminator_step()
|
||||||
|
test_trainer_initialization()
|
||||||
|
test_train_epoch()
|
||||||
|
test_validation()
|
||||||
|
test_checkpoint_saving()
|
||||||
|
|
||||||
|
print("\n" + "=" * 60)
|
||||||
|
print("Все тесты пройдены успешно! ✓")
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"\nОшибка при тестировании: {e}")
|
||||||
|
import traceback
|
||||||
|
|
||||||
|
traceback.print_exc()
|
||||||
|
return 1
|
||||||
|
|
||||||
|
return 0
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
# Очищаем тестовые директории
|
||||||
|
import shutil
|
||||||
|
|
||||||
|
test_dirs = ["test_runs/gan", "test_runs/gan_checkpoint"]
|
||||||
|
for dir_path in test_dirs:
|
||||||
|
if Path(dir_path).exists():
|
||||||
|
shutil.rmtree(dir_path)
|
||||||
|
|
||||||
|
# Запускаем тесты
|
||||||
|
exit_code = main()
|
||||||
|
|
||||||
|
# Очищаем после тестов
|
||||||
|
for dir_path in test_dirs:
|
||||||
|
if Path(dir_path).exists():
|
||||||
|
shutil.rmtree(dir_path)
|
||||||
|
|
||||||
|
exit(exit_code)
|
||||||
347
models/GAN/train_example.py
Normal file
347
models/GAN/train_example.py
Normal file
@@ -0,0 +1,347 @@
|
|||||||
|
"""
|
||||||
|
Пример обучения GAN модели для преобразования Yandex → Google карт.
|
||||||
|
|
||||||
|
Этот скрипт показывает, как использовать GANTrainer для обучения модели.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
|
# Добавляем путь к модулям
|
||||||
|
sys.path.append(str(Path(__file__).parent.parent.parent))
|
||||||
|
|
||||||
|
from models.GAN.gan import create_image_gan
|
||||||
|
from models.GAN.trainer import GANTrainer
|
||||||
|
|
||||||
|
|
||||||
|
def create_simple_config():
|
||||||
|
"""Создает простую конфигурацию для обучения."""
|
||||||
|
config = {
|
||||||
|
# Параметры оптимизатора
|
||||||
|
"learning_rate": 2e-4,
|
||||||
|
"beta1": 0.5,
|
||||||
|
"beta2": 0.999,
|
||||||
|
# Параметры обучения
|
||||||
|
"batch_size": 4,
|
||||||
|
"epochs": 100,
|
||||||
|
# Параметры GAN
|
||||||
|
"gan_mode": "vanilla", # "vanilla", "lsgan", или "wgangp"
|
||||||
|
"lambda_L1": 100.0, # Вес L1 потерь
|
||||||
|
# Регуляризация
|
||||||
|
"grad_clip": 1.0,
|
||||||
|
# Ранняя остановка
|
||||||
|
"early_stopping_patience": 20,
|
||||||
|
# Выходные данные
|
||||||
|
"output_dir": "runs/gan_training",
|
||||||
|
# Логирование
|
||||||
|
"log_interval": 10, # Логировать каждые N батчей
|
||||||
|
"save_interval": 5, # Сохранять чекпоинт каждые N эпох
|
||||||
|
}
|
||||||
|
return config
|
||||||
|
|
||||||
|
|
||||||
|
def create_advanced_config():
|
||||||
|
"""Создает расширенную конфигурацию для обучения."""
|
||||||
|
config = {
|
||||||
|
# Параметры оптимизатора
|
||||||
|
"learning_rate": 2e-4,
|
||||||
|
"beta1": 0.5,
|
||||||
|
"beta2": 0.999,
|
||||||
|
# Планировщик learning rate
|
||||||
|
"use_scheduler": True,
|
||||||
|
"scheduler_type": "linear", # "linear", "cosine", или "plateau"
|
||||||
|
"scheduler_start_epoch": 50,
|
||||||
|
"scheduler_end_epoch": 100,
|
||||||
|
# Параметры обучения
|
||||||
|
"batch_size": 8,
|
||||||
|
"epochs": 200,
|
||||||
|
# Параметры GAN
|
||||||
|
"gan_mode": "lsgan", # LSGAN обычно более стабилен
|
||||||
|
"lambda_L1": 100.0,
|
||||||
|
# Аугментация данных
|
||||||
|
"augmentation": {
|
||||||
|
"random_crop": True,
|
||||||
|
"crop_size": 256,
|
||||||
|
"random_flip": True,
|
||||||
|
"color_jitter": True,
|
||||||
|
"brightness": 0.2,
|
||||||
|
"contrast": 0.2,
|
||||||
|
"saturation": 0.2,
|
||||||
|
"hue": 0.1,
|
||||||
|
},
|
||||||
|
# Регуляризация
|
||||||
|
"grad_clip": 1.0,
|
||||||
|
"weight_decay": 1e-4,
|
||||||
|
# Ранняя остановка
|
||||||
|
"early_stopping_patience": 30,
|
||||||
|
"early_stopping_min_delta": 1e-4,
|
||||||
|
# Выходные данные
|
||||||
|
"output_dir": "runs/gan_advanced",
|
||||||
|
# Логирование
|
||||||
|
"log_interval": 20,
|
||||||
|
"save_interval": 10,
|
||||||
|
"save_best_only": True, # Сохранять только лучшую модель
|
||||||
|
# Визуализация
|
||||||
|
"visualize_samples": True,
|
||||||
|
"num_visualize": 4,
|
||||||
|
"visualize_interval": 5, # Визуализировать каждые N эпох
|
||||||
|
}
|
||||||
|
return config
|
||||||
|
|
||||||
|
|
||||||
|
def print_config_summary(config):
|
||||||
|
"""Печатает сводку конфигурации."""
|
||||||
|
print("=" * 60)
|
||||||
|
print("Конфигурация обучения GAN")
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
print(f"\nПараметры модели:")
|
||||||
|
print(f" Режим GAN: {config.get('gan_mode', 'vanilla')}")
|
||||||
|
print(f" Вес L1 потерь: {config.get('lambda_L1', 100.0)}")
|
||||||
|
|
||||||
|
print(f"\nПараметры обучения:")
|
||||||
|
print(f" Learning rate: {config.get('learning_rate', 2e-4)}")
|
||||||
|
print(f" Batch size: {config.get('batch_size', 4)}")
|
||||||
|
print(f" Эпох: {config.get('epochs', 100)}")
|
||||||
|
print(f" Beta1: {config.get('beta1', 0.5)}")
|
||||||
|
print(f" Beta2: {config.get('beta2', 0.999)}")
|
||||||
|
|
||||||
|
if config.get("use_scheduler", False):
|
||||||
|
print(f" Планировщик LR: {config.get('scheduler_type', 'linear')}")
|
||||||
|
|
||||||
|
print(f"\nРегуляризация:")
|
||||||
|
print(f" Gradient clipping: {config.get('grad_clip', 1.0)}")
|
||||||
|
if "weight_decay" in config:
|
||||||
|
print(f" Weight decay: {config['weight_decay']}")
|
||||||
|
|
||||||
|
print(f"\nРанняя остановка:")
|
||||||
|
if config.get("early_stopping_patience", 0) > 0:
|
||||||
|
print(f" Patience: {config['early_stopping_patience']} эпох")
|
||||||
|
if "early_stopping_min_delta" in config:
|
||||||
|
print(f" Min delta: {config['early_stopping_min_delta']}")
|
||||||
|
|
||||||
|
print(f"\nВыходные данные:")
|
||||||
|
print(f" Директория: {config.get('output_dir', 'runs/gan')}")
|
||||||
|
print(f" Интервал сохранения: {config.get('save_interval', 5)} эпох")
|
||||||
|
|
||||||
|
print(f"\nЛогирование:")
|
||||||
|
print(f" Интервал логирования: {config.get('log_interval', 10)} батчей")
|
||||||
|
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
|
||||||
|
def setup_training():
|
||||||
|
"""Настраивает обучение."""
|
||||||
|
print("Настройка обучения GAN...")
|
||||||
|
|
||||||
|
# Выбираем конфигурацию
|
||||||
|
use_advanced = False # Измените на True для расширенной конфигурации
|
||||||
|
|
||||||
|
if use_advanced:
|
||||||
|
config = create_advanced_config()
|
||||||
|
else:
|
||||||
|
config = create_simple_config()
|
||||||
|
|
||||||
|
# Печатаем сводку конфигурации
|
||||||
|
print_config_summary(config)
|
||||||
|
|
||||||
|
# Устройство
|
||||||
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
print(f"\nИспользуемое устройство: {device}")
|
||||||
|
|
||||||
|
if device.type == "cuda":
|
||||||
|
print(f" GPU: {torch.cuda.get_device_name(0)}")
|
||||||
|
print(
|
||||||
|
f" Память: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Создаем модель
|
||||||
|
print("\nСоздание модели...")
|
||||||
|
model = create_image_gan(
|
||||||
|
input_channels=3,
|
||||||
|
output_channels=3,
|
||||||
|
gan_mode=config.get("gan_mode", "vanilla"),
|
||||||
|
lambda_L1=config.get("lambda_L1", 100.0),
|
||||||
|
use_cuda=(device.type == "cuda"),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Создаем даталоадеры
|
||||||
|
print("\nСоздание даталоадеров...")
|
||||||
|
# ЗАМЕНИТЕ ЭТО НА ВАШИ РЕАЛЬНЫЕ ДАННЫЕ
|
||||||
|
# Пример:
|
||||||
|
# from your_dataset_module import create_data_loaders
|
||||||
|
# train_loader, val_loader = create_data_loaders(
|
||||||
|
# data_dir="ваш/путь/к/данным",
|
||||||
|
# batch_size=config["batch_size"],
|
||||||
|
# image_size=(256, 256),
|
||||||
|
# augment=config.get("augmentation", None),
|
||||||
|
# )
|
||||||
|
|
||||||
|
# Для примера создаем фиктивные даталоадеры
|
||||||
|
# ВАЖНО: Замените это на реальные данные!
|
||||||
|
print(" ВНИМАНИЕ: Используются фиктивные данные!")
|
||||||
|
print(" Замените на реальные даталоадеры!")
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from torch.utils.data import Dataset
|
||||||
|
|
||||||
|
class DummyDataset(Dataset):
|
||||||
|
def __init__(self, num_samples=100):
|
||||||
|
self.num_samples = num_samples
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return self.num_samples
|
||||||
|
|
||||||
|
def __getitem__(self, idx):
|
||||||
|
# Фиктивные данные для примера
|
||||||
|
yandex_img = torch.randn(3, 256, 256)
|
||||||
|
google_img = torch.randn(3, 256, 256)
|
||||||
|
return {"yandex_img": yandex_img, "google_img": google_img}
|
||||||
|
|
||||||
|
train_dataset = DummyDataset(num_samples=100)
|
||||||
|
val_dataset = DummyDataset(num_samples=20)
|
||||||
|
|
||||||
|
train_loader = DataLoader(
|
||||||
|
train_dataset,
|
||||||
|
batch_size=config.get("batch_size", 4),
|
||||||
|
shuffle=True,
|
||||||
|
num_workers=0,
|
||||||
|
)
|
||||||
|
|
||||||
|
val_loader = DataLoader(
|
||||||
|
val_dataset,
|
||||||
|
batch_size=config.get("batch_size", 4),
|
||||||
|
shuffle=False,
|
||||||
|
num_workers=0,
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f" Размер обучающего набора: {len(train_dataset)}")
|
||||||
|
print(f" Размер валидационного набора: {len(val_dataset)}")
|
||||||
|
print(f" Батчей в эпохе: {len(train_loader)}")
|
||||||
|
|
||||||
|
# Создаем тренер
|
||||||
|
print("\nСоздание тренера...")
|
||||||
|
trainer = GANTrainer(
|
||||||
|
model=model,
|
||||||
|
train_loader=train_loader,
|
||||||
|
val_loader=val_loader,
|
||||||
|
device=device,
|
||||||
|
config=config,
|
||||||
|
)
|
||||||
|
|
||||||
|
return trainer, config
|
||||||
|
|
||||||
|
|
||||||
|
def train_model(trainer, config):
|
||||||
|
"""Запускает обучение модели."""
|
||||||
|
print("\n" + "=" * 60)
|
||||||
|
print("Начало обучения")
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
epochs = config.get("epochs", 100)
|
||||||
|
|
||||||
|
try:
|
||||||
|
trainer.train(num_epochs=epochs)
|
||||||
|
|
||||||
|
print("\n" + "=" * 60)
|
||||||
|
print("Обучение завершено успешно!")
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
print("\n\nОбучение прервано пользователем.")
|
||||||
|
print("Сохранение текущего состояния...")
|
||||||
|
trainer.save_checkpoint(is_best=False)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"\n\nОшибка при обучении: {e}")
|
||||||
|
import traceback
|
||||||
|
|
||||||
|
traceback.print_exc()
|
||||||
|
|
||||||
|
# Пытаемся сохранить чекпоинт при ошибке
|
||||||
|
try:
|
||||||
|
trainer.save_checkpoint(is_best=False)
|
||||||
|
print("Текущее состояние сохранено.")
|
||||||
|
except:
|
||||||
|
print("Не удалось сохранить состояние.")
|
||||||
|
|
||||||
|
|
||||||
|
def evaluate_model(trainer, test_loader=None):
|
||||||
|
"""Оценивает обученную модель."""
|
||||||
|
print("\n" + "=" * 60)
|
||||||
|
print("Оценка модели")
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
if test_loader is None:
|
||||||
|
print("Тестовый даталоадер не предоставлен.")
|
||||||
|
print("Используется валидационный даталоадер для оценки.")
|
||||||
|
test_loader = trainer.val_loader
|
||||||
|
|
||||||
|
metrics = trainer.evaluate(test_loader)
|
||||||
|
|
||||||
|
print("\nМетрики оценки:")
|
||||||
|
for key, value in metrics.items():
|
||||||
|
print(f" {key}: {value:.6f}")
|
||||||
|
|
||||||
|
return metrics
|
||||||
|
|
||||||
|
|
||||||
|
def generate_examples(model, device, num_examples=4):
|
||||||
|
"""Генерирует примеры преобразования."""
|
||||||
|
print("\n" + "=" * 60)
|
||||||
|
print("Генерация примеров")
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
# Создаем фиктивные входные данные
|
||||||
|
yandex_input = torch.randn(num_examples, 3, 256, 256).to(device)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
google_output = model(yandex_input)
|
||||||
|
|
||||||
|
print(f"Сгенерировано {num_examples} примеров")
|
||||||
|
print(f"Размер входных данных: {yandex_input.shape}")
|
||||||
|
print(f"Размер выходных данных: {google_output.shape}")
|
||||||
|
|
||||||
|
# Сохраняем примеры (в реальном коде сохраняйте как изображения)
|
||||||
|
print("\nПримеры сгенерированы.")
|
||||||
|
print("В реальном коде сохраняйте их как изображения для визуализации.")
|
||||||
|
|
||||||
|
return yandex_input, google_output
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
"""Основная функция."""
|
||||||
|
print("=" * 60)
|
||||||
|
print("Пример обучения GAN для преобразования Yandex → Google")
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
# Настройка
|
||||||
|
trainer, config = setup_training()
|
||||||
|
|
||||||
|
# Обучение
|
||||||
|
train_model(trainer, config)
|
||||||
|
|
||||||
|
# Оценка (требует реальных тестовых данных)
|
||||||
|
# evaluate_model(trainer)
|
||||||
|
|
||||||
|
# Генерация примеров
|
||||||
|
# generate_examples(trainer.model, trainer.device)
|
||||||
|
|
||||||
|
print("\n" + "=" * 60)
|
||||||
|
print("Скрипт завершен.")
|
||||||
|
print("=" * 60)
|
||||||
|
print("\nСледующие шаги:")
|
||||||
|
print("1. Замените фиктивные даталоадеры на реальные данные")
|
||||||
|
print("2. Настройте параметры в create_simple_config()")
|
||||||
|
print("3. Запустите обучение с реальными данными")
|
||||||
|
print("4. Визуализируйте результаты")
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
415
models/GAN/trainer.py
Normal file
415
models/GAN/trainer.py
Normal file
@@ -0,0 +1,415 @@
|
|||||||
|
import json
|
||||||
|
import time
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any, Dict, List, Tuple
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.optim as optim
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
# Type aliases
|
||||||
|
ModuleType = nn.Module
|
||||||
|
TensorType = torch.Tensor
|
||||||
|
|
||||||
|
|
||||||
|
class GANTrainer:
|
||||||
|
"""Trainer class for GAN model."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model: ModuleType,
|
||||||
|
train_loader: DataLoader,
|
||||||
|
val_loader: DataLoader,
|
||||||
|
device: torch.device,
|
||||||
|
config: Dict[str, Any],
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Initialize the GAN trainer.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model: GAN model (ImageGAN)
|
||||||
|
train_loader: Training data loader
|
||||||
|
val_loader: Validation data loader
|
||||||
|
device: Device to run training on
|
||||||
|
config: Training configuration dictionary
|
||||||
|
"""
|
||||||
|
self.model = model.to(device)
|
||||||
|
self.train_loader = train_loader
|
||||||
|
self.val_loader = val_loader
|
||||||
|
self.device = device
|
||||||
|
self.config = config
|
||||||
|
|
||||||
|
# Optimizers
|
||||||
|
lr = config.get("learning_rate", 2e-4)
|
||||||
|
beta1 = config.get("beta1", 0.5)
|
||||||
|
beta2 = config.get("beta2", 0.999)
|
||||||
|
|
||||||
|
# Separate optimizers for generator and discriminator
|
||||||
|
# Note: self.model is expected to have .generator and .discriminator attributes
|
||||||
|
self.optimizer_G = optim.Adam(
|
||||||
|
self.model.generator.parameters(), lr=lr, betas=(beta1, beta2)
|
||||||
|
)
|
||||||
|
self.optimizer_D = optim.Adam(
|
||||||
|
self.model.discriminator.parameters(), lr=lr, betas=(beta1, beta2)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Training state
|
||||||
|
self.current_epoch = 0
|
||||||
|
self.best_val_loss = float("inf")
|
||||||
|
self.g_losses: List[float] = []
|
||||||
|
self.d_losses: List[float] = []
|
||||||
|
self.val_g_losses: List[float] = []
|
||||||
|
self.val_d_losses: List[float] = []
|
||||||
|
|
||||||
|
# Create output directory
|
||||||
|
self.output_dir = Path(config.get("output_dir", "runs/gan"))
|
||||||
|
self.output_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
# TensorBoard writer
|
||||||
|
self.writer = SummaryWriter(log_dir=self.output_dir / "tensorboard")
|
||||||
|
|
||||||
|
# Save configuration
|
||||||
|
config_path = self.output_dir / "config.json"
|
||||||
|
with open(config_path, "w") as f:
|
||||||
|
json.dump(config, f, indent=2)
|
||||||
|
|
||||||
|
print(f"Training configuration saved to {config_path}")
|
||||||
|
# Access parameters through the model's generator and discriminator
|
||||||
|
generator_params = sum(p.numel() for p in self.model.generator.parameters())
|
||||||
|
discriminator_params = sum(
|
||||||
|
p.numel() for p in self.model.discriminator.parameters()
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"Generator has {generator_params:,} parameters")
|
||||||
|
print(f"Discriminator has {discriminator_params:,} parameters")
|
||||||
|
|
||||||
|
def train_epoch(self) -> Tuple[float, float]:
|
||||||
|
"""
|
||||||
|
Train for one epoch.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (average generator loss, average discriminator loss)
|
||||||
|
"""
|
||||||
|
self.model.train()
|
||||||
|
total_g_loss = 0.0
|
||||||
|
total_d_loss = 0.0
|
||||||
|
num_batches = len(self.train_loader)
|
||||||
|
|
||||||
|
progress_bar = tqdm(self.train_loader, desc=f"Epoch {self.current_epoch + 1}")
|
||||||
|
for batch_idx, batch in enumerate(progress_bar):
|
||||||
|
# Move data to device
|
||||||
|
yandex_img = batch["yandex_img"].to(self.device)
|
||||||
|
google_img = batch["google_img"].to(self.device)
|
||||||
|
|
||||||
|
# ========== Train Discriminator ==========
|
||||||
|
self.optimizer_D.zero_grad()
|
||||||
|
|
||||||
|
# Generate fake image
|
||||||
|
with torch.no_grad():
|
||||||
|
fake_google_img = self.model.generator(yandex_img)
|
||||||
|
|
||||||
|
# Discriminator loss - returns tuple of tensors
|
||||||
|
d_loss_tuple = self.model.discriminator_step(
|
||||||
|
yandex_img, google_img, fake_google_img
|
||||||
|
)
|
||||||
|
d_loss, d_real_loss, d_fake_loss = d_loss_tuple
|
||||||
|
|
||||||
|
# Backward and optimize discriminator
|
||||||
|
d_loss.backward()
|
||||||
|
self.optimizer_D.step()
|
||||||
|
|
||||||
|
# ========== Train Generator ==========
|
||||||
|
self.optimizer_G.zero_grad()
|
||||||
|
|
||||||
|
# Generate fake image
|
||||||
|
fake_google_img = self.model.generator(yandex_img)
|
||||||
|
|
||||||
|
# Generator loss - returns tuple of tensors
|
||||||
|
g_loss_tuple = self.model.generator_step(yandex_img, google_img)
|
||||||
|
g_loss, g_gan_loss, g_l1_loss = g_loss_tuple
|
||||||
|
|
||||||
|
# Backward and optimize generator
|
||||||
|
g_loss.backward()
|
||||||
|
self.optimizer_G.step()
|
||||||
|
|
||||||
|
# Update statistics
|
||||||
|
total_g_loss += g_loss.item()
|
||||||
|
total_d_loss += d_loss.item()
|
||||||
|
|
||||||
|
# Update progress bar
|
||||||
|
progress_bar.set_postfix(
|
||||||
|
{
|
||||||
|
"g_loss": g_loss.item(),
|
||||||
|
"d_loss": d_loss.item(),
|
||||||
|
"g_l1": g_l1_loss.item(),
|
||||||
|
"d_real": d_real_loss.item(),
|
||||||
|
"d_fake": d_fake_loss.item(),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Log batch losses to TensorBoard
|
||||||
|
global_step = self.current_epoch * num_batches + batch_idx
|
||||||
|
self.writer.add_scalar("train/batch_g_loss", g_loss.item(), global_step)
|
||||||
|
self.writer.add_scalar("train/batch_d_loss", d_loss.item(), global_step)
|
||||||
|
self.writer.add_scalar(
|
||||||
|
"train/batch_g_l1_loss", g_l1_loss.item(), global_step
|
||||||
|
)
|
||||||
|
self.writer.add_scalar(
|
||||||
|
"train/batch_d_real_loss", d_real_loss.item(), global_step
|
||||||
|
)
|
||||||
|
self.writer.add_scalar(
|
||||||
|
"train/batch_d_fake_loss", d_fake_loss.item(), global_step
|
||||||
|
)
|
||||||
|
|
||||||
|
avg_g_loss = total_g_loss / num_batches
|
||||||
|
avg_d_loss = total_d_loss / num_batches
|
||||||
|
self.g_losses.append(avg_g_loss)
|
||||||
|
self.d_losses.append(avg_d_loss)
|
||||||
|
|
||||||
|
return avg_g_loss, avg_d_loss
|
||||||
|
|
||||||
|
def validate(self) -> Tuple[float, float]:
|
||||||
|
"""
|
||||||
|
Validate the model.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (average generator validation loss, average discriminator validation loss)
|
||||||
|
"""
|
||||||
|
self.model.eval()
|
||||||
|
total_g_loss = 0.0
|
||||||
|
total_d_loss = 0.0
|
||||||
|
|
||||||
|
progress_bar = tqdm(self.val_loader, desc="Validation")
|
||||||
|
for batch in progress_bar:
|
||||||
|
# Move data to device
|
||||||
|
yandex_img = batch["yandex_img"].to(self.device)
|
||||||
|
google_img = batch["google_img"].to(self.device)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
# Generate fake image
|
||||||
|
fake_google_img = self.model.generator(yandex_img)
|
||||||
|
|
||||||
|
# Generator loss - returns tuple of tensors
|
||||||
|
g_loss_tuple = self.model.generator_step(yandex_img, google_img)
|
||||||
|
g_loss, g_gan_loss, g_l1_loss = g_loss_tuple
|
||||||
|
|
||||||
|
# Discriminator loss - returns tuple of tensors
|
||||||
|
d_loss_tuple = self.model.discriminator_step(
|
||||||
|
yandex_img, google_img, fake_google_img
|
||||||
|
)
|
||||||
|
d_loss, d_real_loss, d_fake_loss = d_loss_tuple
|
||||||
|
|
||||||
|
# Update statistics
|
||||||
|
total_g_loss += g_loss.item()
|
||||||
|
total_d_loss += d_loss.item()
|
||||||
|
|
||||||
|
# Update progress bar
|
||||||
|
progress_bar.set_postfix({"g_loss": g_loss.item(), "d_loss": d_loss.item()})
|
||||||
|
|
||||||
|
avg_g_loss = total_g_loss / len(self.val_loader)
|
||||||
|
avg_d_loss = total_d_loss / len(self.val_loader)
|
||||||
|
self.val_g_losses.append(avg_g_loss)
|
||||||
|
self.val_d_losses.append(avg_d_loss)
|
||||||
|
|
||||||
|
return avg_g_loss, avg_d_loss
|
||||||
|
|
||||||
|
def save_checkpoint(self, is_best: bool = False):
|
||||||
|
"""
|
||||||
|
Save training checkpoint.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
is_best: Whether this is the best model so far
|
||||||
|
"""
|
||||||
|
checkpoint = {
|
||||||
|
"epoch": self.current_epoch,
|
||||||
|
"model_state_dict": self.model.state_dict(),
|
||||||
|
"optimizer_G_state_dict": self.optimizer_G.state_dict(),
|
||||||
|
"optimizer_D_state_dict": self.optimizer_D.state_dict(),
|
||||||
|
"g_losses": self.g_losses,
|
||||||
|
"d_losses": self.d_losses,
|
||||||
|
"val_g_losses": self.val_g_losses,
|
||||||
|
"val_d_losses": self.val_d_losses,
|
||||||
|
"best_val_loss": self.best_val_loss,
|
||||||
|
"config": self.config,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Save regular checkpoint
|
||||||
|
checkpoint_path = (
|
||||||
|
self.output_dir / f"checkpoint_epoch_{self.current_epoch + 1}.pth"
|
||||||
|
)
|
||||||
|
torch.save(checkpoint, checkpoint_path)
|
||||||
|
|
||||||
|
# Save best model
|
||||||
|
if is_best:
|
||||||
|
best_path = self.output_dir / "model_best.pth"
|
||||||
|
torch.save(checkpoint, best_path)
|
||||||
|
print(f"Best model saved to {best_path}")
|
||||||
|
|
||||||
|
def load_checkpoint(self, checkpoint_path: str, resume_training: bool = False):
|
||||||
|
"""
|
||||||
|
Load training checkpoint.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
checkpoint_path: Path to checkpoint file
|
||||||
|
resume_training: Если True, продолжить обучение с сохраненной эпохи
|
||||||
|
"""
|
||||||
|
checkpoint = torch.load(checkpoint_path, map_location=self.device)
|
||||||
|
|
||||||
|
self.model.load_state_dict(checkpoint["model_state_dict"])
|
||||||
|
self.optimizer_G.load_state_dict(checkpoint["optimizer_G_state_dict"])
|
||||||
|
self.optimizer_D.load_state_dict(checkpoint["optimizer_D_state_dict"])
|
||||||
|
|
||||||
|
self.current_epoch = checkpoint["epoch"]
|
||||||
|
self.g_losses = checkpoint["g_losses"]
|
||||||
|
self.d_losses = checkpoint["d_losses"]
|
||||||
|
self.val_g_losses = checkpoint["val_g_losses"]
|
||||||
|
self.val_d_losses = checkpoint["val_d_losses"]
|
||||||
|
self.best_val_loss = checkpoint["best_val_loss"]
|
||||||
|
|
||||||
|
if resume_training:
|
||||||
|
print(f"Resuming training from epoch {self.current_epoch + 1}")
|
||||||
|
else:
|
||||||
|
print(f"Loaded checkpoint from epoch {self.current_epoch + 1}")
|
||||||
|
|
||||||
|
def train(self, num_epochs: int, start_epoch: int = 0):
|
||||||
|
"""
|
||||||
|
Train the model for specified number of epochs.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
num_epochs: Number of epochs to train
|
||||||
|
start_epoch: Starting epoch (useful when resuming training)
|
||||||
|
"""
|
||||||
|
print(
|
||||||
|
f"Starting GAN training for {num_epochs} epochs from epoch {start_epoch + 1}..."
|
||||||
|
)
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
for epoch in range(start_epoch, start_epoch + num_epochs):
|
||||||
|
self.current_epoch = epoch
|
||||||
|
|
||||||
|
# Train for one epoch
|
||||||
|
train_g_loss, train_d_loss = self.train_epoch()
|
||||||
|
|
||||||
|
# Validate
|
||||||
|
val_g_loss, val_d_loss = self.validate()
|
||||||
|
|
||||||
|
# Log to TensorBoard
|
||||||
|
self.writer.add_scalar("train/epoch_g_loss", train_g_loss, epoch)
|
||||||
|
self.writer.add_scalar("train/epoch_d_loss", train_d_loss, epoch)
|
||||||
|
self.writer.add_scalar("val/epoch_g_loss", val_g_loss, epoch)
|
||||||
|
self.writer.add_scalar("val/epoch_d_loss", val_d_loss, epoch)
|
||||||
|
|
||||||
|
# Print epoch summary
|
||||||
|
print(f"\nEpoch {epoch + 1}/{num_epochs}:")
|
||||||
|
print(" Generator:")
|
||||||
|
print(f" Train Loss: {train_g_loss:.6f}")
|
||||||
|
print(f" Val Loss: {val_g_loss:.6f}")
|
||||||
|
print(" Discriminator:")
|
||||||
|
print(f" Train Loss: {train_d_loss:.6f}")
|
||||||
|
print(f" Val Loss: {val_d_loss:.6f}")
|
||||||
|
|
||||||
|
# Save checkpoint
|
||||||
|
val_total_loss = val_g_loss + val_d_loss
|
||||||
|
is_best = val_total_loss < self.best_val_loss
|
||||||
|
if is_best:
|
||||||
|
self.best_val_loss = val_total_loss
|
||||||
|
|
||||||
|
self.save_checkpoint(is_best=is_best)
|
||||||
|
|
||||||
|
# Early stopping
|
||||||
|
if self.config.get("early_stopping_patience", 0) > 0:
|
||||||
|
val_losses = [
|
||||||
|
g + d for g, d in zip(self.val_g_losses, self.val_d_losses)
|
||||||
|
]
|
||||||
|
if (
|
||||||
|
epoch - np.argmin(val_losses)
|
||||||
|
>= self.config["early_stopping_patience"]
|
||||||
|
):
|
||||||
|
print(f"Early stopping at epoch {epoch + 1}")
|
||||||
|
break
|
||||||
|
|
||||||
|
# Training completed
|
||||||
|
training_time = time.time() - start_time
|
||||||
|
print(f"\nTraining completed in {training_time:.2f} seconds")
|
||||||
|
print(f"Best validation total loss: {self.best_val_loss:.6f}")
|
||||||
|
|
||||||
|
# Save final model
|
||||||
|
final_model_path = self.output_dir / "model_final.pth"
|
||||||
|
torch.save(self.model.state_dict(), final_model_path)
|
||||||
|
print(f"Final model saved to {final_model_path}")
|
||||||
|
|
||||||
|
# Save training history
|
||||||
|
history_path = self.output_dir / "training_history.json"
|
||||||
|
history = {
|
||||||
|
"g_losses": self.g_losses,
|
||||||
|
"d_losses": self.d_losses,
|
||||||
|
"val_g_losses": self.val_g_losses,
|
||||||
|
"val_d_losses": self.val_d_losses,
|
||||||
|
"best_val_loss": self.best_val_loss,
|
||||||
|
"total_epochs": self.current_epoch + 1,
|
||||||
|
}
|
||||||
|
with open(history_path, "w") as f:
|
||||||
|
json.dump(history, f, indent=2)
|
||||||
|
print(f"Training history saved to {history_path}")
|
||||||
|
|
||||||
|
# Close TensorBoard writer
|
||||||
|
self.writer.close()
|
||||||
|
|
||||||
|
def evaluate(self, test_loader: DataLoader) -> Dict:
|
||||||
|
"""
|
||||||
|
Evaluate the model on test data.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
test_loader: Test data loader
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary with evaluation metrics
|
||||||
|
"""
|
||||||
|
self.model.eval()
|
||||||
|
total_g_loss = 0.0
|
||||||
|
total_d_loss = 0.0
|
||||||
|
|
||||||
|
print("Evaluating model on test set...")
|
||||||
|
|
||||||
|
for batch in tqdm(test_loader, desc="Evaluation"):
|
||||||
|
# Move data to device
|
||||||
|
yandex_img = batch["yandex_img"].to(self.device)
|
||||||
|
google_img = batch["google_img"].to(self.device)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
# Generate fake image
|
||||||
|
fake_google_img = self.model.generator(yandex_img)
|
||||||
|
|
||||||
|
# Generator loss - returns tuple of tensors
|
||||||
|
g_loss_tuple = self.model.generator_step(yandex_img, google_img)
|
||||||
|
g_loss, g_gan_loss, g_l1_loss = g_loss_tuple
|
||||||
|
|
||||||
|
# Discriminator loss - returns tuple of tensors
|
||||||
|
d_loss_tuple = self.model.discriminator_step(
|
||||||
|
yandex_img, google_img, fake_google_img
|
||||||
|
)
|
||||||
|
d_loss, d_real_loss, d_fake_loss = d_loss_tuple
|
||||||
|
|
||||||
|
# Update statistics
|
||||||
|
total_g_loss += g_loss.item()
|
||||||
|
total_d_loss += d_loss.item()
|
||||||
|
|
||||||
|
avg_g_loss = total_g_loss / len(test_loader)
|
||||||
|
avg_d_loss = total_d_loss / len(test_loader)
|
||||||
|
|
||||||
|
metrics = {
|
||||||
|
"generator_loss": avg_g_loss,
|
||||||
|
"discriminator_loss": avg_d_loss,
|
||||||
|
"total_loss": avg_g_loss + avg_d_loss,
|
||||||
|
}
|
||||||
|
|
||||||
|
print("\nTest Results:")
|
||||||
|
print(f" Generator Loss: {avg_g_loss:.6f}")
|
||||||
|
print(f" Discriminator Loss: {avg_d_loss:.6f}")
|
||||||
|
print(f" Total Loss: {avg_g_loss + avg_d_loss:.6f}")
|
||||||
|
|
||||||
|
return metrics
|
||||||
1
models/SiaN/.gitignore
vendored
1
models/SiaN/.gitignore
vendored
@@ -0,0 +1 @@
|
|||||||
|
runs
|
||||||
File diff suppressed because one or more lines are too long
Reference in New Issue
Block a user