feat: add GAN

This commit is contained in:
2026-05-30 14:49:40 +03:00
parent 6477ce0776
commit 72e1950127
29 changed files with 2670 additions and 361 deletions

View File

@@ -1,254 +1,111 @@
# GAN Trainer для преобразования изображений Yandex → Google
Этот модуль содержит реализацию тренера для GAN (Generative Adversarial Network) модели, предназначенной для преобразования изображений карт Yandex в стиль Google Maps.
## Структура проекта
```
autopilot/models/GAN/
├── gan.py # Основная реализация GAN модели
├── trainer.py # Тренер для обучения GAN
├── test_trainer.py # Тесты для тренера
├── train_example.py # Пример использования тренера
── README.md # Этот файл
```
## Модель GAN
Модель состоит из двух основных компонентов:
### 1. Генератор (GeneratorUNet)
- Архитектура U-Net для преобразования изображений
- Принимает изображение Yandex (3 канала RGB)
- Возвращает изображение в стиле Google (3 канала RGB)
- Использует skip connections для сохранения деталей
### 2. Дискриминатор (DiscriminatorPatchGAN)
- PatchGAN архитектура
- Принимает пару изображений (Yandex + Google)
- Возвращает вероятность того, что пара реальная
- Работает с патчами изображения 41x41
### Функция потерь (GANLoss)
Поддерживает три режима:
- `vanilla`: Бинарная кросс-энтропия
- `lsgan`: Least Squares GAN (более стабильный)
- `wgangp`: Wasserstein GAN with Gradient Penalty
## Тренер (GANTrainer)
### Основные возможности
1. **Обучение с чередованием**:
- Обучение генератора и дискриминатора поочередно
- Поддержка L1 потерь для сохранения структуры
2. **Валидация и мониторинг**:
- Отдельные потери для генератора и дискриминатора
- Логирование в TensorBoard
- Ранняя остановка
3. **Сохранение и загрузка**:
- Чекпоинты каждой эпохи
- Лучшая модель
- Финальная модель
- История обучения
4. **Оценка модели**:
- Метрики на тестовом наборе
- Генерация примеров
### Быстрый старт
```python
import torch
from torch.utils.data import DataLoader
from models.GAN.gan import create_image_gan
from models.GAN.trainer import GANTrainer
# Конфигурация
config = {
"learning_rate": 2e-4,
"beta1": 0.5,
"beta2": 0.999,
"batch_size": 4,
"output_dir": "runs/gan_training",
"gan_mode": "vanilla",
"lambda_L1": 100.0,
"early_stopping_patience": 20,
}
# Устройство
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Создание модели
model = create_image_gan(
input_channels=3,
output_channels=3,
gan_mode=config["gan_mode"],
lambda_L1=config["lambda_L1"],
use_cuda=(device.type == "cuda"),
)
# Создание даталоадеров (замените на свои данные)
train_loader = DataLoader(train_dataset, batch_size=config["batch_size"], shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=config["batch_size"], shuffle=False)
# Создание тренера
trainer = GANTrainer(
model=model,
train_loader=train_loader,
val_loader=val_loader,
device=device,
config=config,
)
# Обучение
trainer.train(num_epochs=100)
# Оценка
metrics = trainer.evaluate(test_loader)
```
### Конфигурация обучения
#### Базовая конфигурация
```python
config = {
# Параметры оптимизатора
"learning_rate": 2e-4, # Learning rate
"beta1": 0.5, # Adam beta1
"beta2": 0.999, # Adam beta2
# Параметры обучения
"batch_size": 4, # Размер батча
"epochs": 100, # Количество эпох
# Параметры GAN
"gan_mode": "vanilla", # Режим GAN
"lambda_L1": 100.0, # Вес L1 потерь
# Регуляризация
"grad_clip": 1.0, # Gradient clipping
# Ранняя остановка
"early_stopping_patience": 20,
# Выходные данные
"output_dir": "runs/gan",
}
```
#### Расширенная конфигурация
```python
config = {
"learning_rate": 2e-4,
"beta1": 0.5,
"beta2": 0.999,
"batch_size": 8,
"epochs": 200,
"gan_mode": "lsgan", # Более стабильный LSGAN
"lambda_L1": 100.0,
"grad_clip": 1.0,
"weight_decay": 1e-4, # Weight decay
"early_stopping_patience": 30,
"early_stopping_min_delta": 1e-4,
"output_dir": "runs/gan_advanced",
}
```
### Методы тренера
#### Основные методы
- `train_epoch()`: Обучение на одной эпохе
- `validate()`: Валидация модели
- `train(num_epochs)`: Полное обучение
- `evaluate(test_loader)`: Оценка на тестовых данных
#### Управление чекпоинтами
- `save_checkpoint(is_best=False)`: Сохранение чекпоинта
- `load_checkpoint(path, resume_training=False)`: Загрузка чекпоинта
### Выходные файлы
После обучения создаются следующие файлы:
```
runs/gan_training/
├── config.json # Конфигурация обучения
├── training_history.json # История потерь
├── model_best.pth # Лучшая модель
├── model_final.pth # Финальная модель
├── checkpoint_epoch_1.pth # Чекпоинты каждой эпохи
├── checkpoint_epoch_2.pth
├── ...
└── tensorboard/ # Логи TensorBoard
├── events.out.tfevents...
└── ...
```
### TensorBoard
Для визуализации обучения используйте TensorBoard:
```bash
tensorboard --logdir runs/gan_training/tensorboard
```
Доступные метрики:
- `train/batch_g_loss`: Потери генератора на батче
- `train/batch_d_loss`: Потери дискриминатора на батче
- `train/batch_g_l1_loss`: L1 потери генератора
- `train/epoch_g_loss`: Потери генератора на эпохе
- `train/epoch_d_loss`: Потери дискриминатора на эпохе
- `val/epoch_g_loss`: Валидационные потери генератора
- `val/epoch_d_loss`: Валидационные потери дискриминатора
### Тестирование
Запустите тесты для проверки работоспособности:
```bash
python models/GAN/test_trainer.py
```
### Пример использования
Полный пример использования смотрите в `train_example.py`.
### Советы по обучению
1. **Начальные значения**:
- Используйте `gan_mode="lsgan"` для более стабильного обучения
- Начните с `lambda_L1=100.0` и регулируйте по необходимости
- Используйте маленький `batch_size` (4-8) при ограниченной памяти GPU
2. **Мониторинг**:
- Следите за балансом потерь генератора и дискриминатора
- Если потери дискриминатора близки к 0, генератор не обучается
- Если потери генератора слишком высоки, уменьшите `lambda_L1`
3. **Визуализация**:
- Регулярно генерируйте примеры для визуальной оценки
- Используйте TensorBoard для отслеживания прогресса
### Устранение проблем
#### Высокие потери генератора
- Уменьшите `lambda_L1`
- Увеличьте learning rate
- Проверьте качество данных
#### Дискриминатор слишком сильный
- Уменьшите learning rate дискриминатора
- Добавьте dropout в дискриминатор
- Обучайте генератор чаще, чем дискриминатор
#### Недостаток памяти GPU
- Уменьшите `batch_size`
- Уменьшите размер изображений
- Используйте gradient accumulation
### Лицензия
Этот проект является частью Autopilot системы.
# GAN для преобразования Google -> Yandex
Модуль реализует pix2pix-подобный GAN для парных изображений карт. Генератор получает изображение Google из пары и пытается сгенерировать изображение в стиле Yandex. Сгенерированная картинка сравнивается со вторым изображением пары, то есть с оригинальным `*_yandex.png`.
## Структура
```text
models/GAN/
├── build.py # генератор notebook.gen.ipynb по схеме
├── _schema.py # структура генерируемого ноутбука
├── _schema.md # описание формата схемы
├── notebook.gen.ipynb
── src/
│ ├── config.py
│ ├── dataloader.py
│ ├── model.py
│ ├── trainer.py
│ ├── analyze.py
│ └── main.py
└── README.md
```
## Архитектура
`GeneratorUNet` принимает `google_img` с 3 RGB-каналами и возвращает `fake_yandex`.
`DiscriminatorPatchGAN` получает пару `(google_img, yandex_img)` и отличает настоящую пару от `(google_img, fake_yandex)`.
`ImageGAN.generator_step()` считает:
- adversarial loss: дискриминатор должен принять `(google_img, fake_yandex)` за реальную пару;
- L1 loss: `fake_yandex` сравнивается с оригинальным `yandex_img` из той же пары.
- SSIM loss: штрафует структурные отличия, чтобы карта не расплывалась;
- Sobel edge loss: сохраняет контуры дорог и объектов для последующего поиска ключевых точек.
Итоговая функция потерь генератора:
```text
G_loss =
lambda_GAN * GAN_loss(D(google_img, fake_yandex), real)
+ lambda_L1 * L1(fake_yandex, yandex_img)
+ lambda_SSIM * SSIMLoss(fake_yandex, yandex_img)
+ lambda_edge * SobelEdgeLoss(fake_yandex, yandex_img)
```
По умолчанию используется `lsgan`, усиленная реконструкция и более медленный
дискриминатор. Это менее “креативно”, зато обычно даёт более чистые контуры.
## Запуск
Из папки `models/GAN`:
```bash
python src/main.py
```
`src/main.py` специально написан без функции `main()`: при генерации ноутбука
все переменные остаются доступны после запуска ячейки, как в SiaN.
Чтобы пересобрать ноутбук из схемы:
```bash
python build.py
```
По умолчанию датасет берется из:
```text
C:\Users\admin\Projects\autopilot\datasets\ya_go_maps\images
```
Путь, размер изображений, batch size и число эпох меняются в `src/config.py`.
Если CUDA видна, но установленный PyTorch не поддерживает архитектуру GPU
например Tesla P100 `sm_60`, код автоматически переключится на CPU. Это
управляется параметром `prefer_cuda` в `src/config.py`.
## Ожидаемый формат данных
В директории датасета должны лежать пары:
```text
0000_google.png
0000_yandex.png
0001_google.png
0001_yandex.png
...
```
Даталоадер возвращает:
- `google_img`: вход генератора;
- `yandex_img`: целевое изображение для сравнения;
- `idx`: номер пары.
## Чекпоинты
Тренер сохраняет чекпоинты в:
```text
models/GAN/runs/checkpoints
```
Сохраняются `best.pth`, периодические `epoch_N.pth` и `final.pth`.
После обучения `src/main.py` также строит:
- графики `G/D/L1/SSIM/edge loss`;
- сетку `Google input -> Generated Yandex -> Yandex target`.
Файлы сохраняются в `models/GAN/runs/images`.

32
models/GAN/_schema.md Normal file
View File

@@ -0,0 +1,32 @@
# GAN Schema
Notebook structure definition for the Google -> Yandex GAN model.
---
## Format
```text
# === IMPORTS ===
<all imports>
# markdown
"""Description"""
# code: ./src/file.py
# # shell:
<shell commands>
```
## Directives
| Directive | Description |
|----------------|------------------------------------|
| `# code:` | Include file from `src/` |
| `# markdown` | Add markdown cell |
| `# # shell:` | Add notebook shell command cell |
---
`build.py` generates `notebook.gen.ipynb` from `_schema.py`.

121
models/GAN/_schema.py Normal file
View File

@@ -0,0 +1,121 @@
# _schema.py
# === IMPORTS ===
import json
import os
from pathlib import Path
from typing import Any, Dict, List, Tuple
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
from PIL import Image
from torch.utils.data import DataLoader, Dataset, Subset
from torchvision import transforms
from tqdm import tqdm
# markdown
"""# Configuration
Global settings for the Google -> Yandex GAN:
- Dataset path and image size
- Optimizer and training hyperparameters
- Device preference with safe CUDA compatibility checks
- GAN, L1, SSIM and edge reconstruction weights
- Output directories for checkpoints and generated samples"""
# code: ./src/config.py
# markdown
"""## Dataset
Google/Yandex paired image loader.
**Direction:**
- `google_img` is the generator input
- `yandex_img` is the target image from the same pair
**Returns:**
- Batch dict with `google_img`, `yandex_img`, `idx`"""
# code: ./src/dataloader.py
# code: ./src/test_dataloader.py
# markdown
"""## Model
Pix2pix-style GAN for Google -> Yandex map translation.
**Generator:**
- `GeneratorUNet`
- Input: Google image `(B, 3, H, W)`
- Output: generated Yandex image `(B, 3, H, W)`
**Discriminator:**
- `DiscriminatorPatchGAN`
- Input pair: `(google_img, yandex_img)`
- Learns to distinguish real pairs from `(google_img, fake_yandex)`
**Generator loss:**
- adversarial loss
- `lambda_L1 * L1(fake_yandex, yandex_img)`
- `lambda_SSIM * SSIMLoss(fake_yandex, yandex_img)`
- `lambda_edge * SobelEdgeLoss(fake_yandex, yandex_img)`
The generator uses bilinear upsampling followed by convolution to avoid
checkerboard artifacts from transposed convolutions."""
# code: ./src/model.py
# markdown
"""## Training
`GANTrainer` trains discriminator and generator alternately.
**Training step:**
1. Generate `fake_yandex = G(google_img)`
2. Train discriminator on real pair `(google_img, yandex_img)` and fake pair `(google_img, fake_yandex)`
3. Train generator against discriminator and paired Yandex target
**Checkpoint saving:**
- `best.pth`
- `epoch_N.pth`
- `final.pth`"""
# code: ./src/trainer.py
# markdown
"""## Analysis
Visualization helpers for generated samples and collected training metrics.
Training history plot contains:
1. Generator loss
2. Discriminator loss
3. L1 loss against the paired Yandex target
4. SSIM structure loss
5. Sobel edge loss
6. Best-checkpoint reconstruction score
The sample grid contains:
1. Google input
2. Generated Yandex
3. Original Yandex target"""
# code: ./src/analyze.py
# markdown
"""## Main Pipeline
Executes the full GAN workflow:
1. Create config
2. Build paired data loaders
3. Initialize Google -> Yandex GAN
4. Train with validation
5. Save checkpoints in `runs/checkpoints/`
6. Show loss plots and generated sample grid
This block is intentionally top-level, not wrapped in `main()`, so notebook
variables such as `model`, `trainer`, `train_loader`, `val_loader`, and
`training_analysis` remain available for debugging."""
# code: ./src/main.py
# # shell:
# !zip artefacts.zip runs/checkpoints/best.pth runs/images/training_history.png runs/images/generation_samples.png

260
models/GAN/build.py Normal file
View File

@@ -0,0 +1,260 @@
import json
import os
import re
def parse_schema(schema_path):
with open(schema_path, 'r', encoding='utf-8') as f:
content = f.read()
imports = []
items = []
in_imports_section = False
lines = content.split('\n')
i = 0
while i < len(lines):
line = lines[i]
stripped = line.strip()
if stripped == '# === IMPORTS ===':
in_imports_section = True
i += 1
continue
elif stripped.startswith('# ==='):
in_imports_section = False
i += 1
continue
if in_imports_section and stripped and not stripped.startswith('#'):
imports.append(stripped)
i += 1
continue
if in_imports_section and stripped.startswith('#'):
in_imports_section = False
if stripped.startswith('# code:'):
match = re.search(r'# code:\s*(.+)', stripped)
if match:
items.append(('code', match.group(1).strip()))
i += 1
continue
if stripped.startswith('# inline:') or stripped.startswith('# # shell:'):
directive = 'inline' if stripped.startswith('# inline:') else 'shell'
i += 1
code_lines = []
while i < len(lines):
stripped = lines[i].strip()
if directive == 'shell':
if stripped.startswith('# ') or stripped.startswith('# # '):
line_content = lines[i].strip()
if line_content.startswith('# # '):
line_content = line_content[3:].lstrip()
else:
line_content = line_content[2:]
code_lines.append(line_content)
i += 1
continue
if stripped.startswith('#'):
if stripped.startswith('# code:') or stripped.startswith('# inline:') or stripped.startswith('# # shell:') or stripped == '# markdown' or stripped.startswith('# ==='):
break
i += 1
continue
if stripped == '':
i += 1
continue
code_lines.append(lines[i])
i += 1
if code_lines:
items.append((directive, '\n'.join(code_lines).rstrip()))
continue
if stripped == '# markdown':
i += 1
if i < len(lines):
next_line = lines[i].strip()
if next_line.startswith('"""'):
if next_line.endswith('"""') and len(next_line) > 3:
md_content = next_line[3:-3].strip()
i += 1
else:
end_idx = None
for j in range(i + 1, len(lines)):
if '"""' in lines[j]:
end_idx = j
break
if end_idx:
md_content = '\n'.join(lines[i:end_idx])
md_content = md_content.strip('"""').strip()
i = end_idx + 1
else:
md_content = ""
i += 1
items.append(('markdown', md_content))
continue
i += 1
return imports, items
def strip_all_imports(content):
lines = content.split('\n')
result_lines = []
skip_block = False
if_block_indent = 0
for line in lines:
stripped = line.strip()
if stripped.startswith('if __name__') or stripped.startswith('if __name__ =='):
skip_block = True
if_block_indent = len(line) - len(line.lstrip())
continue
if skip_block:
current_indent = len(line) - len(line.lstrip())
if line.strip() == '':
continue
if current_indent < if_block_indent:
skip_block = False
elif current_indent == if_block_indent and stripped.startswith('if '):
skip_block = True
if_block_indent = current_indent
continue
else:
continue
if stripped.startswith('import ') or stripped.startswith('from '):
continue
result_lines.append(line)
while result_lines and result_lines[-1].strip() == '':
result_lines.pop()
return '\n'.join(result_lines)
def read_src_file(ref, base_dir):
ref_path = ref.replace('./src/', '').replace('src/', '')
full_path = os.path.join(base_dir, 'src', ref_path.replace('./src/', '').lstrip('/'))
if not full_path.endswith('.py'):
full_path += '.py'
with open(full_path, 'r', encoding='utf-8') as f:
return f.read()
def build_notebook(schema_path, output_path=None):
schema_dir = os.path.dirname(os.path.abspath(schema_path))
if output_path is None:
output_path = os.path.join(schema_dir, 'notebook.gen.ipynb')
imports, items = parse_schema(schema_path)
cells = []
if imports:
imports_cell = {
"cell_type": "code",
"execution_count": None,
"metadata": {},
"outputs": [],
"source": [imp + "\n" for imp in imports]
}
cells.append(imports_cell)
for item_type, item_content in items:
if item_type == 'markdown':
md_cell = {
"cell_type": "markdown",
"metadata": {},
"source": item_content + "\n"
}
cells.append(md_cell)
elif item_type == 'inline':
lines = item_content.split('\n')
while lines and lines[-1].strip() == '':
lines.pop()
if lines:
lines.append('')
cell = {
"cell_type": "code",
"execution_count": None,
"metadata": {},
"outputs": [],
"source": [line + "\n" for line in lines]
}
cells.append(cell)
elif item_type == 'shell':
lines = item_content.split('\n')
while lines and lines[-1].strip() == '':
lines.pop()
if lines:
lines.append('')
cell = {
"cell_type": "code",
"execution_count": None,
"metadata": {},
"outputs": [],
"source": [line + "\n" for line in lines]
}
cells.append(cell)
elif item_type == 'code':
try:
content = read_src_file(item_content, schema_dir)
content = strip_all_imports(content)
lines = content.split('\n')
while lines and lines[-1].strip() == '':
lines.pop()
if lines:
lines.append('')
cell = {
"cell_type": "code",
"execution_count": None,
"metadata": {},
"outputs": [],
"source": [line + "\n" for line in lines]
}
cells.append(cell)
except FileNotFoundError:
print(f"Warning: Could not find: {item_content}")
notebook = {
"cells": cells,
"metadata": {
"kernelspec": {
"display_name": ".venv",
"language": "python",
"name": "python3"
},
"language_info": {
"name": "python",
"version": "3.11.0"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
with open(output_path, 'w', encoding='utf-8') as f:
json.dump(notebook, f, indent=1, ensure_ascii=False)
print(f"Notebook generated: {output_path}")
if __name__ == "__main__":
script_dir = os.path.dirname(os.path.abspath(__file__))
schema_path = os.path.join(script_dir, '_schema.py')
if os.path.exists(schema_path):
build_notebook(schema_path)
else:
print(f"Error: _schema.py not found at {schema_path}")

View File

@@ -1,30 +0,0 @@
"""Main entry point for GAN training."""
from config import create_config
from dataloader import create_data_loaders
from model import create_gan
from trainer import create_trainer
def main():
"""Run training pipeline."""
config = create_config()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Create components
model = create_gan(use_cuda=False) # Set to True to use GPU
train_loader, val_loader = create_data_loaders(
root_dir=config["data_dir"],
batch_size=config["batch_size"],
image_size=tuple(config["image_size"]),
num_workers=config["num_workers"],
)
trainer = create_trainer(model, train_loader, val_loader, config)
# Train
trainer.train(config["epochs"])
if __name__ == "__main__":
import torch
main()

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1 @@
"""GAN package for Google-to-Yandex map image translation."""

117
models/GAN/src/analyze.py Normal file
View File

@@ -0,0 +1,117 @@
from pathlib import Path
import matplotlib.pyplot as plt
import torch
def denormalize_image(tensor):
return (tensor.detach().cpu() * 0.5 + 0.5).clamp(0, 1)
@torch.no_grad()
def visualize_generation(model, data_loader, output_dir, device=None, num_samples=4, show=True):
device = device or model.device
model.eval()
batch = next(iter(data_loader))
google_img = batch["google_img"][:num_samples].to(device)
yandex_img = batch["yandex_img"][:num_samples].to(device)
fake_yandex = model.generator(google_img)
output_dir = Path(output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
output_path = output_dir / "generation_samples.png"
fig, axes = plt.subplots(num_samples, 3, figsize=(9, 3 * num_samples))
if num_samples == 1:
axes = axes.reshape(1, 3)
titles = ["Google input", "Generated Yandex", "Yandex target"]
for row in range(num_samples):
images = [google_img[row], fake_yandex[row], yandex_img[row]]
for col, image in enumerate(images):
axes[row, col].imshow(denormalize_image(image).permute(1, 2, 0))
axes[row, col].set_title(titles[col])
axes[row, col].axis("off")
fig.tight_layout()
fig.savefig(output_path, dpi=150)
if show:
plt.show()
plt.close(fig)
return output_path
def plot_training_history(trainer, output_dir, show=True):
output_dir = Path(output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
output_path = output_dir / "training_history.png"
epochs = range(1, len(trainer.g_losses) + 1)
fig, axes = plt.subplots(2, 3, figsize=(15, 8))
axes = axes.ravel()
axes[0].plot(epochs, trainer.g_losses, label="train G")
axes[0].plot(epochs, trainer.val_g_losses, label="val G")
axes[0].set_title("Generator loss")
axes[0].set_xlabel("Epoch")
axes[0].legend()
axes[0].grid(True, alpha=0.3)
axes[1].plot(epochs, trainer.d_losses, label="train D")
axes[1].plot(epochs, trainer.val_d_losses, label="val D")
axes[1].set_title("Discriminator loss")
axes[1].set_xlabel("Epoch")
axes[1].legend()
axes[1].grid(True, alpha=0.3)
axes[2].plot(epochs, trainer.l1_losses, label="train L1")
axes[2].plot(epochs, trainer.val_l1_losses, label="val L1")
axes[2].set_title("Paired Yandex L1 loss")
axes[2].set_xlabel("Epoch")
axes[2].legend()
axes[2].grid(True, alpha=0.3)
axes[3].plot(epochs, trainer.ssim_losses, label="train SSIM")
axes[3].plot(epochs, trainer.val_ssim_losses, label="val SSIM")
axes[3].set_title("SSIM structure loss")
axes[3].set_xlabel("Epoch")
axes[3].legend()
axes[3].grid(True, alpha=0.3)
axes[4].plot(epochs, trainer.edge_losses, label="train edge")
axes[4].plot(epochs, trainer.val_edge_losses, label="val edge")
axes[4].set_title("Sobel edge loss")
axes[4].set_xlabel("Epoch")
axes[4].legend()
axes[4].grid(True, alpha=0.3)
axes[5].plot(epochs, trainer.val_reconstruction_losses, label="val reconstruction")
axes[5].set_title("Best-checkpoint score")
axes[5].set_xlabel("Epoch")
axes[5].legend()
axes[5].grid(True, alpha=0.3)
fig.tight_layout()
fig.savefig(output_path, dpi=150)
if show:
plt.show()
plt.close(fig)
return output_path
def analyze_training(trainer):
return {
"best_val_loss": trainer.best_val_loss,
"g_losses": trainer.g_losses,
"d_losses": trainer.d_losses,
"l1_losses": trainer.l1_losses,
"ssim_losses": trainer.ssim_losses,
"edge_losses": trainer.edge_losses,
"val_g_losses": trainer.val_g_losses,
"val_d_losses": trainer.val_d_losses,
"val_l1_losses": trainer.val_l1_losses,
"val_ssim_losses": trainer.val_ssim_losses,
"val_edge_losses": trainer.val_edge_losses,
"val_reconstruction_losses": trainer.val_reconstruction_losses,
}

View File

@@ -6,23 +6,30 @@ def create_config():
return {
# Optimizer params
"learning_rate": 2e-4,
"discriminator_lr_factor": 0.5,
"beta1": 0.5,
"beta2": 0.999,
# Training params
"batch_size": 32,
"epochs": 100,
"prefer_cuda": True,
# GAN params
"gan_mode": "vanilla",
"lambda_L1": 100.0,
"gan_mode": "lsgan",
"lambda_GAN": 0.5,
"lambda_L1": 150.0,
"lambda_SSIM": 25.0,
"lambda_edge": 20.0,
"discriminator_update_interval": 1,
# Regularization
"grad_clip": 1.0,
# Early stopping
"early_stopping_patience": 20,
"early_stopping_patience": 25,
# Output
"output_dir": "runs/gan_training",
"output_dir": r"C:\Users\admin\Projects\autopilot\models\GAN\runs",
# Logging
"log_interval": 10,
"save_interval": 5,
"num_visual_samples": 4,
# Data
"data_dir": r"C:\Users\admin\Projects\autopilot\datasets\ya_go_maps\images",
"image_size": [256, 256],
@@ -33,4 +40,4 @@ def create_config():
if __name__ == "__main__":
config = create_config()
print("Default config:", config)
print("Default config:", config)

View File

@@ -1,4 +1,4 @@
"""Data loader for Yandex-to-Google image translation."""
"""Data loader for Google-to-Yandex image translation."""
import os
from typing import Dict, List, Tuple
@@ -10,7 +10,7 @@ from torchvision import transforms
class YaGoDataset(Dataset):
"""Dataset loading pairs of Yandex and Google map images."""
"""Dataset loading paired Google and Yandex map images."""
def __init__(
self,
@@ -159,4 +159,4 @@ if __name__ == "__main__":
batch = next(iter(train_loader))
print(f"Batch shapes: google={batch['google_img'].shape}, yandex={batch['yandex_img'].shape}")
print(f"Train batches: {len(train_loader)}, Val batches: {len(val_loader)}")
print(f"Train batches: {len(train_loader)}, Val batches: {len(val_loader)}")

60
models/GAN/src/main.py Normal file
View File

@@ -0,0 +1,60 @@
"""Executable GAN training pipeline.
The code is intentionally top-level, mirroring the SiaN notebook style:
when this file is included in the generated notebook, variables remain
available for debugging and interactive experiments.
"""
from pathlib import Path
import torch
from analyze import analyze_training, plot_training_history, visualize_generation
from config import create_config
from dataloader import create_data_loaders
from model import create_gan, get_compatible_device
from trainer import create_trainer
config = create_config()
device = get_compatible_device(prefer_cuda=config["prefer_cuda"])
print(f"Using device: {device}")
train_loader, val_loader = create_data_loaders(
root_dir=config["data_dir"],
batch_size=config["batch_size"],
train_split=config["train_split"],
image_size=tuple(config["image_size"]),
num_workers=config["num_workers"],
)
print(f"Data loaders created: train={len(train_loader.dataset)}, val={len(val_loader.dataset)}")
model = create_gan(
gan_mode=config["gan_mode"],
lambda_GAN=config["lambda_GAN"],
lambda_L1=config["lambda_L1"],
lambda_SSIM=config["lambda_SSIM"],
lambda_edge=config["lambda_edge"],
use_cuda=(device.type == "cuda"),
)
generator_params = sum(p.numel() for p in model.generator.parameters())
discriminator_params = sum(p.numel() for p in model.discriminator.parameters())
print(f"Model created: generator={generator_params:,}, discriminator={discriminator_params:,}")
trainer = create_trainer(model, train_loader, val_loader, config)
trainer.train(config["epochs"])
training_analysis = analyze_training(trainer)
images_dir = Path(config["output_dir"]) / "images"
history_plot_path = plot_training_history(trainer, images_dir)
generation_samples_path = visualize_generation(
model=model,
data_loader=val_loader,
output_dir=images_dir,
device=device,
num_samples=config["num_visual_samples"],
)
print(f"Training history plot: {history_plot_path}")
print(f"Generation samples: {generation_samples_path}")

View File

@@ -1,10 +1,37 @@
"""GAN model for image translation Yandex -> Google."""
"""GAN model for image translation Google -> Yandex."""
import torch
import torch.nn as nn
import torch.nn.functional as F
def get_compatible_device(prefer_cuda: bool = True, verbose: bool = True) -> torch.device:
"""Return CUDA only when the current PyTorch build supports the GPU arch."""
if not prefer_cuda or not torch.cuda.is_available():
return torch.device("cpu")
try:
major, minor = torch.cuda.get_device_capability()
arch = f"sm_{major}{minor}"
supported_arches = set(torch.cuda.get_arch_list())
gpu_name = torch.cuda.get_device_name()
except Exception as exc:
if verbose:
print(f"CUDA is visible but cannot be inspected ({exc}); using CPU.")
return torch.device("cpu")
if supported_arches and arch not in supported_arches:
if verbose:
supported = ", ".join(sorted(supported_arches))
print(
f"CUDA GPU '{gpu_name}' has capability {arch}, but this PyTorch build "
f"supports only: {supported}. Using CPU."
)
return torch.device("cpu")
return torch.device("cuda")
class UNetDownBlock(nn.Module):
"""Downsampling block for U-Net."""
@@ -29,7 +56,10 @@ class UNetUpBlock(nn.Module):
def __init__(self, in_channels: int, out_channels: int, dropout: float = 0.0):
super().__init__()
self.upconv = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=4, stride=2, padding=1, bias=False)
self.upconv = nn.Sequential(
nn.Upsample(scale_factor=2, mode="bilinear", align_corners=False),
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, bias=False),
)
self.norm = nn.BatchNorm2d(out_channels)
self.relu = nn.ReLU(inplace=True)
if dropout > 0:
@@ -38,7 +68,6 @@ class UNetUpBlock(nn.Module):
self.dropout = None
def forward(self, x, skip_input):
x = self.upconv(x)
# Pad if needed to match skip connection size
if x.shape != skip_input.shape:
@@ -54,7 +83,7 @@ class UNetUpBlock(nn.Module):
class GeneratorUNet(nn.Module):
"""U-Net generator for Yandex -> Google translation."""
"""U-Net generator for Google -> Yandex translation."""
def __init__(self, in_channels: int = 3, out_channels: int = 3):
super().__init__()
@@ -85,7 +114,8 @@ class GeneratorUNet(nn.Module):
# Final
self.final = nn.Sequential(
nn.ConvTranspose2d(128, out_channels, kernel_size=4, stride=2, padding=1),
nn.Upsample(scale_factor=2, mode="bilinear", align_corners=False),
nn.Conv2d(128, out_channels, kernel_size=3, padding=1),
nn.Tanh(),
)
@@ -115,7 +145,7 @@ class GeneratorUNet(nn.Module):
class DiscriminatorPatchGAN(nn.Module):
"""PatchGAN discriminator."""
"""PatchGAN discriminator for paired source/target images."""
def __init__(self, in_channels: int = 6):
super().__init__()
@@ -132,7 +162,6 @@ class DiscriminatorPatchGAN(nn.Module):
nn.BatchNorm2d(512),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(512, 1, kernel_size=4, stride=1, padding=1),
nn.Sigmoid(),
)
def forward(self, img_A, img_B):
@@ -167,15 +196,82 @@ class GANLoss(nn.Module):
return -prediction.mean() if target_is_real else prediction.mean()
class SSIMLoss(nn.Module):
"""Local SSIM loss for normalized image tensors in [-1, 1]."""
def __init__(self, window_size: int = 11, c1: float = 0.01 ** 2, c2: float = 0.03 ** 2):
super().__init__()
self.window_size = window_size
self.c1 = c1
self.c2 = c2
def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
pred = (pred + 1.0) * 0.5
target = (target + 1.0) * 0.5
padding = self.window_size // 2
mu_pred = F.avg_pool2d(pred, self.window_size, stride=1, padding=padding)
mu_target = F.avg_pool2d(target, self.window_size, stride=1, padding=padding)
mu_pred_sq = mu_pred.pow(2)
mu_target_sq = mu_target.pow(2)
mu_pred_target = mu_pred * mu_target
sigma_pred = F.avg_pool2d(pred * pred, self.window_size, stride=1, padding=padding) - mu_pred_sq
sigma_target = F.avg_pool2d(target * target, self.window_size, stride=1, padding=padding) - mu_target_sq
sigma_pred_target = F.avg_pool2d(pred * target, self.window_size, stride=1, padding=padding) - mu_pred_target
ssim_map = (
(2 * mu_pred_target + self.c1) * (2 * sigma_pred_target + self.c2)
) / (
(mu_pred_sq + mu_target_sq + self.c1) * (sigma_pred + sigma_target + self.c2)
)
return (1.0 - ssim_map.clamp(0, 1)).mean()
class SobelEdgeLoss(nn.Module):
"""L1 loss between Sobel edge maps, useful for stable keypoint structure."""
def __init__(self):
super().__init__()
kernel_x = torch.tensor(
[[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]],
dtype=torch.float32,
).view(1, 1, 3, 3)
kernel_y = torch.tensor(
[[-1, -2, -1], [0, 0, 0], [1, 2, 1]],
dtype=torch.float32,
).view(1, 1, 3, 3)
self.register_buffer("kernel_x", kernel_x)
self.register_buffer("kernel_y", kernel_y)
@staticmethod
def _to_gray(x: torch.Tensor) -> torch.Tensor:
x = (x + 1.0) * 0.5
weights = x.new_tensor([0.299, 0.587, 0.114]).view(1, 3, 1, 1)
return (x * weights).sum(dim=1, keepdim=True)
def _edges(self, x: torch.Tensor) -> torch.Tensor:
gray = self._to_gray(x)
grad_x = F.conv2d(gray, self.kernel_x, padding=1)
grad_y = F.conv2d(gray, self.kernel_y, padding=1)
return torch.sqrt(grad_x.pow(2) + grad_y.pow(2) + 1e-6)
def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
return F.l1_loss(self._edges(pred), self._edges(target))
class ImageGAN(nn.Module):
"""Complete GAN model for image translation."""
"""Complete pix2pix-style GAN for Google -> Yandex image translation."""
def __init__(
self,
input_channels: int = 3,
output_channels: int = 3,
gan_mode: str = "vanilla",
lambda_L1: float = 100.0,
gan_mode: str = "lsgan",
lambda_L1: float = 150.0,
lambda_GAN: float = 0.5,
lambda_SSIM: float = 25.0,
lambda_edge: float = 20.0,
use_cuda: bool = True,
):
super().__init__()
@@ -183,29 +279,36 @@ class ImageGAN(nn.Module):
self.discriminator = DiscriminatorPatchGAN(input_channels + output_channels)
self.gan_loss = GANLoss(gan_mode)
self.l1_loss = nn.L1Loss()
self.ssim_loss = SSIMLoss()
self.edge_loss = SobelEdgeLoss()
self.lambda_L1 = lambda_L1
self.lambda_GAN = lambda_GAN
self.lambda_SSIM = lambda_SSIM
self.lambda_edge = lambda_edge
self.device = torch.device("cuda" if use_cuda and torch.cuda.is_available() else "cpu")
self.device = get_compatible_device(prefer_cuda=use_cuda)
self.to(self.device)
def forward(self, yandex_image):
"""Generate Google image from Yandex."""
return self.generator(yandex_image)
def forward(self, google_image):
"""Generate a Yandex-style image from a Google image."""
return self.generator(google_image)
def generator_step(self, yandex_img, real_google_img):
"""Compute generator losses."""
fake_google = self.generator(yandex_img)
fake_pred = self.discriminator(yandex_img, fake_google)
gan_loss = self.gan_loss(fake_pred, True)
l1_loss = self.l1_loss(fake_google, real_google_img) * self.lambda_L1
total_loss = gan_loss + l1_loss
return total_loss, gan_loss, l1_loss
def generator_step(self, google_img, real_yandex_img):
"""Compute generator losses against the paired original Yandex image."""
fake_yandex = self.generator(google_img)
fake_pred = self.discriminator(google_img, fake_yandex)
gan_loss = self.gan_loss(fake_pred, True) * self.lambda_GAN
l1_loss = self.l1_loss(fake_yandex, real_yandex_img) * self.lambda_L1
ssim_loss = self.ssim_loss(fake_yandex, real_yandex_img) * self.lambda_SSIM
edge_loss = self.edge_loss(fake_yandex, real_yandex_img) * self.lambda_edge
total_loss = gan_loss + l1_loss + ssim_loss + edge_loss
return total_loss, gan_loss, l1_loss, ssim_loss, edge_loss
def discriminator_step(self, yandex_img, real_google_img, fake_google_img):
"""Compute discriminator losses."""
real_pred = self.discriminator(yandex_img, real_google_img)
def discriminator_step(self, google_img, real_yandex_img, fake_yandex_img):
"""Compute discriminator losses for real and generated Yandex targets."""
real_pred = self.discriminator(google_img, real_yandex_img)
real_loss = self.gan_loss(real_pred, True)
fake_pred = self.discriminator(yandex_img, fake_google_img.detach())
fake_pred = self.discriminator(google_img, fake_yandex_img.detach())
fake_loss = self.gan_loss(fake_pred, False)
total_loss = (real_loss + fake_loss) * 0.5
return total_loss, real_loss, fake_loss
@@ -214,8 +317,11 @@ class ImageGAN(nn.Module):
def create_gan(
input_channels: int = 3,
output_channels: int = 3,
gan_mode: str = "vanilla",
lambda_L1: float = 100.0,
gan_mode: str = "lsgan",
lambda_L1: float = 150.0,
lambda_GAN: float = 0.5,
lambda_SSIM: float = 25.0,
lambda_edge: float = 20.0,
use_cuda: bool = True,
) -> ImageGAN:
"""Create a GAN model."""
@@ -224,6 +330,9 @@ def create_gan(
output_channels=output_channels,
gan_mode=gan_mode,
lambda_L1=lambda_L1,
lambda_GAN=lambda_GAN,
lambda_SSIM=lambda_SSIM,
lambda_edge=lambda_edge,
use_cuda=use_cuda,
)
@@ -252,4 +361,4 @@ if __name__ == "__main__":
# Count parameters
gen_params = sum(p.numel() for p in model.generator.parameters())
disc_params = sum(p.numel() for p in model.discriminator.parameters())
print(f"Generator: {gen_params:,} params, Discriminator: {disc_params:,} params")
print(f"Generator: {gen_params:,} params, Discriminator: {disc_params:,} params")

View File

@@ -0,0 +1,40 @@
from config import create_config
from dataloader import YaGoDataset, create_data_loaders
def get_dataset_info():
config = create_config()
dataset = YaGoDataset(
root_dir=config["data_dir"],
image_size=tuple(config["image_size"]),
)
sample = dataset[0] if len(dataset) else {}
return {
"size": len(dataset),
"sample_keys": list(sample.keys()),
"google_shape": tuple(sample["google_img"].shape) if sample else None,
"yandex_shape": tuple(sample["yandex_img"].shape) if sample else None,
}
def smoke_test_dataloader(batch_size=4):
config = create_config()
train_loader, val_loader = create_data_loaders(
root_dir=config["data_dir"],
batch_size=batch_size,
train_split=config["train_split"],
num_workers=config["num_workers"],
image_size=tuple(config["image_size"]),
)
batch = next(iter(train_loader))
return {
"train_size": len(train_loader.dataset),
"val_size": len(val_loader.dataset),
"google_batch_shape": tuple(batch["google_img"].shape),
"yandex_batch_shape": tuple(batch["yandex_img"].shape),
}
if __name__ == "__main__":
print(get_dataset_info())
print(smoke_test_dataloader())

View File

@@ -28,18 +28,26 @@ class GANTrainer:
# Optimizers
lr = config.get("learning_rate", 2e-4)
lr_d = config.get("discriminator_learning_rate", lr * config.get("discriminator_lr_factor", 0.5))
beta1 = config.get("beta1", 0.5)
beta2 = config.get("beta2", 0.999)
self.opt_G = torch.optim.Adam(model.generator.parameters(), lr=lr, betas=(beta1, beta2))
self.opt_D = torch.optim.Adam(model.discriminator.parameters(), lr=lr, betas=(beta1, beta2))
self.opt_D = torch.optim.Adam(model.discriminator.parameters(), lr=lr_d, betas=(beta1, beta2))
# Training state
self.current_epoch = 0
self.best_val_loss = float("inf")
self.g_losses = []
self.d_losses = []
self.l1_losses = []
self.ssim_losses = []
self.edge_losses = []
self.val_g_losses = []
self.val_d_losses = []
self.val_l1_losses = []
self.val_ssim_losses = []
self.val_edge_losses = []
self.val_reconstruction_losses = []
# Output dir
self.output_dir = Path(config.get("output_dir", "runs/gan"))
@@ -54,35 +62,55 @@ class GANTrainer:
"""Train for one epoch."""
self.model.train()
total_g = total_d = 0.0
total_l1 = total_ssim = total_edge = 0.0
num_batches = len(self.train_loader)
d_update_interval = max(1, self.config.get("discriminator_update_interval", 1))
pbar = tqdm(self.train_loader, desc=f"Epoch {self.current_epoch + 1}")
for batch in pbar:
yandex_img = batch["yandex_img"].to(self.device)
pbar = tqdm(enumerate(self.train_loader), total=num_batches, desc=f"Epoch {self.current_epoch + 1}")
for batch_idx, batch in pbar:
google_img = batch["google_img"].to(self.device)
yandex_img = batch["yandex_img"].to(self.device)
# Train D
self.opt_D.zero_grad()
with torch.no_grad():
fake_img = self.model.generator(yandex_img)
d_loss = self.model.discriminator_step(yandex_img, google_img, fake_img)[0]
d_loss.backward()
self.opt_D.step()
if batch_idx % d_update_interval == 0:
self.opt_D.zero_grad()
with torch.no_grad():
fake_img = self.model.generator(google_img)
d_loss = self.model.discriminator_step(google_img, yandex_img, fake_img)[0]
d_loss.backward()
self.opt_D.step()
else:
d_loss = google_img.new_tensor(0.0)
# Train G
self.opt_G.zero_grad()
g_loss = self.model.generator_step(yandex_img, google_img)[0]
g_loss, gan_loss, l1_loss, ssim_loss, edge_loss = self.model.generator_step(google_img, yandex_img)
g_loss.backward()
self.opt_G.step()
total_g += g_loss.item()
total_d += d_loss.item()
pbar.set_postfix({"g_loss": g_loss.item(), "d_loss": d_loss.item()})
total_l1 += l1_loss.item()
total_ssim += ssim_loss.item()
total_edge += edge_loss.item()
pbar.set_postfix({
"g_loss": g_loss.item(),
"d_loss": d_loss.item(),
"l1": l1_loss.item(),
"ssim": ssim_loss.item(),
"edge": edge_loss.item(),
})
avg_g = total_g / num_batches
avg_d = total_d / num_batches
avg_l1 = total_l1 / num_batches
avg_ssim = total_ssim / num_batches
avg_edge = total_edge / num_batches
self.g_losses.append(avg_g)
self.d_losses.append(avg_d)
self.l1_losses.append(avg_l1)
self.ssim_losses.append(avg_ssim)
self.edge_losses.append(avg_edge)
return avg_g, avg_d
@torch.no_grad()
@@ -90,20 +118,32 @@ class GANTrainer:
"""Validate the model."""
self.model.eval()
total_g = total_d = 0.0
total_l1 = total_ssim = total_edge = 0.0
for batch in tqdm(self.val_loader, desc="Val"):
yandex_img = batch["yandex_img"].to(self.device)
google_img = batch["google_img"].to(self.device)
fake_img = self.model.generator(yandex_img)
g_loss = self.model.generator_step(yandex_img, google_img)[0]
d_loss = self.model.discriminator_step(yandex_img, google_img, fake_img)[0]
yandex_img = batch["yandex_img"].to(self.device)
fake_img = self.model.generator(google_img)
g_loss, _, l1_loss, ssim_loss, edge_loss = self.model.generator_step(google_img, yandex_img)
d_loss = self.model.discriminator_step(google_img, yandex_img, fake_img)[0]
total_g += g_loss.item()
total_d += d_loss.item()
total_l1 += l1_loss.item()
total_ssim += ssim_loss.item()
total_edge += edge_loss.item()
avg_g = total_g / len(self.val_loader)
avg_d = total_d / len(self.val_loader)
avg_l1 = total_l1 / len(self.val_loader)
avg_ssim = total_ssim / len(self.val_loader)
avg_edge = total_edge / len(self.val_loader)
avg_reconstruction = avg_l1 + avg_ssim + avg_edge
self.val_g_losses.append(avg_g)
self.val_d_losses.append(avg_d)
self.val_l1_losses.append(avg_l1)
self.val_ssim_losses.append(avg_ssim)
self.val_edge_losses.append(avg_edge)
self.val_reconstruction_losses.append(avg_reconstruction)
return avg_g, avg_d
def train(self, num_epochs: int):
@@ -117,23 +157,30 @@ class GANTrainer:
train_g, train_d = self.train_epoch()
val_g, val_d = self.validate()
# Save best checkpoint
val_total = val_g + val_d
if val_total < self.best_val_loss:
self.best_val_loss = val_total
val_reconstruction = self.val_reconstruction_losses[-1]
if val_reconstruction < self.best_val_loss:
self.best_val_loss = val_reconstruction
self.save_checkpoint("best")
# Periodic checkpoint
if (epoch + 1) % self.config.get("save_interval", 5) == 0:
self.save_checkpoint(f"epoch_{epoch + 1}")
print(f"Epoch {epoch + 1}: train_g={train_g:.4f}, train_d={train_d:.4f}, val_g={val_g:.4f}, val_d={val_d:.4f}")
print(
f"Epoch {epoch + 1}: "
f"train_g={train_g:.4f}, train_d={train_d:.4f}, "
f"train_l1={self.l1_losses[-1]:.4f}, train_ssim={self.ssim_losses[-1]:.4f}, "
f"train_edge={self.edge_losses[-1]:.4f}, val_g={val_g:.4f}, val_d={val_d:.4f}, "
f"val_l1={self.val_l1_losses[-1]:.4f}, val_ssim={self.val_ssim_losses[-1]:.4f}, "
f"val_edge={self.val_edge_losses[-1]:.4f}, val_rec={val_reconstruction:.4f}"
)
# Early stopping
patience = self.config.get("early_stopping_patience", 0)
if patience > 0 and len(self.val_g_losses) > patience:
recent = self.val_g_losses[-patience:]
if all(l >= min(self.val_g_losses[:-patience]) for l in recent):
if patience > 0 and len(self.val_reconstruction_losses) > patience:
recent = self.val_reconstruction_losses[-patience:]
previous_best = min(self.val_reconstruction_losses[:-patience])
if all(loss >= previous_best for loss in recent):
print(f"Early stopping at epoch {epoch + 1}")
break
@@ -188,4 +235,4 @@ if __name__ == "__main__":
g_loss, d_loss = trainer.train_epoch()
print(f"Training step succeeded: G={g_loss:.4f}, D={d_loss:.4f}")
except Exception as e:
print(f"Training step failed: {e}")
print(f"Training step failed: {e}")