diff --git a/.vscode/c_cpp_properties.json b/.vscode/c_cpp_properties.json
new file mode 100644
index 0000000..79ddb0c
--- /dev/null
+++ b/.vscode/c_cpp_properties.json
@@ -0,0 +1,21 @@
+{
+ "configurations": [
+ {
+ "name": "Win32",
+ "includePath": [
+ "${workspaceFolder}/**"
+ ],
+ "defines": [
+ "_DEBUG",
+ "UNICODE",
+ "_UNICODE"
+ ],
+ "windowsSdkVersion": "10.0.26100.0",
+ "compilerPath": "cl.exe",
+ "cStandard": "c17",
+ "cppStandard": "c++17",
+ "intelliSenseMode": "windows-msvc-x64"
+ }
+ ],
+ "version": 4
+}
\ No newline at end of file
diff --git a/autopilot.py b/autopilot.py
index 1383e99..8ff8752 100644
--- a/autopilot.py
+++ b/autopilot.py
@@ -8,6 +8,7 @@ import cv2
import numpy as np
from PIL import Image
+import sian_similarity
from timer import Timer
from vision_chunk import VisionChunk
@@ -58,8 +59,16 @@ class AutoPilot(Pilot):
# Положение на основе ориентира
reserved_pos: Position | None
proccessing_time: float
+ use_sian_similarity: bool
- def __init__(self, points = [], chunks = [], viz_manager=None, pixel_ratio: float = 1.):
+ def __init__(
+ self,
+ points = [],
+ chunks = [],
+ viz_manager=None,
+ pixel_ratio: float = 1.,
+ use_sian_similarity: bool = False,
+ ):
self.prev_chunk = None
self.pos = Position(0, 0, 1, 0, 0, 0)
self.chunks = chunks
@@ -67,6 +76,7 @@ class AutoPilot(Pilot):
self.vis_manager = viz_manager # Менеджер визуализации
self.reserved_pos = None
self.pixel_ratio = pixel_ratio
+ self.use_sian_similarity = use_sian_similarity
# Пороговые значения качества сопоставления/гомографии
self.min_inliers: int = 12
@@ -153,17 +163,33 @@ class AutoPilot(Pilot):
landmark_timer = Timer()
landmark_timer.start()
- cur_pos = np.array([self.pos.x, self.pos.y])
- closest_chunk_idx = ((self.chunk_points - cur_pos) ** 2).sum(1).argmin()
-
current_chunk = self.prev_chunk
- landmark_chunk = self.chunks[closest_chunk_idx]
+ if current_chunk is None or not self.chunks:
+ return None
+
+ if self.use_sian_similarity:
+ similarity_scores = sian_similarity.get_similarity_scores(current_chunk, self.chunks)
+ best_chunk_idx = int(np.argmax(similarity_scores))
+ best_similarity_score = similarity_scores[best_chunk_idx]
+
+ print(f"[LANDMARK]: best similarity={best_similarity_score:.4f}, idx={best_chunk_idx}")
+
+ if best_similarity_score < sian_similarity.get_threshold():
+ print("[LANDMARK]: not similar")
+ return None
+
+ print("[LANDMARK]: similar")
+ landmark_chunk = self.chunks[best_chunk_idx]
+ else:
+ cur_pos = np.array([self.pos.x, self.pos.y])
+ closest_chunk_idx = ((self.chunk_points - cur_pos) ** 2).sum(1).argmin()
+ landmark_chunk = self.chunks[closest_chunk_idx]
if constants.DEBUG_FPS:
- print(f"[LANDMARK]: Closest chunk finding: {landmark_timer.loop() * 1000:.2f} ms")
+ print(f"[LANDMARK]: Landmark chunk finding: {landmark_timer.loop() * 1000:.2f} ms")
# Краевой случай: отсутствие чанков
- if current_chunk is None or landmark_chunk is None:
+ if landmark_chunk is None:
return None
landmark_timer.start()
@@ -312,10 +338,10 @@ class AutoPilot(Pilot):
# Пытаемся найти ориентир на картинке:
self.prev_chunk = current_chunk
# Для улучшения среднего FPS
- # if self.frame_count % 5 == 0:
- # pos_by_chunk = self.get_position_by_chunk()
- # if pos_by_chunk is not None:
- # self.pos = pos_by_chunk
+ if self.frame_count % 5 == 0:
+ pos_by_chunk = self.get_position_by_chunk()
+ if pos_by_chunk is not None:
+ self.pos = pos_by_chunk
command = self.make_command()
self.timer.reset()
diff --git a/dissertation/_media/gan_patchgan_discriminator.svg b/dissertation/_media/gan_patchgan_discriminator.svg
new file mode 100644
index 0000000..c89e7d3
--- /dev/null
+++ b/dissertation/_media/gan_patchgan_discriminator.svg
@@ -0,0 +1,118 @@
+
diff --git a/dissertation/_media/gan_pipeline.svg b/dissertation/_media/gan_pipeline.svg
new file mode 100644
index 0000000..912246f
--- /dev/null
+++ b/dissertation/_media/gan_pipeline.svg
@@ -0,0 +1,119 @@
+
diff --git a/dissertation/_media/gan_unet_generator.svg b/dissertation/_media/gan_unet_generator.svg
new file mode 100644
index 0000000..765567e
--- /dev/null
+++ b/dissertation/_media/gan_unet_generator.svg
@@ -0,0 +1,104 @@
+
diff --git a/dissertation/chapter_2/2.3_deep_learning/2.3.3_additional.md b/dissertation/chapter_2/2.3_deep_learning/2.3.3_additional.md
index 1d864b7..8648653 100644
--- a/dissertation/chapter_2/2.3_deep_learning/2.3.3_additional.md
+++ b/dissertation/chapter_2/2.3_deep_learning/2.3.3_additional.md
@@ -1,3 +1,39 @@
-2.3.3 Применение архитектуры сиамских близнецов для вычисления матрицы гомографии между двумя кадрами
+2.3.3 Генеративно-состязательная сеть для приведения карт к единому домену
+При сопоставлении кадров БПЛА с эталонной картой возникает не только геометрическое, но и доменное различие изображений. Даже если два фрагмента описывают один и тот же участок местности, снимки из разных источников могут иметь разные цвета, толщину дорог, подписи, условные обозначения, контраст зданий и набор отображаемых ориентиров. Например, один и тот же район на карте Google и на карте Яндекс визуально отличается настолько, что классические алгоритмы поиска ключевых точек могут находить мало устойчивых совпадений или формировать большое число ложных соответствий.
+Одним из способов уменьшить этот доменный разрыв является предварительное преобразование изображения из одного картографического домена в другой. В данной работе рассматривается генеративно-состязательная сеть (Generative Adversarial Network, GAN), которая переводит фрагмент карты Google в визуальный стиль карты Яндекс. После такого преобразования исходный фрагмент Google становится ближе к эталонному домену Яндекс, а значит для дальнейшей локализации можно применять классические методы выделения и сопоставления ключевых точек: ORB, SIFT, AKAZE, BRISK и последующую оценку матрицы гомографии при помощи RANSAC.
+
+В отличие от модели сиамских близнецов, которая напрямую оценивает схожесть пары изображений, GAN решает вспомогательную задачу нормализации домена. То есть нейросеть не заменяет классический алгоритм сопоставления, а подготавливает данные так, чтобы классический алгоритм работал в более благоприятных условиях.
+
+
+
+Рисунок 8 – Применение GAN для приведения картографических изображений к единому домену
+
+Архитектура модели построена по принципу pix2pix, так как для обучения доступны парные изображения одного и того же участка местности в двух доменах: Google Maps и Яндекс.Карты. На вход генератора подается изображение Google размером \left(B,3,256,256\right), где B – размер пакета данных. Генератор формирует изображение \hat{Y}, визуально соответствующее стилю Яндекс.Карт. Дискриминатор получает на вход пару изображений и должен определить, является ли пара реальной \left(G,Y\right) или сгенерированной \left(G,\hat{Y}\right), где G – исходный фрагмент Google, Y – настоящий фрагмент Яндекс, \hat{Y} – результат работы генератора.
+
+Генератор реализован в виде U-Net. Энкодер последовательно уменьшает пространственное разрешение изображения и извлекает признаки высокого уровня: структуру дорог, контуры кварталов, границы зданий, водные объекты и другие устойчивые элементы карты. Декодер восстанавливает изображение в целевом стиле. Между симметричными уровнями энкодера и декодера используются skip-соединения, которые передают локальную геометрию напрямую из ранних слоев в поздние. Это важно для навигационной задачи: модель должна изменить стиль карты, но не должна смещать дороги, перекрестки и контуры объектов, так как именно они затем используются как ориентиры.
+
+
+
+Рисунок 9 – Архитектура генератора U-Net
+
+Дискриминатор реализован как PatchGAN. В отличие от обычного дискриминатора, который выдает одно значение для всего изображения, PatchGAN оценивает реалистичность локальных областей. На вход дискриминатора подается конкатенация исходного изображения Google и изображения Яндекс по каналам, поэтому входной тензор имеет размер \left(B,6,256,256\right). Далее изображение проходит через сверточные блоки с постепенным уменьшением разрешения, а выходом является карта оценок для отдельных фрагментов. Такой подход подходит для картографических изображений, потому что локальные признаки – ширина дорог, стиль подписей, границы объектов, цветовые переходы – важнее глобальной художественной реалистичности.
+
+
+
+Рисунок 10 – Архитектура дискриминатора PatchGAN
+
+Обучение модели является состязательным. Генератор стремится сформировать такое изображение \hat{Y}, чтобы дискриминатор считал пару \left(G,\hat{Y}\right) реальной. Дискриминатор, наоборот, учится отличать настоящие пары \left(G,Y\right) от сгенерированных. Для сохранения геометрии карты одной только состязательной функции потерь недостаточно, поэтому итоговая функция потерь генератора включает несколько компонентов:
+
+L_G = \lambda_{GAN}L_{GAN}\left(D\left(G,\hat{Y}\right),1\right)+\lambda_{L1}\left\|\hat{Y}-Y\right\|_1+\lambda_{SSIM}L_{SSIM}\left(\hat{Y},Y\right)+\lambda_{edge}L_{edge}\left(\hat{Y},Y\right) (1)
+
+где L_{GAN} – состязательная функция потерь, L1 – попиксельная ошибка между сгенерированным и настоящим изображением Яндекс, L_{SSIM} – структурная ошибка, сохраняющая сходство локальной структуры, L_{edge} – ошибка по картам границ, вычисленным оператором Собеля. Коэффициенты \lambda_{GAN}, \lambda_{L1}, \lambda_{SSIM} и \lambda_{edge} задают вклад каждого компонента. В реализованной модели используются значения \lambda_{GAN}=0.5, \lambda_{L1}=150, \lambda_{SSIM}=25, \lambda_{edge}=20. Усиленные L1, SSIM и edge-компоненты делают модель менее «творческой», но лучше сохраняют контуры дорог и объектов, что важнее для последующего поиска ключевых точек.
+
+Функция потерь дискриминатора имеет вид:
+
+L_D = \frac{1}{2}\left(L_{GAN}\left(D\left(G,Y\right),1\right)+L_{GAN}\left(D\left(G,\hat{Y}\right),0\right)\right) (2)
+
+После обучения GAN может использоваться в навигационном пайплайне следующим образом. Сначала для области предполагаемого положения БПЛА загружается или выбирается фрагмент Google Maps. Затем генератор переводит этот фрагмент в стиль Яндекс.Карт. На полученном изображении и на эталонном фрагменте Яндекс.Карт выделяются ключевые точки и дескрипторы. Далее дескрипторы сопоставляются, ложные соответствия отбрасываются при помощи RANSAC, а по оставшимся точкам оценивается матрица гомографии. Полученная матрица позволяет связать координаты текущего изображения с координатами эталонной карты и уточнить положение БПЛА.
+
+Таким образом, GAN выступает как промежуточный модуль доменной адаптации. Его применение особенно полезно в ситуации, когда источник доступной карты и источник эталонных ориентиров различаются. В рассматриваемой задаче это позволяет перевести изображения Google в представление, близкое к Яндекс.Картам, где ориентиры визуально согласованы между собой. Благодаря этому классические методы компьютерного зрения получают более похожие изображения и могут устойчивее находить ключевые точки, не требуя полного отказа от интерпретируемого геометрического пайплайна.
diff --git a/dissertation/chapter_2/2.3_deep_learning/readme.md b/dissertation/chapter_2/2.3_deep_learning/readme.md
index 01c5cbe..2ef72d5 100644
--- a/dissertation/chapter_2/2.3_deep_learning/readme.md
+++ b/dissertation/chapter_2/2.3_deep_learning/readme.md
@@ -6,3 +6,4 @@
|-----------|----------|------|
| 2.3.1 | Применение архитектуры сиамских близнецов для сопоставления кадров из различных доменов | 2.3.1_siamese_match.md |
| 2.3.2 | Применение архитектуры сиамских близнецов для вычисления матрицы гомографии | 2.3.2_siamese_homography.md |
+| 2.3.3 | Генеративно-состязательная сеть для приведения карт к единому домену | 2.3.3_additional.md |
diff --git a/dissertation/chapter_2/readme.md b/dissertation/chapter_2/readme.md
index b40136d..4804abe 100644
--- a/dissertation/chapter_2/readme.md
+++ b/dissertation/chapter_2/readme.md
@@ -9,5 +9,6 @@
| 2.3 | Методы глубокого обучения | 2.3_deep_learning/ |
| 2.3.1 | Сиамские близнецы для сопоставления кадров из различных доменов | 2.3_deep_learning/2.3.1_siamese_match.md |
| 2.3.2 | Сиамские близнецы для вычисления матрицы гомографии | 2.3_deep_learning/2.3.2_siamese_homography.md |
+| 2.3.3 | Генеративно-состязательная сеть для приведения карт к единому домену | 2.3_deep_learning/2.3.3_additional.md |
| 2.4 | Датасет | 2.4_dataset/ |
| 2.5 | Обучение моделей глубокого обучения | 2.5_training/ |
diff --git a/docs/google-city-1 --use-sian-similarity.txt b/docs/google-city-1 --use-sian-similarity.txt
new file mode 100644
index 0000000..5c93abf
--- /dev/null
+++ b/docs/google-city-1 --use-sian-similarity.txt
@@ -0,0 +1,41 @@
+google-city-1 --use-sian-similarity
+MSE: 1006.2278793057307
+RMSE: 31.721095178220608
+Average FPS: 2.1113510203551025
+
+google-city-2 --use-sian-similarity
+MSE: 0.018811735940268307
+RMSE: 0.13715588190182842
+Average FPS: 4.383734633135686
+
+google-city-3 --use-sian-similarity
+MSE: 0.06464743825147634
+RMSE: 0.2542586050686905
+Average FPS: 4.047345287474747
+
+google-city-4 --use-sian-similarity
+MSE: 0.3089718382086463
+RMSE: 0.5558523528857697
+Average FPS: 2.9425785696773636
+
+
+
+google-city-1
+MSE: 716.9078171684257
+RMSE: 26.775134307196776
+Average FPS: 38.766897877580504
+
+google-city-2
+MSE: 0.06237516673056363
+RMSE: 0.2497502086697099
+Average FPS: 25.821371544698586
+
+google-city-3
+MSE: 0.07598762118336233
+RMSE: 0.2756585227838282
+Average FPS: 26.37680633915316
+
+google-city-3
+MSE: 0.7199724273833397
+RMSE: 0.8485118899481254
+Average FPS: 29.534802136097262
\ No newline at end of file
diff --git a/main.py b/main.py
index 31d3d90..807491c 100644
--- a/main.py
+++ b/main.py
@@ -120,7 +120,7 @@ def build(name: str, map_name: str, lat: float, lon: float):
sleep(15)
online_map.destroy()
-def run(name: str, map_name: str, ref_min_distance: float):
+def run(name: str, map_name: str, ref_min_distance: float, use_sian_similarity: bool = False):
dir = Path('trajectories')
assert dir.exists()
dir /= name
@@ -158,7 +158,13 @@ def run(name: str, map_name: str, ref_min_distance: float):
print("READ POINTS:", points)
vis_manager = VisualizationManager()
- pilot = autopilot.AutoPilot(points, chunks, vis_manager, online_map.pixel_ratio)
+ pilot = autopilot.AutoPilot(
+ points,
+ chunks,
+ vis_manager,
+ online_map.pixel_ratio,
+ use_sian_similarity=use_sian_similarity,
+ )
simulator = Simulator(online_map)
pilot.target_idx = 0
@@ -300,6 +306,12 @@ def parse_args():
help='Включить отладку эталонов'
)
+ parser.add_argument(
+ '--use-sian-similarity',
+ action='store_true',
+ help='Выбирать ориентир через SiaN similarity вместо ближайшего по текущей позиции'
+ )
+
# Парсим аргументы
args = parser.parse_args()
@@ -327,4 +339,4 @@ if __name__ == "__main__":
build(name, ref, lat, lon)
if mode == 'run' or mode == 'standalone':
- run(name, sim, rmd)
+ run(name, sim, rmd, args.use_sian_similarity)
diff --git a/map.jpg b/map.jpg
index 071d7b8..d8f8ff8 100644
Binary files a/map.jpg and b/map.jpg differ
diff --git a/models/GAN/README.md b/models/GAN/README.md
index b7c7824..dae5679 100644
--- a/models/GAN/README.md
+++ b/models/GAN/README.md
@@ -1,254 +1,111 @@
-# GAN Trainer для преобразования изображений Yandex → Google
-
-Этот модуль содержит реализацию тренера для GAN (Generative Adversarial Network) модели, предназначенной для преобразования изображений карт Yandex в стиль Google Maps.
-
-## Структура проекта
-
-```
-autopilot/models/GAN/
-├── gan.py # Основная реализация GAN модели
-├── trainer.py # Тренер для обучения GAN
-├── test_trainer.py # Тесты для тренера
-├── train_example.py # Пример использования тренера
-└── README.md # Этот файл
-```
-
-## Модель GAN
-
-Модель состоит из двух основных компонентов:
-
-### 1. Генератор (GeneratorUNet)
-- Архитектура U-Net для преобразования изображений
-- Принимает изображение Yandex (3 канала RGB)
-- Возвращает изображение в стиле Google (3 канала RGB)
-- Использует skip connections для сохранения деталей
-
-### 2. Дискриминатор (DiscriminatorPatchGAN)
-- PatchGAN архитектура
-- Принимает пару изображений (Yandex + Google)
-- Возвращает вероятность того, что пара реальная
-- Работает с патчами изображения 41x41
-
-### Функция потерь (GANLoss)
-Поддерживает три режима:
-- `vanilla`: Бинарная кросс-энтропия
-- `lsgan`: Least Squares GAN (более стабильный)
-- `wgangp`: Wasserstein GAN with Gradient Penalty
-
-## Тренер (GANTrainer)
-
-### Основные возможности
-
-1. **Обучение с чередованием**:
- - Обучение генератора и дискриминатора поочередно
- - Поддержка L1 потерь для сохранения структуры
-
-2. **Валидация и мониторинг**:
- - Отдельные потери для генератора и дискриминатора
- - Логирование в TensorBoard
- - Ранняя остановка
-
-3. **Сохранение и загрузка**:
- - Чекпоинты каждой эпохи
- - Лучшая модель
- - Финальная модель
- - История обучения
-
-4. **Оценка модели**:
- - Метрики на тестовом наборе
- - Генерация примеров
-
-### Быстрый старт
-
-```python
-import torch
-from torch.utils.data import DataLoader
-from models.GAN.gan import create_image_gan
-from models.GAN.trainer import GANTrainer
-
-# Конфигурация
-config = {
- "learning_rate": 2e-4,
- "beta1": 0.5,
- "beta2": 0.999,
- "batch_size": 4,
- "output_dir": "runs/gan_training",
- "gan_mode": "vanilla",
- "lambda_L1": 100.0,
- "early_stopping_patience": 20,
-}
-
-# Устройство
-device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
-
-# Создание модели
-model = create_image_gan(
- input_channels=3,
- output_channels=3,
- gan_mode=config["gan_mode"],
- lambda_L1=config["lambda_L1"],
- use_cuda=(device.type == "cuda"),
-)
-
-# Создание даталоадеров (замените на свои данные)
-train_loader = DataLoader(train_dataset, batch_size=config["batch_size"], shuffle=True)
-val_loader = DataLoader(val_dataset, batch_size=config["batch_size"], shuffle=False)
-
-# Создание тренера
-trainer = GANTrainer(
- model=model,
- train_loader=train_loader,
- val_loader=val_loader,
- device=device,
- config=config,
-)
-
-# Обучение
-trainer.train(num_epochs=100)
-
-# Оценка
-metrics = trainer.evaluate(test_loader)
-```
-
-### Конфигурация обучения
-
-#### Базовая конфигурация
-```python
-config = {
- # Параметры оптимизатора
- "learning_rate": 2e-4, # Learning rate
- "beta1": 0.5, # Adam beta1
- "beta2": 0.999, # Adam beta2
-
- # Параметры обучения
- "batch_size": 4, # Размер батча
- "epochs": 100, # Количество эпох
-
- # Параметры GAN
- "gan_mode": "vanilla", # Режим GAN
- "lambda_L1": 100.0, # Вес L1 потерь
-
- # Регуляризация
- "grad_clip": 1.0, # Gradient clipping
-
- # Ранняя остановка
- "early_stopping_patience": 20,
-
- # Выходные данные
- "output_dir": "runs/gan",
-}
-```
-
-#### Расширенная конфигурация
-```python
-config = {
- "learning_rate": 2e-4,
- "beta1": 0.5,
- "beta2": 0.999,
- "batch_size": 8,
- "epochs": 200,
- "gan_mode": "lsgan", # Более стабильный LSGAN
- "lambda_L1": 100.0,
- "grad_clip": 1.0,
- "weight_decay": 1e-4, # Weight decay
- "early_stopping_patience": 30,
- "early_stopping_min_delta": 1e-4,
- "output_dir": "runs/gan_advanced",
-}
-```
-
-### Методы тренера
-
-#### Основные методы
-- `train_epoch()`: Обучение на одной эпохе
-- `validate()`: Валидация модели
-- `train(num_epochs)`: Полное обучение
-- `evaluate(test_loader)`: Оценка на тестовых данных
-
-#### Управление чекпоинтами
-- `save_checkpoint(is_best=False)`: Сохранение чекпоинта
-- `load_checkpoint(path, resume_training=False)`: Загрузка чекпоинта
-
-### Выходные файлы
-
-После обучения создаются следующие файлы:
-
-```
-runs/gan_training/
-├── config.json # Конфигурация обучения
-├── training_history.json # История потерь
-├── model_best.pth # Лучшая модель
-├── model_final.pth # Финальная модель
-├── checkpoint_epoch_1.pth # Чекпоинты каждой эпохи
-├── checkpoint_epoch_2.pth
-├── ...
-└── tensorboard/ # Логи TensorBoard
- ├── events.out.tfevents...
- └── ...
-```
-
-### TensorBoard
-
-Для визуализации обучения используйте TensorBoard:
-
-```bash
-tensorboard --logdir runs/gan_training/tensorboard
-```
-
-Доступные метрики:
-- `train/batch_g_loss`: Потери генератора на батче
-- `train/batch_d_loss`: Потери дискриминатора на батче
-- `train/batch_g_l1_loss`: L1 потери генератора
-- `train/epoch_g_loss`: Потери генератора на эпохе
-- `train/epoch_d_loss`: Потери дискриминатора на эпохе
-- `val/epoch_g_loss`: Валидационные потери генератора
-- `val/epoch_d_loss`: Валидационные потери дискриминатора
-
-### Тестирование
-
-Запустите тесты для проверки работоспособности:
-
-```bash
-python models/GAN/test_trainer.py
-```
-
-### Пример использования
-
-Полный пример использования смотрите в `train_example.py`.
-
-### Советы по обучению
-
-1. **Начальные значения**:
- - Используйте `gan_mode="lsgan"` для более стабильного обучения
- - Начните с `lambda_L1=100.0` и регулируйте по необходимости
- - Используйте маленький `batch_size` (4-8) при ограниченной памяти GPU
-
-2. **Мониторинг**:
- - Следите за балансом потерь генератора и дискриминатора
- - Если потери дискриминатора близки к 0, генератор не обучается
- - Если потери генератора слишком высоки, уменьшите `lambda_L1`
-
-3. **Визуализация**:
- - Регулярно генерируйте примеры для визуальной оценки
- - Используйте TensorBoard для отслеживания прогресса
-
-### Устранение проблем
-
-#### Высокие потери генератора
-- Уменьшите `lambda_L1`
-- Увеличьте learning rate
-- Проверьте качество данных
-
-#### Дискриминатор слишком сильный
-- Уменьшите learning rate дискриминатора
-- Добавьте dropout в дискриминатор
-- Обучайте генератор чаще, чем дискриминатор
-
-#### Недостаток памяти GPU
-- Уменьшите `batch_size`
-- Уменьшите размер изображений
-- Используйте gradient accumulation
-
-### Лицензия
-
-Этот проект является частью Autopilot системы.
\ No newline at end of file
+# GAN для преобразования Google -> Yandex
+
+Модуль реализует pix2pix-подобный GAN для парных изображений карт. Генератор получает изображение Google из пары и пытается сгенерировать изображение в стиле Yandex. Сгенерированная картинка сравнивается со вторым изображением пары, то есть с оригинальным `*_yandex.png`.
+
+## Структура
+
+```text
+models/GAN/
+├── build.py # генератор notebook.gen.ipynb по схеме
+├── _schema.py # структура генерируемого ноутбука
+├── _schema.md # описание формата схемы
+├── notebook.gen.ipynb
+├── src/
+│ ├── config.py
+│ ├── dataloader.py
+│ ├── model.py
+│ ├── trainer.py
+│ ├── analyze.py
+│ └── main.py
+└── README.md
+```
+
+## Архитектура
+
+`GeneratorUNet` принимает `google_img` с 3 RGB-каналами и возвращает `fake_yandex`.
+
+`DiscriminatorPatchGAN` получает пару `(google_img, yandex_img)` и отличает настоящую пару от `(google_img, fake_yandex)`.
+
+`ImageGAN.generator_step()` считает:
+
+- adversarial loss: дискриминатор должен принять `(google_img, fake_yandex)` за реальную пару;
+- L1 loss: `fake_yandex` сравнивается с оригинальным `yandex_img` из той же пары.
+- SSIM loss: штрафует структурные отличия, чтобы карта не расплывалась;
+- Sobel edge loss: сохраняет контуры дорог и объектов для последующего поиска ключевых точек.
+
+Итоговая функция потерь генератора:
+
+```text
+G_loss =
+ lambda_GAN * GAN_loss(D(google_img, fake_yandex), real)
+ + lambda_L1 * L1(fake_yandex, yandex_img)
+ + lambda_SSIM * SSIMLoss(fake_yandex, yandex_img)
+ + lambda_edge * SobelEdgeLoss(fake_yandex, yandex_img)
+```
+
+По умолчанию используется `lsgan`, усиленная реконструкция и более медленный
+дискриминатор. Это менее “креативно”, зато обычно даёт более чистые контуры.
+
+## Запуск
+
+Из папки `models/GAN`:
+
+```bash
+python src/main.py
+```
+
+`src/main.py` специально написан без функции `main()`: при генерации ноутбука
+все переменные остаются доступны после запуска ячейки, как в SiaN.
+
+Чтобы пересобрать ноутбук из схемы:
+
+```bash
+python build.py
+```
+
+По умолчанию датасет берется из:
+
+```text
+C:\Users\admin\Projects\autopilot\datasets\ya_go_maps\images
+```
+
+Путь, размер изображений, batch size и число эпох меняются в `src/config.py`.
+
+Если CUDA видна, но установленный PyTorch не поддерживает архитектуру GPU
+например Tesla P100 `sm_60`, код автоматически переключится на CPU. Это
+управляется параметром `prefer_cuda` в `src/config.py`.
+
+## Ожидаемый формат данных
+
+В директории датасета должны лежать пары:
+
+```text
+0000_google.png
+0000_yandex.png
+0001_google.png
+0001_yandex.png
+...
+```
+
+Даталоадер возвращает:
+
+- `google_img`: вход генератора;
+- `yandex_img`: целевое изображение для сравнения;
+- `idx`: номер пары.
+
+## Чекпоинты
+
+Тренер сохраняет чекпоинты в:
+
+```text
+models/GAN/runs/checkpoints
+```
+
+Сохраняются `best.pth`, периодические `epoch_N.pth` и `final.pth`.
+
+После обучения `src/main.py` также строит:
+
+- графики `G/D/L1/SSIM/edge loss`;
+- сетку `Google input -> Generated Yandex -> Yandex target`.
+
+Файлы сохраняются в `models/GAN/runs/images`.
diff --git a/models/GAN/_schema.md b/models/GAN/_schema.md
new file mode 100644
index 0000000..407668c
--- /dev/null
+++ b/models/GAN/_schema.md
@@ -0,0 +1,32 @@
+# GAN Schema
+
+Notebook structure definition for the Google -> Yandex GAN model.
+
+---
+
+## Format
+
+```text
+# === IMPORTS ===
+
+
+# markdown
+"""Description"""
+
+# code: ./src/file.py
+
+# # shell:
+
+```
+
+## Directives
+
+| Directive | Description |
+|----------------|------------------------------------|
+| `# code:` | Include file from `src/` |
+| `# markdown` | Add markdown cell |
+| `# # shell:` | Add notebook shell command cell |
+
+---
+
+`build.py` generates `notebook.gen.ipynb` from `_schema.py`.
diff --git a/models/GAN/_schema.py b/models/GAN/_schema.py
new file mode 100644
index 0000000..424259a
--- /dev/null
+++ b/models/GAN/_schema.py
@@ -0,0 +1,121 @@
+# _schema.py
+
+# === IMPORTS ===
+import json
+import os
+from pathlib import Path
+from typing import Any, Dict, List, Tuple
+
+import matplotlib.pyplot as plt
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from PIL import Image
+from torch.utils.data import DataLoader, Dataset, Subset
+from torchvision import transforms
+from tqdm import tqdm
+
+# markdown
+"""# Configuration
+
+Global settings for the Google -> Yandex GAN:
+- Dataset path and image size
+- Optimizer and training hyperparameters
+- Device preference with safe CUDA compatibility checks
+- GAN, L1, SSIM and edge reconstruction weights
+- Output directories for checkpoints and generated samples"""
+# code: ./src/config.py
+
+# markdown
+"""## Dataset
+
+Google/Yandex paired image loader.
+
+**Direction:**
+- `google_img` is the generator input
+- `yandex_img` is the target image from the same pair
+
+**Returns:**
+- Batch dict with `google_img`, `yandex_img`, `idx`"""
+# code: ./src/dataloader.py
+
+# code: ./src/test_dataloader.py
+
+# markdown
+"""## Model
+
+Pix2pix-style GAN for Google -> Yandex map translation.
+
+**Generator:**
+- `GeneratorUNet`
+- Input: Google image `(B, 3, H, W)`
+- Output: generated Yandex image `(B, 3, H, W)`
+
+**Discriminator:**
+- `DiscriminatorPatchGAN`
+- Input pair: `(google_img, yandex_img)`
+- Learns to distinguish real pairs from `(google_img, fake_yandex)`
+
+**Generator loss:**
+- adversarial loss
+- `lambda_L1 * L1(fake_yandex, yandex_img)`
+- `lambda_SSIM * SSIMLoss(fake_yandex, yandex_img)`
+- `lambda_edge * SobelEdgeLoss(fake_yandex, yandex_img)`
+
+The generator uses bilinear upsampling followed by convolution to avoid
+checkerboard artifacts from transposed convolutions."""
+# code: ./src/model.py
+
+# markdown
+"""## Training
+
+`GANTrainer` trains discriminator and generator alternately.
+
+**Training step:**
+1. Generate `fake_yandex = G(google_img)`
+2. Train discriminator on real pair `(google_img, yandex_img)` and fake pair `(google_img, fake_yandex)`
+3. Train generator against discriminator and paired Yandex target
+
+**Checkpoint saving:**
+- `best.pth`
+- `epoch_N.pth`
+- `final.pth`"""
+# code: ./src/trainer.py
+
+# markdown
+"""## Analysis
+
+Visualization helpers for generated samples and collected training metrics.
+
+Training history plot contains:
+1. Generator loss
+2. Discriminator loss
+3. L1 loss against the paired Yandex target
+4. SSIM structure loss
+5. Sobel edge loss
+6. Best-checkpoint reconstruction score
+
+The sample grid contains:
+1. Google input
+2. Generated Yandex
+3. Original Yandex target"""
+# code: ./src/analyze.py
+
+# markdown
+"""## Main Pipeline
+
+Executes the full GAN workflow:
+1. Create config
+2. Build paired data loaders
+3. Initialize Google -> Yandex GAN
+4. Train with validation
+5. Save checkpoints in `runs/checkpoints/`
+6. Show loss plots and generated sample grid
+
+This block is intentionally top-level, not wrapped in `main()`, so notebook
+variables such as `model`, `trainer`, `train_loader`, `val_loader`, and
+`training_analysis` remain available for debugging."""
+# code: ./src/main.py
+
+# # shell:
+# !zip artefacts.zip runs/checkpoints/best.pth runs/images/training_history.png runs/images/generation_samples.png
diff --git a/models/GAN/build.py b/models/GAN/build.py
new file mode 100644
index 0000000..cdf5fbd
--- /dev/null
+++ b/models/GAN/build.py
@@ -0,0 +1,260 @@
+import json
+import os
+import re
+
+
+def parse_schema(schema_path):
+ with open(schema_path, 'r', encoding='utf-8') as f:
+ content = f.read()
+
+ imports = []
+ items = []
+
+ in_imports_section = False
+ lines = content.split('\n')
+
+ i = 0
+ while i < len(lines):
+ line = lines[i]
+ stripped = line.strip()
+
+ if stripped == '# === IMPORTS ===':
+ in_imports_section = True
+ i += 1
+ continue
+ elif stripped.startswith('# ==='):
+ in_imports_section = False
+ i += 1
+ continue
+
+ if in_imports_section and stripped and not stripped.startswith('#'):
+ imports.append(stripped)
+ i += 1
+ continue
+
+ if in_imports_section and stripped.startswith('#'):
+ in_imports_section = False
+
+ if stripped.startswith('# code:'):
+ match = re.search(r'# code:\s*(.+)', stripped)
+ if match:
+ items.append(('code', match.group(1).strip()))
+ i += 1
+ continue
+
+ if stripped.startswith('# inline:') or stripped.startswith('# # shell:'):
+ directive = 'inline' if stripped.startswith('# inline:') else 'shell'
+ i += 1
+ code_lines = []
+ while i < len(lines):
+ stripped = lines[i].strip()
+ if directive == 'shell':
+ if stripped.startswith('# ') or stripped.startswith('# # '):
+ line_content = lines[i].strip()
+ if line_content.startswith('# # '):
+ line_content = line_content[3:].lstrip()
+ else:
+ line_content = line_content[2:]
+ code_lines.append(line_content)
+ i += 1
+ continue
+ if stripped.startswith('#'):
+ if stripped.startswith('# code:') or stripped.startswith('# inline:') or stripped.startswith('# # shell:') or stripped == '# markdown' or stripped.startswith('# ==='):
+ break
+ i += 1
+ continue
+ if stripped == '':
+ i += 1
+ continue
+ code_lines.append(lines[i])
+ i += 1
+ if code_lines:
+ items.append((directive, '\n'.join(code_lines).rstrip()))
+ continue
+
+ if stripped == '# markdown':
+ i += 1
+ if i < len(lines):
+ next_line = lines[i].strip()
+ if next_line.startswith('"""'):
+ if next_line.endswith('"""') and len(next_line) > 3:
+ md_content = next_line[3:-3].strip()
+ i += 1
+ else:
+ end_idx = None
+ for j in range(i + 1, len(lines)):
+ if '"""' in lines[j]:
+ end_idx = j
+ break
+ if end_idx:
+ md_content = '\n'.join(lines[i:end_idx])
+ md_content = md_content.strip('"""').strip()
+ i = end_idx + 1
+ else:
+ md_content = ""
+ i += 1
+ items.append(('markdown', md_content))
+ continue
+
+ i += 1
+
+ return imports, items
+
+
+def strip_all_imports(content):
+ lines = content.split('\n')
+ result_lines = []
+ skip_block = False
+ if_block_indent = 0
+
+ for line in lines:
+ stripped = line.strip()
+
+ if stripped.startswith('if __name__') or stripped.startswith('if __name__ =='):
+ skip_block = True
+ if_block_indent = len(line) - len(line.lstrip())
+ continue
+
+ if skip_block:
+ current_indent = len(line) - len(line.lstrip())
+ if line.strip() == '':
+ continue
+ if current_indent < if_block_indent:
+ skip_block = False
+ elif current_indent == if_block_indent and stripped.startswith('if '):
+ skip_block = True
+ if_block_indent = current_indent
+ continue
+ else:
+ continue
+
+ if stripped.startswith('import ') or stripped.startswith('from '):
+ continue
+
+ result_lines.append(line)
+
+ while result_lines and result_lines[-1].strip() == '':
+ result_lines.pop()
+
+ return '\n'.join(result_lines)
+
+
+def read_src_file(ref, base_dir):
+ ref_path = ref.replace('./src/', '').replace('src/', '')
+ full_path = os.path.join(base_dir, 'src', ref_path.replace('./src/', '').lstrip('/'))
+
+ if not full_path.endswith('.py'):
+ full_path += '.py'
+
+ with open(full_path, 'r', encoding='utf-8') as f:
+ return f.read()
+
+
+def build_notebook(schema_path, output_path=None):
+ schema_dir = os.path.dirname(os.path.abspath(schema_path))
+ if output_path is None:
+ output_path = os.path.join(schema_dir, 'notebook.gen.ipynb')
+
+ imports, items = parse_schema(schema_path)
+
+ cells = []
+
+ if imports:
+ imports_cell = {
+ "cell_type": "code",
+ "execution_count": None,
+ "metadata": {},
+ "outputs": [],
+ "source": [imp + "\n" for imp in imports]
+ }
+ cells.append(imports_cell)
+
+ for item_type, item_content in items:
+ if item_type == 'markdown':
+ md_cell = {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": item_content + "\n"
+ }
+ cells.append(md_cell)
+ elif item_type == 'inline':
+ lines = item_content.split('\n')
+ while lines and lines[-1].strip() == '':
+ lines.pop()
+ if lines:
+ lines.append('')
+ cell = {
+ "cell_type": "code",
+ "execution_count": None,
+ "metadata": {},
+ "outputs": [],
+ "source": [line + "\n" for line in lines]
+ }
+ cells.append(cell)
+ elif item_type == 'shell':
+ lines = item_content.split('\n')
+ while lines and lines[-1].strip() == '':
+ lines.pop()
+ if lines:
+ lines.append('')
+ cell = {
+ "cell_type": "code",
+ "execution_count": None,
+ "metadata": {},
+ "outputs": [],
+ "source": [line + "\n" for line in lines]
+ }
+ cells.append(cell)
+ elif item_type == 'code':
+ try:
+ content = read_src_file(item_content, schema_dir)
+ content = strip_all_imports(content)
+
+ lines = content.split('\n')
+ while lines and lines[-1].strip() == '':
+ lines.pop()
+ if lines:
+ lines.append('')
+
+ cell = {
+ "cell_type": "code",
+ "execution_count": None,
+ "metadata": {},
+ "outputs": [],
+ "source": [line + "\n" for line in lines]
+ }
+ cells.append(cell)
+ except FileNotFoundError:
+ print(f"Warning: Could not find: {item_content}")
+
+ notebook = {
+ "cells": cells,
+ "metadata": {
+ "kernelspec": {
+ "display_name": ".venv",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "name": "python",
+ "version": "3.11.0"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 5
+ }
+
+ with open(output_path, 'w', encoding='utf-8') as f:
+ json.dump(notebook, f, indent=1, ensure_ascii=False)
+
+ print(f"Notebook generated: {output_path}")
+
+
+if __name__ == "__main__":
+ script_dir = os.path.dirname(os.path.abspath(__file__))
+ schema_path = os.path.join(script_dir, '_schema.py')
+
+ if os.path.exists(schema_path):
+ build_notebook(schema_path)
+ else:
+ print(f"Error: _schema.py not found at {schema_path}")
diff --git a/models/GAN/main.py b/models/GAN/main.py
deleted file mode 100644
index 5534ac9..0000000
--- a/models/GAN/main.py
+++ /dev/null
@@ -1,30 +0,0 @@
-"""Main entry point for GAN training."""
-
-from config import create_config
-from dataloader import create_data_loaders
-from model import create_gan
-from trainer import create_trainer
-
-
-def main():
- """Run training pipeline."""
- config = create_config()
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
-
- # Create components
- model = create_gan(use_cuda=False) # Set to True to use GPU
- train_loader, val_loader = create_data_loaders(
- root_dir=config["data_dir"],
- batch_size=config["batch_size"],
- image_size=tuple(config["image_size"]),
- num_workers=config["num_workers"],
- )
- trainer = create_trainer(model, train_loader, val_loader, config)
-
- # Train
- trainer.train(config["epochs"])
-
-
-if __name__ == "__main__":
- import torch
- main()
\ No newline at end of file
diff --git a/models/GAN/notebook.gen.ipynb b/models/GAN/notebook.gen.ipynb
new file mode 100644
index 0000000..49dfcaa
--- /dev/null
+++ b/models/GAN/notebook.gen.ipynb
@@ -0,0 +1,1064 @@
+{
+ "cells": [
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import json\n",
+ "import os\n",
+ "from pathlib import Path\n",
+ "from typing import Any, Dict, List, Tuple\n",
+ "import matplotlib.pyplot as plt\n",
+ "import torch\n",
+ "import torch.nn as nn\n",
+ "import torch.nn.functional as F\n",
+ "from PIL import Image\n",
+ "from torch.utils.data import DataLoader, Dataset, Subset\n",
+ "from torchvision import transforms\n",
+ "from tqdm import tqdm\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": "# Configuration\n\nGlobal settings for the Google -> Yandex GAN:\n- Dataset path and image size\n- Optimizer and training hyperparameters\n- Device preference with safe CUDA compatibility checks\n- GAN, L1, SSIM and edge reconstruction weights\n"
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "\"\"\"Configuration for GAN training.\"\"\"\n",
+ "\n",
+ "\n",
+ "def create_config():\n",
+ " \"\"\"Create default configuration dictionary.\"\"\"\n",
+ " return {\n",
+ " # Optimizer params\n",
+ " \"learning_rate\": 2e-4,\n",
+ " \"discriminator_lr_factor\": 0.5,\n",
+ " \"beta1\": 0.5,\n",
+ " \"beta2\": 0.999,\n",
+ " # Training params\n",
+ " \"batch_size\": 32,\n",
+ " \"epochs\": 100,\n",
+ " \"prefer_cuda\": True,\n",
+ " # GAN params\n",
+ " \"gan_mode\": \"lsgan\",\n",
+ " \"lambda_GAN\": 0.5,\n",
+ " \"lambda_L1\": 150.0,\n",
+ " \"lambda_SSIM\": 25.0,\n",
+ " \"lambda_edge\": 20.0,\n",
+ " \"discriminator_update_interval\": 1,\n",
+ " # Regularization\n",
+ " \"grad_clip\": 1.0,\n",
+ " # Early stopping\n",
+ " \"early_stopping_patience\": 25,\n",
+ " # Output\n",
+ " \"output_dir\": r\"C:\\Users\\admin\\Projects\\autopilot\\models\\GAN\\runs\",\n",
+ " # Logging\n",
+ " \"log_interval\": 10,\n",
+ " \"save_interval\": 5,\n",
+ " \"num_visual_samples\": 4,\n",
+ " # Data\n",
+ " \"data_dir\": r\"C:\\Users\\admin\\Projects\\autopilot\\datasets\\ya_go_maps\\images\",\n",
+ " \"image_size\": [256, 256],\n",
+ " \"train_split\": 0.8,\n",
+ " \"num_workers\": 0,\n",
+ " }\n",
+ "\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": "## Dataset\n\nGoogle/Yandex paired image loader.\n\n**Direction:**\n- `google_img` is the generator input\n- `yandex_img` is the target image from the same pair\n\n**Returns:**\n"
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "\"\"\"Data loader for Google-to-Yandex image translation.\"\"\"\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "class YaGoDataset(Dataset):\n",
+ " \"\"\"Dataset loading paired Google and Yandex map images.\"\"\"\n",
+ "\n",
+ " def __init__(\n",
+ " self,\n",
+ " root_dir: str,\n",
+ " image_size: Tuple[int, int] = (256, 256),\n",
+ " augment: bool = False,\n",
+ " ):\n",
+ " \"\"\"\n",
+ " Args:\n",
+ " root_dir: Directory with images named {idx:04d}_google.png and {idx:04d}_yandex.png\n",
+ " image_size: Target image size (height, width)\n",
+ " augment: Whether to apply augmentation (not implemented for simplicity)\n",
+ " \"\"\"\n",
+ " self.root_dir = root_dir\n",
+ " self.image_size = image_size\n",
+ " self.augment = augment\n",
+ "\n",
+ " # Discover image pairs\n",
+ " self.pairs = self._find_pairs()\n",
+ "\n",
+ " # Transform to tensor + normalization\n",
+ " self.transform = transforms.Compose(\n",
+ " [\n",
+ " transforms.ToTensor(),\n",
+ " transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),\n",
+ " ]\n",
+ " )\n",
+ "\n",
+ " def _find_pairs(self) -> List[Dict]:\n",
+ " \"\"\"Find all matching Google-Yandex image pairs.\"\"\"\n",
+ " pairs = []\n",
+ " google_files = [f for f in os.listdir(self.root_dir) if f.endswith(\"_google.png\")]\n",
+ "\n",
+ " for google_file in sorted(google_files):\n",
+ " idx_str = google_file.split(\"_\")[0]\n",
+ " try:\n",
+ " idx = int(idx_str)\n",
+ " except ValueError:\n",
+ " continue\n",
+ "\n",
+ " yandex_file = f\"{idx:04d}_yandex.png\"\n",
+ " yandex_path = os.path.join(self.root_dir, yandex_file)\n",
+ "\n",
+ " if os.path.exists(yandex_path):\n",
+ " pairs.append(\n",
+ " {\n",
+ " \"idx\": idx,\n",
+ " \"google_path\": os.path.join(self.root_dir, google_file),\n",
+ " \"yandex_path\": yandex_path,\n",
+ " }\n",
+ " )\n",
+ "\n",
+ " return pairs\n",
+ "\n",
+ " def __len__(self) -> int:\n",
+ " return len(self.pairs)\n",
+ "\n",
+ " def __getitem__(self, idx: int) -> dict:\n",
+ " pair = self.pairs[idx]\n",
+ "\n",
+ " # Load images\n",
+ " google_img = Image.open(pair[\"google_path\"]).convert(\"RGB\")\n",
+ " yandex_img = Image.open(pair[\"yandex_path\"]).convert(\"RGB\")\n",
+ "\n",
+ " # Resize\n",
+ " google_img = google_img.resize((self.image_size[1], self.image_size[0]))\n",
+ " yandex_img = yandex_img.resize((self.image_size[1], self.image_size[0]))\n",
+ "\n",
+ " # Apply transforms\n",
+ " google_tensor = self.transform(google_img)\n",
+ " yandex_tensor = self.transform(yandex_img)\n",
+ "\n",
+ " return {\n",
+ " \"google_img\": google_tensor,\n",
+ " \"yandex_img\": yandex_tensor,\n",
+ " \"idx\": torch.tensor(pair[\"idx\"], dtype=torch.long),\n",
+ " }\n",
+ "\n",
+ "\n",
+ "def create_data_loaders(\n",
+ " root_dir: str,\n",
+ " batch_size: int = 32,\n",
+ " train_split: float = 0.8,\n",
+ " num_workers: int = 0,\n",
+ " image_size: Tuple[int, int] = (256, 256),\n",
+ ") -> Tuple[DataLoader, DataLoader]:\n",
+ " \"\"\"\n",
+ " Create train and validation data loaders.\n",
+ "\n",
+ " Args:\n",
+ " root_dir: Directory with image pairs\n",
+ " batch_size: Batch size\n",
+ " train_split: Fraction for training (0.0-1.0)\n",
+ " num_workers: DataLoader workers\n",
+ " image_size: Target image size\n",
+ "\n",
+ " Returns:\n",
+ " (train_loader, val_loader)\n",
+ " \"\"\"\n",
+ " # Full dataset\n",
+ " dataset = YaGoDataset(root_dir=root_dir, image_size=image_size)\n",
+ "\n",
+ " # Split\n",
+ " dataset_size = len(dataset)\n",
+ " train_size = int(train_split * dataset_size)\n",
+ " indices = torch.randperm(dataset_size).tolist()\n",
+ " train_indices = indices[:train_size]\n",
+ " val_indices = indices[train_size:]\n",
+ "\n",
+ " # Subsets\n",
+ "\n",
+ " train_dataset = Subset(dataset, train_indices)\n",
+ " val_dataset = Subset(dataset, val_indices)\n",
+ "\n",
+ " # DataLoaders\n",
+ " train_loader = DataLoader(\n",
+ " train_dataset,\n",
+ " batch_size=batch_size,\n",
+ " shuffle=True,\n",
+ " num_workers=num_workers,\n",
+ " pin_memory=True,\n",
+ " )\n",
+ "\n",
+ " val_loader = DataLoader(\n",
+ " val_dataset,\n",
+ " batch_size=batch_size,\n",
+ " shuffle=False,\n",
+ " num_workers=num_workers,\n",
+ " pin_memory=True,\n",
+ " )\n",
+ "\n",
+ " return train_loader, val_loader\n",
+ "\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "\n",
+ "\n",
+ "def get_dataset_info():\n",
+ " config = create_config()\n",
+ " dataset = YaGoDataset(\n",
+ " root_dir=config[\"data_dir\"],\n",
+ " image_size=tuple(config[\"image_size\"]),\n",
+ " )\n",
+ " sample = dataset[0] if len(dataset) else {}\n",
+ " return {\n",
+ " \"size\": len(dataset),\n",
+ " \"sample_keys\": list(sample.keys()),\n",
+ " \"google_shape\": tuple(sample[\"google_img\"].shape) if sample else None,\n",
+ " \"yandex_shape\": tuple(sample[\"yandex_img\"].shape) if sample else None,\n",
+ " }\n",
+ "\n",
+ "\n",
+ "def smoke_test_dataloader(batch_size=4):\n",
+ " config = create_config()\n",
+ " train_loader, val_loader = create_data_loaders(\n",
+ " root_dir=config[\"data_dir\"],\n",
+ " batch_size=batch_size,\n",
+ " train_split=config[\"train_split\"],\n",
+ " num_workers=config[\"num_workers\"],\n",
+ " image_size=tuple(config[\"image_size\"]),\n",
+ " )\n",
+ " batch = next(iter(train_loader))\n",
+ " return {\n",
+ " \"train_size\": len(train_loader.dataset),\n",
+ " \"val_size\": len(val_loader.dataset),\n",
+ " \"google_batch_shape\": tuple(batch[\"google_img\"].shape),\n",
+ " \"yandex_batch_shape\": tuple(batch[\"yandex_img\"].shape),\n",
+ " }\n",
+ "\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": "## Model\n\nPix2pix-style GAN for Google -> Yandex map translation.\n\n**Generator:**\n- `GeneratorUNet`\n- Input: Google image `(B, 3, H, W)`\n- Output: generated Yandex image `(B, 3, H, W)`\n\n**Discriminator:**\n- `DiscriminatorPatchGAN`\n- Input pair: `(google_img, yandex_img)`\n- Learns to distinguish real pairs from `(google_img, fake_yandex)`\n\n**Generator loss:**\n- adversarial loss\n- `lambda_L1 * L1(fake_yandex, yandex_img)`\n- `lambda_SSIM * SSIMLoss(fake_yandex, yandex_img)`\n- `lambda_edge * SobelEdgeLoss(fake_yandex, yandex_img)`\n\nThe generator uses bilinear upsampling followed by convolution to avoid\n"
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "\"\"\"GAN model for image translation Google -> Yandex.\"\"\"\n",
+ "\n",
+ "\n",
+ "\n",
+ "def get_compatible_device(prefer_cuda: bool = True, verbose: bool = True) -> torch.device:\n",
+ " \"\"\"Return CUDA only when the current PyTorch build supports the GPU arch.\"\"\"\n",
+ " if not prefer_cuda or not torch.cuda.is_available():\n",
+ " return torch.device(\"cpu\")\n",
+ "\n",
+ " try:\n",
+ " major, minor = torch.cuda.get_device_capability()\n",
+ " arch = f\"sm_{major}{minor}\"\n",
+ " supported_arches = set(torch.cuda.get_arch_list())\n",
+ " gpu_name = torch.cuda.get_device_name()\n",
+ " except Exception as exc:\n",
+ " if verbose:\n",
+ " print(f\"CUDA is visible but cannot be inspected ({exc}); using CPU.\")\n",
+ " return torch.device(\"cpu\")\n",
+ "\n",
+ " if supported_arches and arch not in supported_arches:\n",
+ " if verbose:\n",
+ " supported = \", \".join(sorted(supported_arches))\n",
+ " print(\n",
+ " f\"CUDA GPU '{gpu_name}' has capability {arch}, but this PyTorch build \"\n",
+ " f\"supports only: {supported}. Using CPU.\"\n",
+ " )\n",
+ " return torch.device(\"cpu\")\n",
+ "\n",
+ " return torch.device(\"cuda\")\n",
+ "\n",
+ "\n",
+ "class UNetDownBlock(nn.Module):\n",
+ " \"\"\"Downsampling block for U-Net.\"\"\"\n",
+ "\n",
+ " def __init__(self, in_channels: int, out_channels: int, normalize: bool = True, dropout: float = 0.0):\n",
+ " super().__init__()\n",
+ " layers = [\n",
+ " nn.Conv2d(in_channels, out_channels, kernel_size=4, stride=2, padding=1, bias=False)\n",
+ " ]\n",
+ " if normalize:\n",
+ " layers.append(nn.BatchNorm2d(out_channels))\n",
+ " layers.append(nn.LeakyReLU(0.2, inplace=True))\n",
+ " if dropout > 0:\n",
+ " layers.append(nn.Dropout2d(dropout))\n",
+ " self.model = nn.Sequential(*layers)\n",
+ "\n",
+ " def forward(self, x):\n",
+ " return self.model(x)\n",
+ "\n",
+ "\n",
+ "class UNetUpBlock(nn.Module):\n",
+ " \"\"\"Upsampling block for U-Net.\"\"\"\n",
+ "\n",
+ " def __init__(self, in_channels: int, out_channels: int, dropout: float = 0.0):\n",
+ " super().__init__()\n",
+ " self.upconv = nn.Sequential(\n",
+ " nn.Upsample(scale_factor=2, mode=\"bilinear\", align_corners=False),\n",
+ " nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, bias=False),\n",
+ " )\n",
+ " self.norm = nn.BatchNorm2d(out_channels)\n",
+ " self.relu = nn.ReLU(inplace=True)\n",
+ " if dropout > 0:\n",
+ " self.dropout = nn.Dropout2d(dropout)\n",
+ " else:\n",
+ " self.dropout = None\n",
+ "\n",
+ " def forward(self, x, skip_input):\n",
+ " x = self.upconv(x)\n",
+ " # Pad if needed to match skip connection size\n",
+ " if x.shape != skip_input.shape:\n",
+ " diff_h = skip_input.size(2) - x.size(2)\n",
+ " diff_w = skip_input.size(3) - x.size(3)\n",
+ " x = F.pad(x, [diff_w // 2, diff_w - diff_w // 2, diff_h // 2, diff_h - diff_h // 2])\n",
+ " x = self.norm(x)\n",
+ " x = self.relu(x)\n",
+ " if self.dropout:\n",
+ " x = self.dropout(x)\n",
+ " x = torch.cat([x, skip_input], dim=1)\n",
+ " return x\n",
+ "\n",
+ "\n",
+ "class GeneratorUNet(nn.Module):\n",
+ " \"\"\"U-Net generator for Google -> Yandex translation.\"\"\"\n",
+ "\n",
+ " def __init__(self, in_channels: int = 3, out_channels: int = 3):\n",
+ " super().__init__()\n",
+ "\n",
+ " # Downsampling\n",
+ " self.down1 = UNetDownBlock(in_channels, 64, normalize=False)\n",
+ " self.down2 = UNetDownBlock(64, 128)\n",
+ " self.down3 = UNetDownBlock(128, 256)\n",
+ " self.down4 = UNetDownBlock(256, 512)\n",
+ " self.down5 = UNetDownBlock(512, 512)\n",
+ " self.down6 = UNetDownBlock(512, 512)\n",
+ " self.down7 = UNetDownBlock(512, 512)\n",
+ "\n",
+ " # Bottleneck\n",
+ " self.bottleneck = nn.Sequential(\n",
+ " nn.Conv2d(512, 512, kernel_size=4, stride=2, padding=1, bias=False),\n",
+ " nn.ReLU(inplace=True),\n",
+ " )\n",
+ "\n",
+ " # Upsampling - input channels from previous layer, output before concat\n",
+ " self.up1 = UNetUpBlock(512, 512, dropout=0.5) # in: 512 (bottleneck) -> out: 512, concat with d7 (512) = 1024\n",
+ " self.up2 = UNetUpBlock(1024, 512, dropout=0.5) # in: 1024 -> out: 512, concat with d6 (512) = 1024\n",
+ " self.up3 = UNetUpBlock(1024, 512, dropout=0.5) # in: 1024 -> out: 512, concat with d5 (512) = 1024\n",
+ " self.up4 = UNetUpBlock(1024, 512) # in: 1024 -> out: 512, concat with d4 (512) = 1024\n",
+ " self.up5 = UNetUpBlock(1024, 256) # in: 1024 -> out: 256, concat with d3 (256) = 512\n",
+ " self.up6 = UNetUpBlock(512, 128) # in: 512 -> out: 128, concat with d2 (128) = 256\n",
+ " self.up7 = UNetUpBlock(256, 64) # in: 256 -> out: 64, concat with d1 (64) = 128\n",
+ "\n",
+ " # Final\n",
+ " self.final = nn.Sequential(\n",
+ " nn.Upsample(scale_factor=2, mode=\"bilinear\", align_corners=False),\n",
+ " nn.Conv2d(128, out_channels, kernel_size=3, padding=1),\n",
+ " nn.Tanh(),\n",
+ " )\n",
+ "\n",
+ " def forward(self, x):\n",
+ " # Down\n",
+ " d1 = self.down1(x)\n",
+ " d2 = self.down2(d1)\n",
+ " d3 = self.down3(d2)\n",
+ " d4 = self.down4(d3)\n",
+ " d5 = self.down5(d4)\n",
+ " d6 = self.down6(d5)\n",
+ " d7 = self.down7(d6)\n",
+ "\n",
+ " # Bottleneck\n",
+ " u = self.bottleneck(d7)\n",
+ "\n",
+ " # Up with skip connections\n",
+ " u = self.up1(u, d7)\n",
+ " u = self.up2(u, d6)\n",
+ " u = self.up3(u, d5)\n",
+ " u = self.up4(u, d4)\n",
+ " u = self.up5(u, d3)\n",
+ " u = self.up6(u, d2)\n",
+ " u = self.up7(u, d1)\n",
+ "\n",
+ " return self.final(u)\n",
+ "\n",
+ "\n",
+ "class DiscriminatorPatchGAN(nn.Module):\n",
+ " \"\"\"PatchGAN discriminator for paired source/target images.\"\"\"\n",
+ "\n",
+ " def __init__(self, in_channels: int = 6):\n",
+ " super().__init__()\n",
+ " self.model = nn.Sequential(\n",
+ " nn.Conv2d(in_channels, 64, kernel_size=4, stride=2, padding=1),\n",
+ " nn.LeakyReLU(0.2, inplace=True),\n",
+ " nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),\n",
+ " nn.BatchNorm2d(128),\n",
+ " nn.LeakyReLU(0.2, inplace=True),\n",
+ " nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1),\n",
+ " nn.BatchNorm2d(256),\n",
+ " nn.LeakyReLU(0.2, inplace=True),\n",
+ " nn.Conv2d(256, 512, kernel_size=4, stride=2, padding=1),\n",
+ " nn.BatchNorm2d(512),\n",
+ " nn.LeakyReLU(0.2, inplace=True),\n",
+ " nn.Conv2d(512, 1, kernel_size=4, stride=1, padding=1),\n",
+ " )\n",
+ "\n",
+ " def forward(self, img_A, img_B):\n",
+ " x = torch.cat([img_A, img_B], dim=1)\n",
+ " return self.model(x)\n",
+ "\n",
+ "\n",
+ "class GANLoss(nn.Module):\n",
+ " \"\"\"GAN loss supporting different GAN modes.\"\"\"\n",
+ "\n",
+ " def __init__(self, gan_mode: str = \"vanilla\", target_real: float = 1.0, target_fake: float = 0.0):\n",
+ " super().__init__()\n",
+ " self.gan_mode = gan_mode\n",
+ " self.register_buffer(\"real_label\", torch.tensor(target_real))\n",
+ " self.register_buffer(\"fake_label\", torch.tensor(target_fake))\n",
+ "\n",
+ " if gan_mode == \"vanilla\":\n",
+ " self.loss_fn = nn.BCEWithLogitsLoss()\n",
+ " elif gan_mode == \"lsgan\":\n",
+ " self.loss_fn = nn.MSELoss()\n",
+ " elif gan_mode == \"wgangp\":\n",
+ " self.loss_fn = None\n",
+ " else:\n",
+ " raise ValueError(f\"Unknown GAN mode: {gan_mode}\")\n",
+ "\n",
+ " def forward(self, prediction: torch.Tensor, target_is_real: bool) -> torch.Tensor:\n",
+ " if self.gan_mode in [\"vanilla\", \"lsgan\"]:\n",
+ " target = self.real_label if target_is_real else self.fake_label\n",
+ " target = target.expand_as(prediction)\n",
+ " return self.loss_fn(prediction, target)\n",
+ " elif self.gan_mode == \"wgangp\":\n",
+ " return -prediction.mean() if target_is_real else prediction.mean()\n",
+ "\n",
+ "\n",
+ "class SSIMLoss(nn.Module):\n",
+ " \"\"\"Local SSIM loss for normalized image tensors in [-1, 1].\"\"\"\n",
+ "\n",
+ " def __init__(self, window_size: int = 11, c1: float = 0.01 ** 2, c2: float = 0.03 ** 2):\n",
+ " super().__init__()\n",
+ " self.window_size = window_size\n",
+ " self.c1 = c1\n",
+ " self.c2 = c2\n",
+ "\n",
+ " def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:\n",
+ " pred = (pred + 1.0) * 0.5\n",
+ " target = (target + 1.0) * 0.5\n",
+ " padding = self.window_size // 2\n",
+ "\n",
+ " mu_pred = F.avg_pool2d(pred, self.window_size, stride=1, padding=padding)\n",
+ " mu_target = F.avg_pool2d(target, self.window_size, stride=1, padding=padding)\n",
+ " mu_pred_sq = mu_pred.pow(2)\n",
+ " mu_target_sq = mu_target.pow(2)\n",
+ " mu_pred_target = mu_pred * mu_target\n",
+ "\n",
+ " sigma_pred = F.avg_pool2d(pred * pred, self.window_size, stride=1, padding=padding) - mu_pred_sq\n",
+ " sigma_target = F.avg_pool2d(target * target, self.window_size, stride=1, padding=padding) - mu_target_sq\n",
+ " sigma_pred_target = F.avg_pool2d(pred * target, self.window_size, stride=1, padding=padding) - mu_pred_target\n",
+ "\n",
+ " ssim_map = (\n",
+ " (2 * mu_pred_target + self.c1) * (2 * sigma_pred_target + self.c2)\n",
+ " ) / (\n",
+ " (mu_pred_sq + mu_target_sq + self.c1) * (sigma_pred + sigma_target + self.c2)\n",
+ " )\n",
+ " return (1.0 - ssim_map.clamp(0, 1)).mean()\n",
+ "\n",
+ "\n",
+ "class SobelEdgeLoss(nn.Module):\n",
+ " \"\"\"L1 loss between Sobel edge maps, useful for stable keypoint structure.\"\"\"\n",
+ "\n",
+ " def __init__(self):\n",
+ " super().__init__()\n",
+ " kernel_x = torch.tensor(\n",
+ " [[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]],\n",
+ " dtype=torch.float32,\n",
+ " ).view(1, 1, 3, 3)\n",
+ " kernel_y = torch.tensor(\n",
+ " [[-1, -2, -1], [0, 0, 0], [1, 2, 1]],\n",
+ " dtype=torch.float32,\n",
+ " ).view(1, 1, 3, 3)\n",
+ " self.register_buffer(\"kernel_x\", kernel_x)\n",
+ " self.register_buffer(\"kernel_y\", kernel_y)\n",
+ "\n",
+ " @staticmethod\n",
+ " def _to_gray(x: torch.Tensor) -> torch.Tensor:\n",
+ " x = (x + 1.0) * 0.5\n",
+ " weights = x.new_tensor([0.299, 0.587, 0.114]).view(1, 3, 1, 1)\n",
+ " return (x * weights).sum(dim=1, keepdim=True)\n",
+ "\n",
+ " def _edges(self, x: torch.Tensor) -> torch.Tensor:\n",
+ " gray = self._to_gray(x)\n",
+ " grad_x = F.conv2d(gray, self.kernel_x, padding=1)\n",
+ " grad_y = F.conv2d(gray, self.kernel_y, padding=1)\n",
+ " return torch.sqrt(grad_x.pow(2) + grad_y.pow(2) + 1e-6)\n",
+ "\n",
+ " def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:\n",
+ " return F.l1_loss(self._edges(pred), self._edges(target))\n",
+ "\n",
+ "\n",
+ "class ImageGAN(nn.Module):\n",
+ " \"\"\"Complete pix2pix-style GAN for Google -> Yandex image translation.\"\"\"\n",
+ "\n",
+ " def __init__(\n",
+ " self,\n",
+ " input_channels: int = 3,\n",
+ " output_channels: int = 3,\n",
+ " gan_mode: str = \"lsgan\",\n",
+ " lambda_L1: float = 150.0,\n",
+ " lambda_GAN: float = 0.5,\n",
+ " lambda_SSIM: float = 25.0,\n",
+ " lambda_edge: float = 20.0,\n",
+ " use_cuda: bool = True,\n",
+ " ):\n",
+ " super().__init__()\n",
+ " self.generator = GeneratorUNet(input_channels, output_channels)\n",
+ " self.discriminator = DiscriminatorPatchGAN(input_channels + output_channels)\n",
+ " self.gan_loss = GANLoss(gan_mode)\n",
+ " self.l1_loss = nn.L1Loss()\n",
+ " self.ssim_loss = SSIMLoss()\n",
+ " self.edge_loss = SobelEdgeLoss()\n",
+ " self.lambda_L1 = lambda_L1\n",
+ " self.lambda_GAN = lambda_GAN\n",
+ " self.lambda_SSIM = lambda_SSIM\n",
+ " self.lambda_edge = lambda_edge\n",
+ "\n",
+ " self.device = get_compatible_device(prefer_cuda=use_cuda)\n",
+ " self.to(self.device)\n",
+ "\n",
+ " def forward(self, google_image):\n",
+ " \"\"\"Generate a Yandex-style image from a Google image.\"\"\"\n",
+ " return self.generator(google_image)\n",
+ "\n",
+ " def generator_step(self, google_img, real_yandex_img):\n",
+ " \"\"\"Compute generator losses against the paired original Yandex image.\"\"\"\n",
+ " fake_yandex = self.generator(google_img)\n",
+ " fake_pred = self.discriminator(google_img, fake_yandex)\n",
+ " gan_loss = self.gan_loss(fake_pred, True) * self.lambda_GAN\n",
+ " l1_loss = self.l1_loss(fake_yandex, real_yandex_img) * self.lambda_L1\n",
+ " ssim_loss = self.ssim_loss(fake_yandex, real_yandex_img) * self.lambda_SSIM\n",
+ " edge_loss = self.edge_loss(fake_yandex, real_yandex_img) * self.lambda_edge\n",
+ " total_loss = gan_loss + l1_loss + ssim_loss + edge_loss\n",
+ " return total_loss, gan_loss, l1_loss, ssim_loss, edge_loss\n",
+ "\n",
+ " def discriminator_step(self, google_img, real_yandex_img, fake_yandex_img):\n",
+ " \"\"\"Compute discriminator losses for real and generated Yandex targets.\"\"\"\n",
+ " real_pred = self.discriminator(google_img, real_yandex_img)\n",
+ " real_loss = self.gan_loss(real_pred, True)\n",
+ " fake_pred = self.discriminator(google_img, fake_yandex_img.detach())\n",
+ " fake_loss = self.gan_loss(fake_pred, False)\n",
+ " total_loss = (real_loss + fake_loss) * 0.5\n",
+ " return total_loss, real_loss, fake_loss\n",
+ "\n",
+ "\n",
+ "def create_gan(\n",
+ " input_channels: int = 3,\n",
+ " output_channels: int = 3,\n",
+ " gan_mode: str = \"lsgan\",\n",
+ " lambda_L1: float = 150.0,\n",
+ " lambda_GAN: float = 0.5,\n",
+ " lambda_SSIM: float = 25.0,\n",
+ " lambda_edge: float = 20.0,\n",
+ " use_cuda: bool = True,\n",
+ ") -> ImageGAN:\n",
+ " \"\"\"Create a GAN model.\"\"\"\n",
+ " return ImageGAN(\n",
+ " input_channels=input_channels,\n",
+ " output_channels=output_channels,\n",
+ " gan_mode=gan_mode,\n",
+ " lambda_L1=lambda_L1,\n",
+ " lambda_GAN=lambda_GAN,\n",
+ " lambda_SSIM=lambda_SSIM,\n",
+ " lambda_edge=lambda_edge,\n",
+ " use_cuda=use_cuda,\n",
+ " )\n",
+ "\n",
+ "\n",
+ "def initialize_weights(model: nn.Module):\n",
+ " \"\"\"Initialize model weights.\"\"\"\n",
+ " for m in model.modules():\n",
+ " if isinstance(m, nn.Conv2d):\n",
+ " nn.init.normal_(m.weight.data, 0.0, 0.02)\n",
+ " elif isinstance(m, nn.BatchNorm2d):\n",
+ " nn.init.normal_(m.weight.data, 1.0, 0.02)\n",
+ " nn.init.constant_(m.bias.data, 0.0)\n",
+ "\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": "## Training\n\n`GANTrainer` trains discriminator and generator alternately.\n\n**Training step:**\n1. Generate `fake_yandex = G(google_img)`\n2. Train discriminator on real pair `(google_img, yandex_img)` and fake pair `(google_img, fake_yandex)`\n3. Train generator against discriminator and paired Yandex target\n\n**Checkpoint saving:**\n- `best.pth`\n- `epoch_N.pth`\n"
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "\"\"\"Trainer for GAN model.\"\"\"\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "class GANTrainer:\n",
+ " \"\"\"Simple GAN trainer.\"\"\"\n",
+ "\n",
+ " def __init__(\n",
+ " self,\n",
+ " model: torch.nn.Module,\n",
+ " train_loader: DataLoader,\n",
+ " val_loader: DataLoader,\n",
+ " config: Dict[str, Any],\n",
+ " ):\n",
+ " self.model = model\n",
+ " self.train_loader = train_loader\n",
+ " self.val_loader = val_loader\n",
+ " self.config = config\n",
+ " self.device = model.device\n",
+ "\n",
+ " # Optimizers\n",
+ " lr = config.get(\"learning_rate\", 2e-4)\n",
+ " lr_d = config.get(\"discriminator_learning_rate\", lr * config.get(\"discriminator_lr_factor\", 0.5))\n",
+ " beta1 = config.get(\"beta1\", 0.5)\n",
+ " beta2 = config.get(\"beta2\", 0.999)\n",
+ " self.opt_G = torch.optim.Adam(model.generator.parameters(), lr=lr, betas=(beta1, beta2))\n",
+ " self.opt_D = torch.optim.Adam(model.discriminator.parameters(), lr=lr_d, betas=(beta1, beta2))\n",
+ "\n",
+ " # Training state\n",
+ " self.current_epoch = 0\n",
+ " self.best_val_loss = float(\"inf\")\n",
+ " self.g_losses = []\n",
+ " self.d_losses = []\n",
+ " self.l1_losses = []\n",
+ " self.ssim_losses = []\n",
+ " self.edge_losses = []\n",
+ " self.val_g_losses = []\n",
+ " self.val_d_losses = []\n",
+ " self.val_l1_losses = []\n",
+ " self.val_ssim_losses = []\n",
+ " self.val_edge_losses = []\n",
+ " self.val_reconstruction_losses = []\n",
+ "\n",
+ " # Output dir\n",
+ " self.output_dir = Path(config.get(\"output_dir\", \"runs/gan\"))\n",
+ " self.output_dir.mkdir(parents=True, exist_ok=True)\n",
+ " (self.output_dir / \"checkpoints\").mkdir(exist_ok=True)\n",
+ "\n",
+ " # Save config\n",
+ " with open(self.output_dir / \"config.json\", \"w\") as f:\n",
+ " json.dump(config, f, indent=2)\n",
+ "\n",
+ " def train_epoch(self) -> Tuple[float, float]:\n",
+ " \"\"\"Train for one epoch.\"\"\"\n",
+ " self.model.train()\n",
+ " total_g = total_d = 0.0\n",
+ " total_l1 = total_ssim = total_edge = 0.0\n",
+ " num_batches = len(self.train_loader)\n",
+ " d_update_interval = max(1, self.config.get(\"discriminator_update_interval\", 1))\n",
+ "\n",
+ " pbar = tqdm(enumerate(self.train_loader), total=num_batches, desc=f\"Epoch {self.current_epoch + 1}\")\n",
+ " for batch_idx, batch in pbar:\n",
+ " google_img = batch[\"google_img\"].to(self.device)\n",
+ " yandex_img = batch[\"yandex_img\"].to(self.device)\n",
+ "\n",
+ " # Train D\n",
+ " if batch_idx % d_update_interval == 0:\n",
+ " self.opt_D.zero_grad()\n",
+ " with torch.no_grad():\n",
+ " fake_img = self.model.generator(google_img)\n",
+ " d_loss = self.model.discriminator_step(google_img, yandex_img, fake_img)[0]\n",
+ " d_loss.backward()\n",
+ " self.opt_D.step()\n",
+ " else:\n",
+ " d_loss = google_img.new_tensor(0.0)\n",
+ "\n",
+ " # Train G\n",
+ " self.opt_G.zero_grad()\n",
+ " g_loss, gan_loss, l1_loss, ssim_loss, edge_loss = self.model.generator_step(google_img, yandex_img)\n",
+ " g_loss.backward()\n",
+ " self.opt_G.step()\n",
+ "\n",
+ " total_g += g_loss.item()\n",
+ " total_d += d_loss.item()\n",
+ " total_l1 += l1_loss.item()\n",
+ " total_ssim += ssim_loss.item()\n",
+ " total_edge += edge_loss.item()\n",
+ " pbar.set_postfix({\n",
+ " \"g_loss\": g_loss.item(),\n",
+ " \"d_loss\": d_loss.item(),\n",
+ " \"l1\": l1_loss.item(),\n",
+ " \"ssim\": ssim_loss.item(),\n",
+ " \"edge\": edge_loss.item(),\n",
+ " })\n",
+ "\n",
+ " avg_g = total_g / num_batches\n",
+ " avg_d = total_d / num_batches\n",
+ " avg_l1 = total_l1 / num_batches\n",
+ " avg_ssim = total_ssim / num_batches\n",
+ " avg_edge = total_edge / num_batches\n",
+ " self.g_losses.append(avg_g)\n",
+ " self.d_losses.append(avg_d)\n",
+ " self.l1_losses.append(avg_l1)\n",
+ " self.ssim_losses.append(avg_ssim)\n",
+ " self.edge_losses.append(avg_edge)\n",
+ " return avg_g, avg_d\n",
+ "\n",
+ " @torch.no_grad()\n",
+ " def validate(self) -> Tuple[float, float]:\n",
+ " \"\"\"Validate the model.\"\"\"\n",
+ " self.model.eval()\n",
+ " total_g = total_d = 0.0\n",
+ " total_l1 = total_ssim = total_edge = 0.0\n",
+ "\n",
+ " for batch in tqdm(self.val_loader, desc=\"Val\"):\n",
+ " google_img = batch[\"google_img\"].to(self.device)\n",
+ " yandex_img = batch[\"yandex_img\"].to(self.device)\n",
+ " fake_img = self.model.generator(google_img)\n",
+ " g_loss, _, l1_loss, ssim_loss, edge_loss = self.model.generator_step(google_img, yandex_img)\n",
+ " d_loss = self.model.discriminator_step(google_img, yandex_img, fake_img)[0]\n",
+ " total_g += g_loss.item()\n",
+ " total_d += d_loss.item()\n",
+ " total_l1 += l1_loss.item()\n",
+ " total_ssim += ssim_loss.item()\n",
+ " total_edge += edge_loss.item()\n",
+ "\n",
+ " avg_g = total_g / len(self.val_loader)\n",
+ " avg_d = total_d / len(self.val_loader)\n",
+ " avg_l1 = total_l1 / len(self.val_loader)\n",
+ " avg_ssim = total_ssim / len(self.val_loader)\n",
+ " avg_edge = total_edge / len(self.val_loader)\n",
+ " avg_reconstruction = avg_l1 + avg_ssim + avg_edge\n",
+ " self.val_g_losses.append(avg_g)\n",
+ " self.val_d_losses.append(avg_d)\n",
+ " self.val_l1_losses.append(avg_l1)\n",
+ " self.val_ssim_losses.append(avg_ssim)\n",
+ " self.val_edge_losses.append(avg_edge)\n",
+ " self.val_reconstruction_losses.append(avg_reconstruction)\n",
+ " return avg_g, avg_d\n",
+ "\n",
+ " def train(self, num_epochs: int):\n",
+ " \"\"\"Train the model.\"\"\"\n",
+ " print(f\"Training for {num_epochs} epochs...\")\n",
+ "\n",
+ " for epoch in range(num_epochs):\n",
+ " self.current_epoch = epoch\n",
+ "\n",
+ " # Train & validate\n",
+ " train_g, train_d = self.train_epoch()\n",
+ " val_g, val_d = self.validate()\n",
+ "\n",
+ " val_reconstruction = self.val_reconstruction_losses[-1]\n",
+ " if val_reconstruction < self.best_val_loss:\n",
+ " self.best_val_loss = val_reconstruction\n",
+ " self.save_checkpoint(\"best\")\n",
+ "\n",
+ " # Periodic checkpoint\n",
+ " if (epoch + 1) % self.config.get(\"save_interval\", 5) == 0:\n",
+ " self.save_checkpoint(f\"epoch_{epoch + 1}\")\n",
+ "\n",
+ " print(\n",
+ " f\"Epoch {epoch + 1}: \"\n",
+ " f\"train_g={train_g:.4f}, train_d={train_d:.4f}, \"\n",
+ " f\"train_l1={self.l1_losses[-1]:.4f}, train_ssim={self.ssim_losses[-1]:.4f}, \"\n",
+ " f\"train_edge={self.edge_losses[-1]:.4f}, val_g={val_g:.4f}, val_d={val_d:.4f}, \"\n",
+ " f\"val_l1={self.val_l1_losses[-1]:.4f}, val_ssim={self.val_ssim_losses[-1]:.4f}, \"\n",
+ " f\"val_edge={self.val_edge_losses[-1]:.4f}, val_rec={val_reconstruction:.4f}\"\n",
+ " )\n",
+ "\n",
+ " # Early stopping\n",
+ " patience = self.config.get(\"early_stopping_patience\", 0)\n",
+ " if patience > 0 and len(self.val_reconstruction_losses) > patience:\n",
+ " recent = self.val_reconstruction_losses[-patience:]\n",
+ " previous_best = min(self.val_reconstruction_losses[:-patience])\n",
+ " if all(loss >= previous_best for loss in recent):\n",
+ " print(f\"Early stopping at epoch {epoch + 1}\")\n",
+ " break\n",
+ "\n",
+ " # Save final\n",
+ " self.save_checkpoint(\"final\")\n",
+ " print(f\"Training finished. Best val loss: {self.best_val_loss:.4f}\")\n",
+ "\n",
+ " def save_checkpoint(self, name: str):\n",
+ " \"\"\"Save model checkpoint.\"\"\"\n",
+ " path = self.output_dir / \"checkpoints\" / f\"{name}.pth\"\n",
+ " torch.save({\n",
+ " \"epoch\": self.current_epoch,\n",
+ " \"generator\": self.model.generator.state_dict(),\n",
+ " \"discriminator\": self.model.discriminator.state_dict(),\n",
+ " \"opt_G\": self.opt_G.state_dict(),\n",
+ " \"opt_D\": self.opt_D.state_dict(),\n",
+ " }, path)\n",
+ "\n",
+ "\n",
+ "def create_trainer(\n",
+ " model: torch.nn.Module,\n",
+ " train_loader: DataLoader,\n",
+ " val_loader: DataLoader,\n",
+ " config: Dict[str, Any],\n",
+ ") -> GANTrainer:\n",
+ " \"\"\"Create a trainer instance.\"\"\"\n",
+ " return GANTrainer(model, train_loader, val_loader, config)\n",
+ "\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": "## Analysis\n\nVisualization helpers for generated samples and collected training metrics.\n\nTraining history plot contains:\n1. Generator loss\n2. Discriminator loss\n3. L1 loss against the paired Yandex target\n4. SSIM structure loss\n5. Sobel edge loss\n6. Best-checkpoint reconstruction score\n\nThe sample grid contains:\n1. Google input\n2. Generated Yandex\n"
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "\n",
+ "\n",
+ "\n",
+ "def denormalize_image(tensor):\n",
+ " return (tensor.detach().cpu() * 0.5 + 0.5).clamp(0, 1)\n",
+ "\n",
+ "\n",
+ "@torch.no_grad()\n",
+ "def visualize_generation(model, data_loader, output_dir, device=None, num_samples=4, show=True):\n",
+ " device = device or model.device\n",
+ " model.eval()\n",
+ "\n",
+ " batch = next(iter(data_loader))\n",
+ " google_img = batch[\"google_img\"][:num_samples].to(device)\n",
+ " yandex_img = batch[\"yandex_img\"][:num_samples].to(device)\n",
+ " fake_yandex = model.generator(google_img)\n",
+ "\n",
+ " output_dir = Path(output_dir)\n",
+ " output_dir.mkdir(parents=True, exist_ok=True)\n",
+ " output_path = output_dir / \"generation_samples.png\"\n",
+ "\n",
+ " fig, axes = plt.subplots(num_samples, 3, figsize=(9, 3 * num_samples))\n",
+ " if num_samples == 1:\n",
+ " axes = axes.reshape(1, 3)\n",
+ "\n",
+ " titles = [\"Google input\", \"Generated Yandex\", \"Yandex target\"]\n",
+ " for row in range(num_samples):\n",
+ " images = [google_img[row], fake_yandex[row], yandex_img[row]]\n",
+ " for col, image in enumerate(images):\n",
+ " axes[row, col].imshow(denormalize_image(image).permute(1, 2, 0))\n",
+ " axes[row, col].set_title(titles[col])\n",
+ " axes[row, col].axis(\"off\")\n",
+ "\n",
+ " fig.tight_layout()\n",
+ " fig.savefig(output_path, dpi=150)\n",
+ " if show:\n",
+ " plt.show()\n",
+ " plt.close(fig)\n",
+ " return output_path\n",
+ "\n",
+ "\n",
+ "def plot_training_history(trainer, output_dir, show=True):\n",
+ " output_dir = Path(output_dir)\n",
+ " output_dir.mkdir(parents=True, exist_ok=True)\n",
+ " output_path = output_dir / \"training_history.png\"\n",
+ "\n",
+ " epochs = range(1, len(trainer.g_losses) + 1)\n",
+ " fig, axes = plt.subplots(2, 3, figsize=(15, 8))\n",
+ " axes = axes.ravel()\n",
+ "\n",
+ " axes[0].plot(epochs, trainer.g_losses, label=\"train G\")\n",
+ " axes[0].plot(epochs, trainer.val_g_losses, label=\"val G\")\n",
+ " axes[0].set_title(\"Generator loss\")\n",
+ " axes[0].set_xlabel(\"Epoch\")\n",
+ " axes[0].legend()\n",
+ " axes[0].grid(True, alpha=0.3)\n",
+ "\n",
+ " axes[1].plot(epochs, trainer.d_losses, label=\"train D\")\n",
+ " axes[1].plot(epochs, trainer.val_d_losses, label=\"val D\")\n",
+ " axes[1].set_title(\"Discriminator loss\")\n",
+ " axes[1].set_xlabel(\"Epoch\")\n",
+ " axes[1].legend()\n",
+ " axes[1].grid(True, alpha=0.3)\n",
+ "\n",
+ " axes[2].plot(epochs, trainer.l1_losses, label=\"train L1\")\n",
+ " axes[2].plot(epochs, trainer.val_l1_losses, label=\"val L1\")\n",
+ " axes[2].set_title(\"Paired Yandex L1 loss\")\n",
+ " axes[2].set_xlabel(\"Epoch\")\n",
+ " axes[2].legend()\n",
+ " axes[2].grid(True, alpha=0.3)\n",
+ "\n",
+ " axes[3].plot(epochs, trainer.ssim_losses, label=\"train SSIM\")\n",
+ " axes[3].plot(epochs, trainer.val_ssim_losses, label=\"val SSIM\")\n",
+ " axes[3].set_title(\"SSIM structure loss\")\n",
+ " axes[3].set_xlabel(\"Epoch\")\n",
+ " axes[3].legend()\n",
+ " axes[3].grid(True, alpha=0.3)\n",
+ "\n",
+ " axes[4].plot(epochs, trainer.edge_losses, label=\"train edge\")\n",
+ " axes[4].plot(epochs, trainer.val_edge_losses, label=\"val edge\")\n",
+ " axes[4].set_title(\"Sobel edge loss\")\n",
+ " axes[4].set_xlabel(\"Epoch\")\n",
+ " axes[4].legend()\n",
+ " axes[4].grid(True, alpha=0.3)\n",
+ "\n",
+ " axes[5].plot(epochs, trainer.val_reconstruction_losses, label=\"val reconstruction\")\n",
+ " axes[5].set_title(\"Best-checkpoint score\")\n",
+ " axes[5].set_xlabel(\"Epoch\")\n",
+ " axes[5].legend()\n",
+ " axes[5].grid(True, alpha=0.3)\n",
+ "\n",
+ " fig.tight_layout()\n",
+ " fig.savefig(output_path, dpi=150)\n",
+ " if show:\n",
+ " plt.show()\n",
+ " plt.close(fig)\n",
+ " return output_path\n",
+ "\n",
+ "\n",
+ "def analyze_training(trainer):\n",
+ " return {\n",
+ " \"best_val_loss\": trainer.best_val_loss,\n",
+ " \"g_losses\": trainer.g_losses,\n",
+ " \"d_losses\": trainer.d_losses,\n",
+ " \"l1_losses\": trainer.l1_losses,\n",
+ " \"ssim_losses\": trainer.ssim_losses,\n",
+ " \"edge_losses\": trainer.edge_losses,\n",
+ " \"val_g_losses\": trainer.val_g_losses,\n",
+ " \"val_d_losses\": trainer.val_d_losses,\n",
+ " \"val_l1_losses\": trainer.val_l1_losses,\n",
+ " \"val_ssim_losses\": trainer.val_ssim_losses,\n",
+ " \"val_edge_losses\": trainer.val_edge_losses,\n",
+ " \"val_reconstruction_losses\": trainer.val_reconstruction_losses,\n",
+ " }\n",
+ "\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": "## Main Pipeline\n\nExecutes the full GAN workflow:\n1. Create config\n2. Build paired data loaders\n3. Initialize Google -> Yandex GAN\n4. Train with validation\n5. Save checkpoints in `runs/checkpoints/`\n6. Show loss plots and generated sample grid\n\nThis block is intentionally top-level, not wrapped in `main()`, so notebook\nvariables such as `model`, `trainer`, `train_loader`, `val_loader`, and\n"
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "\"\"\"Executable GAN training pipeline.\n",
+ "\n",
+ "The code is intentionally top-level, mirroring the SiaN notebook style:\n",
+ "when this file is included in the generated notebook, variables remain\n",
+ "available for debugging and interactive experiments.\n",
+ "\"\"\"\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "config = create_config()\n",
+ "device = get_compatible_device(prefer_cuda=config[\"prefer_cuda\"])\n",
+ "print(f\"Using device: {device}\")\n",
+ "\n",
+ "train_loader, val_loader = create_data_loaders(\n",
+ " root_dir=config[\"data_dir\"],\n",
+ " batch_size=config[\"batch_size\"],\n",
+ " train_split=config[\"train_split\"],\n",
+ " image_size=tuple(config[\"image_size\"]),\n",
+ " num_workers=config[\"num_workers\"],\n",
+ ")\n",
+ "print(f\"Data loaders created: train={len(train_loader.dataset)}, val={len(val_loader.dataset)}\")\n",
+ "\n",
+ "model = create_gan(\n",
+ " gan_mode=config[\"gan_mode\"],\n",
+ " lambda_GAN=config[\"lambda_GAN\"],\n",
+ " lambda_L1=config[\"lambda_L1\"],\n",
+ " lambda_SSIM=config[\"lambda_SSIM\"],\n",
+ " lambda_edge=config[\"lambda_edge\"],\n",
+ " use_cuda=(device.type == \"cuda\"),\n",
+ ")\n",
+ "\n",
+ "generator_params = sum(p.numel() for p in model.generator.parameters())\n",
+ "discriminator_params = sum(p.numel() for p in model.discriminator.parameters())\n",
+ "print(f\"Model created: generator={generator_params:,}, discriminator={discriminator_params:,}\")\n",
+ "\n",
+ "trainer = create_trainer(model, train_loader, val_loader, config)\n",
+ "trainer.train(config[\"epochs\"])\n",
+ "\n",
+ "training_analysis = analyze_training(trainer)\n",
+ "images_dir = Path(config[\"output_dir\"]) / \"images\"\n",
+ "history_plot_path = plot_training_history(trainer, images_dir)\n",
+ "generation_samples_path = visualize_generation(\n",
+ " model=model,\n",
+ " data_loader=val_loader,\n",
+ " output_dir=images_dir,\n",
+ " device=device,\n",
+ " num_samples=config[\"num_visual_samples\"],\n",
+ ")\n",
+ "\n",
+ "print(f\"Training history plot: {history_plot_path}\")\n",
+ "print(f\"Generation samples: {generation_samples_path}\")\n",
+ "\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "!zip artefacts.zip runs/checkpoints/best.pth runs/images/training_history.png runs/images/generation_samples.png\n",
+ "\n"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": ".venv",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "name": "python",
+ "version": "3.11.0"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 5
+}
\ No newline at end of file
diff --git a/models/GAN/src/__init__.py b/models/GAN/src/__init__.py
new file mode 100644
index 0000000..bfa6849
--- /dev/null
+++ b/models/GAN/src/__init__.py
@@ -0,0 +1 @@
+"""GAN package for Google-to-Yandex map image translation."""
diff --git a/models/GAN/src/analyze.py b/models/GAN/src/analyze.py
new file mode 100644
index 0000000..bbe8f30
--- /dev/null
+++ b/models/GAN/src/analyze.py
@@ -0,0 +1,117 @@
+from pathlib import Path
+
+import matplotlib.pyplot as plt
+import torch
+
+
+def denormalize_image(tensor):
+ return (tensor.detach().cpu() * 0.5 + 0.5).clamp(0, 1)
+
+
+@torch.no_grad()
+def visualize_generation(model, data_loader, output_dir, device=None, num_samples=4, show=True):
+ device = device or model.device
+ model.eval()
+
+ batch = next(iter(data_loader))
+ google_img = batch["google_img"][:num_samples].to(device)
+ yandex_img = batch["yandex_img"][:num_samples].to(device)
+ fake_yandex = model.generator(google_img)
+
+ output_dir = Path(output_dir)
+ output_dir.mkdir(parents=True, exist_ok=True)
+ output_path = output_dir / "generation_samples.png"
+
+ fig, axes = plt.subplots(num_samples, 3, figsize=(9, 3 * num_samples))
+ if num_samples == 1:
+ axes = axes.reshape(1, 3)
+
+ titles = ["Google input", "Generated Yandex", "Yandex target"]
+ for row in range(num_samples):
+ images = [google_img[row], fake_yandex[row], yandex_img[row]]
+ for col, image in enumerate(images):
+ axes[row, col].imshow(denormalize_image(image).permute(1, 2, 0))
+ axes[row, col].set_title(titles[col])
+ axes[row, col].axis("off")
+
+ fig.tight_layout()
+ fig.savefig(output_path, dpi=150)
+ if show:
+ plt.show()
+ plt.close(fig)
+ return output_path
+
+
+def plot_training_history(trainer, output_dir, show=True):
+ output_dir = Path(output_dir)
+ output_dir.mkdir(parents=True, exist_ok=True)
+ output_path = output_dir / "training_history.png"
+
+ epochs = range(1, len(trainer.g_losses) + 1)
+ fig, axes = plt.subplots(2, 3, figsize=(15, 8))
+ axes = axes.ravel()
+
+ axes[0].plot(epochs, trainer.g_losses, label="train G")
+ axes[0].plot(epochs, trainer.val_g_losses, label="val G")
+ axes[0].set_title("Generator loss")
+ axes[0].set_xlabel("Epoch")
+ axes[0].legend()
+ axes[0].grid(True, alpha=0.3)
+
+ axes[1].plot(epochs, trainer.d_losses, label="train D")
+ axes[1].plot(epochs, trainer.val_d_losses, label="val D")
+ axes[1].set_title("Discriminator loss")
+ axes[1].set_xlabel("Epoch")
+ axes[1].legend()
+ axes[1].grid(True, alpha=0.3)
+
+ axes[2].plot(epochs, trainer.l1_losses, label="train L1")
+ axes[2].plot(epochs, trainer.val_l1_losses, label="val L1")
+ axes[2].set_title("Paired Yandex L1 loss")
+ axes[2].set_xlabel("Epoch")
+ axes[2].legend()
+ axes[2].grid(True, alpha=0.3)
+
+ axes[3].plot(epochs, trainer.ssim_losses, label="train SSIM")
+ axes[3].plot(epochs, trainer.val_ssim_losses, label="val SSIM")
+ axes[3].set_title("SSIM structure loss")
+ axes[3].set_xlabel("Epoch")
+ axes[3].legend()
+ axes[3].grid(True, alpha=0.3)
+
+ axes[4].plot(epochs, trainer.edge_losses, label="train edge")
+ axes[4].plot(epochs, trainer.val_edge_losses, label="val edge")
+ axes[4].set_title("Sobel edge loss")
+ axes[4].set_xlabel("Epoch")
+ axes[4].legend()
+ axes[4].grid(True, alpha=0.3)
+
+ axes[5].plot(epochs, trainer.val_reconstruction_losses, label="val reconstruction")
+ axes[5].set_title("Best-checkpoint score")
+ axes[5].set_xlabel("Epoch")
+ axes[5].legend()
+ axes[5].grid(True, alpha=0.3)
+
+ fig.tight_layout()
+ fig.savefig(output_path, dpi=150)
+ if show:
+ plt.show()
+ plt.close(fig)
+ return output_path
+
+
+def analyze_training(trainer):
+ return {
+ "best_val_loss": trainer.best_val_loss,
+ "g_losses": trainer.g_losses,
+ "d_losses": trainer.d_losses,
+ "l1_losses": trainer.l1_losses,
+ "ssim_losses": trainer.ssim_losses,
+ "edge_losses": trainer.edge_losses,
+ "val_g_losses": trainer.val_g_losses,
+ "val_d_losses": trainer.val_d_losses,
+ "val_l1_losses": trainer.val_l1_losses,
+ "val_ssim_losses": trainer.val_ssim_losses,
+ "val_edge_losses": trainer.val_edge_losses,
+ "val_reconstruction_losses": trainer.val_reconstruction_losses,
+ }
diff --git a/models/GAN/config.py b/models/GAN/src/config.py
similarity index 62%
rename from models/GAN/config.py
rename to models/GAN/src/config.py
index b97e1b5..69fa681 100644
--- a/models/GAN/config.py
+++ b/models/GAN/src/config.py
@@ -6,23 +6,30 @@ def create_config():
return {
# Optimizer params
"learning_rate": 2e-4,
+ "discriminator_lr_factor": 0.5,
"beta1": 0.5,
"beta2": 0.999,
# Training params
"batch_size": 32,
"epochs": 100,
+ "prefer_cuda": True,
# GAN params
- "gan_mode": "vanilla",
- "lambda_L1": 100.0,
+ "gan_mode": "lsgan",
+ "lambda_GAN": 0.5,
+ "lambda_L1": 150.0,
+ "lambda_SSIM": 25.0,
+ "lambda_edge": 20.0,
+ "discriminator_update_interval": 1,
# Regularization
"grad_clip": 1.0,
# Early stopping
- "early_stopping_patience": 20,
+ "early_stopping_patience": 25,
# Output
- "output_dir": "runs/gan_training",
+ "output_dir": r"C:\Users\admin\Projects\autopilot\models\GAN\runs",
# Logging
"log_interval": 10,
"save_interval": 5,
+ "num_visual_samples": 4,
# Data
"data_dir": r"C:\Users\admin\Projects\autopilot\datasets\ya_go_maps\images",
"image_size": [256, 256],
@@ -33,4 +40,4 @@ def create_config():
if __name__ == "__main__":
config = create_config()
- print("Default config:", config)
\ No newline at end of file
+ print("Default config:", config)
diff --git a/models/GAN/dataloader.py b/models/GAN/src/dataloader.py
similarity index 97%
rename from models/GAN/dataloader.py
rename to models/GAN/src/dataloader.py
index 3472db1..c21ba3f 100644
--- a/models/GAN/dataloader.py
+++ b/models/GAN/src/dataloader.py
@@ -1,4 +1,4 @@
-"""Data loader for Yandex-to-Google image translation."""
+"""Data loader for Google-to-Yandex image translation."""
import os
from typing import Dict, List, Tuple
@@ -10,7 +10,7 @@ from torchvision import transforms
class YaGoDataset(Dataset):
- """Dataset loading pairs of Yandex and Google map images."""
+ """Dataset loading paired Google and Yandex map images."""
def __init__(
self,
@@ -159,4 +159,4 @@ if __name__ == "__main__":
batch = next(iter(train_loader))
print(f"Batch shapes: google={batch['google_img'].shape}, yandex={batch['yandex_img'].shape}")
- print(f"Train batches: {len(train_loader)}, Val batches: {len(val_loader)}")
\ No newline at end of file
+ print(f"Train batches: {len(train_loader)}, Val batches: {len(val_loader)}")
diff --git a/models/GAN/src/main.py b/models/GAN/src/main.py
new file mode 100644
index 0000000..4e2966d
--- /dev/null
+++ b/models/GAN/src/main.py
@@ -0,0 +1,60 @@
+"""Executable GAN training pipeline.
+
+The code is intentionally top-level, mirroring the SiaN notebook style:
+when this file is included in the generated notebook, variables remain
+available for debugging and interactive experiments.
+"""
+
+from pathlib import Path
+
+import torch
+
+from analyze import analyze_training, plot_training_history, visualize_generation
+from config import create_config
+from dataloader import create_data_loaders
+from model import create_gan, get_compatible_device
+from trainer import create_trainer
+
+
+config = create_config()
+device = get_compatible_device(prefer_cuda=config["prefer_cuda"])
+print(f"Using device: {device}")
+
+train_loader, val_loader = create_data_loaders(
+ root_dir=config["data_dir"],
+ batch_size=config["batch_size"],
+ train_split=config["train_split"],
+ image_size=tuple(config["image_size"]),
+ num_workers=config["num_workers"],
+)
+print(f"Data loaders created: train={len(train_loader.dataset)}, val={len(val_loader.dataset)}")
+
+model = create_gan(
+ gan_mode=config["gan_mode"],
+ lambda_GAN=config["lambda_GAN"],
+ lambda_L1=config["lambda_L1"],
+ lambda_SSIM=config["lambda_SSIM"],
+ lambda_edge=config["lambda_edge"],
+ use_cuda=(device.type == "cuda"),
+)
+
+generator_params = sum(p.numel() for p in model.generator.parameters())
+discriminator_params = sum(p.numel() for p in model.discriminator.parameters())
+print(f"Model created: generator={generator_params:,}, discriminator={discriminator_params:,}")
+
+trainer = create_trainer(model, train_loader, val_loader, config)
+trainer.train(config["epochs"])
+
+training_analysis = analyze_training(trainer)
+images_dir = Path(config["output_dir"]) / "images"
+history_plot_path = plot_training_history(trainer, images_dir)
+generation_samples_path = visualize_generation(
+ model=model,
+ data_loader=val_loader,
+ output_dir=images_dir,
+ device=device,
+ num_samples=config["num_visual_samples"],
+)
+
+print(f"Training history plot: {history_plot_path}")
+print(f"Generation samples: {generation_samples_path}")
diff --git a/models/GAN/model.py b/models/GAN/src/model.py
similarity index 56%
rename from models/GAN/model.py
rename to models/GAN/src/model.py
index dfb0787..f250c13 100644
--- a/models/GAN/model.py
+++ b/models/GAN/src/model.py
@@ -1,10 +1,37 @@
-"""GAN model for image translation Yandex -> Google."""
+"""GAN model for image translation Google -> Yandex."""
import torch
import torch.nn as nn
import torch.nn.functional as F
+def get_compatible_device(prefer_cuda: bool = True, verbose: bool = True) -> torch.device:
+ """Return CUDA only when the current PyTorch build supports the GPU arch."""
+ if not prefer_cuda or not torch.cuda.is_available():
+ return torch.device("cpu")
+
+ try:
+ major, minor = torch.cuda.get_device_capability()
+ arch = f"sm_{major}{minor}"
+ supported_arches = set(torch.cuda.get_arch_list())
+ gpu_name = torch.cuda.get_device_name()
+ except Exception as exc:
+ if verbose:
+ print(f"CUDA is visible but cannot be inspected ({exc}); using CPU.")
+ return torch.device("cpu")
+
+ if supported_arches and arch not in supported_arches:
+ if verbose:
+ supported = ", ".join(sorted(supported_arches))
+ print(
+ f"CUDA GPU '{gpu_name}' has capability {arch}, but this PyTorch build "
+ f"supports only: {supported}. Using CPU."
+ )
+ return torch.device("cpu")
+
+ return torch.device("cuda")
+
+
class UNetDownBlock(nn.Module):
"""Downsampling block for U-Net."""
@@ -29,7 +56,10 @@ class UNetUpBlock(nn.Module):
def __init__(self, in_channels: int, out_channels: int, dropout: float = 0.0):
super().__init__()
- self.upconv = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=4, stride=2, padding=1, bias=False)
+ self.upconv = nn.Sequential(
+ nn.Upsample(scale_factor=2, mode="bilinear", align_corners=False),
+ nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, bias=False),
+ )
self.norm = nn.BatchNorm2d(out_channels)
self.relu = nn.ReLU(inplace=True)
if dropout > 0:
@@ -38,7 +68,6 @@ class UNetUpBlock(nn.Module):
self.dropout = None
def forward(self, x, skip_input):
-
x = self.upconv(x)
# Pad if needed to match skip connection size
if x.shape != skip_input.shape:
@@ -54,7 +83,7 @@ class UNetUpBlock(nn.Module):
class GeneratorUNet(nn.Module):
- """U-Net generator for Yandex -> Google translation."""
+ """U-Net generator for Google -> Yandex translation."""
def __init__(self, in_channels: int = 3, out_channels: int = 3):
super().__init__()
@@ -85,7 +114,8 @@ class GeneratorUNet(nn.Module):
# Final
self.final = nn.Sequential(
- nn.ConvTranspose2d(128, out_channels, kernel_size=4, stride=2, padding=1),
+ nn.Upsample(scale_factor=2, mode="bilinear", align_corners=False),
+ nn.Conv2d(128, out_channels, kernel_size=3, padding=1),
nn.Tanh(),
)
@@ -115,7 +145,7 @@ class GeneratorUNet(nn.Module):
class DiscriminatorPatchGAN(nn.Module):
- """PatchGAN discriminator."""
+ """PatchGAN discriminator for paired source/target images."""
def __init__(self, in_channels: int = 6):
super().__init__()
@@ -132,7 +162,6 @@ class DiscriminatorPatchGAN(nn.Module):
nn.BatchNorm2d(512),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(512, 1, kernel_size=4, stride=1, padding=1),
- nn.Sigmoid(),
)
def forward(self, img_A, img_B):
@@ -167,15 +196,82 @@ class GANLoss(nn.Module):
return -prediction.mean() if target_is_real else prediction.mean()
+class SSIMLoss(nn.Module):
+ """Local SSIM loss for normalized image tensors in [-1, 1]."""
+
+ def __init__(self, window_size: int = 11, c1: float = 0.01 ** 2, c2: float = 0.03 ** 2):
+ super().__init__()
+ self.window_size = window_size
+ self.c1 = c1
+ self.c2 = c2
+
+ def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
+ pred = (pred + 1.0) * 0.5
+ target = (target + 1.0) * 0.5
+ padding = self.window_size // 2
+
+ mu_pred = F.avg_pool2d(pred, self.window_size, stride=1, padding=padding)
+ mu_target = F.avg_pool2d(target, self.window_size, stride=1, padding=padding)
+ mu_pred_sq = mu_pred.pow(2)
+ mu_target_sq = mu_target.pow(2)
+ mu_pred_target = mu_pred * mu_target
+
+ sigma_pred = F.avg_pool2d(pred * pred, self.window_size, stride=1, padding=padding) - mu_pred_sq
+ sigma_target = F.avg_pool2d(target * target, self.window_size, stride=1, padding=padding) - mu_target_sq
+ sigma_pred_target = F.avg_pool2d(pred * target, self.window_size, stride=1, padding=padding) - mu_pred_target
+
+ ssim_map = (
+ (2 * mu_pred_target + self.c1) * (2 * sigma_pred_target + self.c2)
+ ) / (
+ (mu_pred_sq + mu_target_sq + self.c1) * (sigma_pred + sigma_target + self.c2)
+ )
+ return (1.0 - ssim_map.clamp(0, 1)).mean()
+
+
+class SobelEdgeLoss(nn.Module):
+ """L1 loss between Sobel edge maps, useful for stable keypoint structure."""
+
+ def __init__(self):
+ super().__init__()
+ kernel_x = torch.tensor(
+ [[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]],
+ dtype=torch.float32,
+ ).view(1, 1, 3, 3)
+ kernel_y = torch.tensor(
+ [[-1, -2, -1], [0, 0, 0], [1, 2, 1]],
+ dtype=torch.float32,
+ ).view(1, 1, 3, 3)
+ self.register_buffer("kernel_x", kernel_x)
+ self.register_buffer("kernel_y", kernel_y)
+
+ @staticmethod
+ def _to_gray(x: torch.Tensor) -> torch.Tensor:
+ x = (x + 1.0) * 0.5
+ weights = x.new_tensor([0.299, 0.587, 0.114]).view(1, 3, 1, 1)
+ return (x * weights).sum(dim=1, keepdim=True)
+
+ def _edges(self, x: torch.Tensor) -> torch.Tensor:
+ gray = self._to_gray(x)
+ grad_x = F.conv2d(gray, self.kernel_x, padding=1)
+ grad_y = F.conv2d(gray, self.kernel_y, padding=1)
+ return torch.sqrt(grad_x.pow(2) + grad_y.pow(2) + 1e-6)
+
+ def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
+ return F.l1_loss(self._edges(pred), self._edges(target))
+
+
class ImageGAN(nn.Module):
- """Complete GAN model for image translation."""
+ """Complete pix2pix-style GAN for Google -> Yandex image translation."""
def __init__(
self,
input_channels: int = 3,
output_channels: int = 3,
- gan_mode: str = "vanilla",
- lambda_L1: float = 100.0,
+ gan_mode: str = "lsgan",
+ lambda_L1: float = 150.0,
+ lambda_GAN: float = 0.5,
+ lambda_SSIM: float = 25.0,
+ lambda_edge: float = 20.0,
use_cuda: bool = True,
):
super().__init__()
@@ -183,29 +279,36 @@ class ImageGAN(nn.Module):
self.discriminator = DiscriminatorPatchGAN(input_channels + output_channels)
self.gan_loss = GANLoss(gan_mode)
self.l1_loss = nn.L1Loss()
+ self.ssim_loss = SSIMLoss()
+ self.edge_loss = SobelEdgeLoss()
self.lambda_L1 = lambda_L1
+ self.lambda_GAN = lambda_GAN
+ self.lambda_SSIM = lambda_SSIM
+ self.lambda_edge = lambda_edge
- self.device = torch.device("cuda" if use_cuda and torch.cuda.is_available() else "cpu")
+ self.device = get_compatible_device(prefer_cuda=use_cuda)
self.to(self.device)
- def forward(self, yandex_image):
- """Generate Google image from Yandex."""
- return self.generator(yandex_image)
+ def forward(self, google_image):
+ """Generate a Yandex-style image from a Google image."""
+ return self.generator(google_image)
- def generator_step(self, yandex_img, real_google_img):
- """Compute generator losses."""
- fake_google = self.generator(yandex_img)
- fake_pred = self.discriminator(yandex_img, fake_google)
- gan_loss = self.gan_loss(fake_pred, True)
- l1_loss = self.l1_loss(fake_google, real_google_img) * self.lambda_L1
- total_loss = gan_loss + l1_loss
- return total_loss, gan_loss, l1_loss
+ def generator_step(self, google_img, real_yandex_img):
+ """Compute generator losses against the paired original Yandex image."""
+ fake_yandex = self.generator(google_img)
+ fake_pred = self.discriminator(google_img, fake_yandex)
+ gan_loss = self.gan_loss(fake_pred, True) * self.lambda_GAN
+ l1_loss = self.l1_loss(fake_yandex, real_yandex_img) * self.lambda_L1
+ ssim_loss = self.ssim_loss(fake_yandex, real_yandex_img) * self.lambda_SSIM
+ edge_loss = self.edge_loss(fake_yandex, real_yandex_img) * self.lambda_edge
+ total_loss = gan_loss + l1_loss + ssim_loss + edge_loss
+ return total_loss, gan_loss, l1_loss, ssim_loss, edge_loss
- def discriminator_step(self, yandex_img, real_google_img, fake_google_img):
- """Compute discriminator losses."""
- real_pred = self.discriminator(yandex_img, real_google_img)
+ def discriminator_step(self, google_img, real_yandex_img, fake_yandex_img):
+ """Compute discriminator losses for real and generated Yandex targets."""
+ real_pred = self.discriminator(google_img, real_yandex_img)
real_loss = self.gan_loss(real_pred, True)
- fake_pred = self.discriminator(yandex_img, fake_google_img.detach())
+ fake_pred = self.discriminator(google_img, fake_yandex_img.detach())
fake_loss = self.gan_loss(fake_pred, False)
total_loss = (real_loss + fake_loss) * 0.5
return total_loss, real_loss, fake_loss
@@ -214,8 +317,11 @@ class ImageGAN(nn.Module):
def create_gan(
input_channels: int = 3,
output_channels: int = 3,
- gan_mode: str = "vanilla",
- lambda_L1: float = 100.0,
+ gan_mode: str = "lsgan",
+ lambda_L1: float = 150.0,
+ lambda_GAN: float = 0.5,
+ lambda_SSIM: float = 25.0,
+ lambda_edge: float = 20.0,
use_cuda: bool = True,
) -> ImageGAN:
"""Create a GAN model."""
@@ -224,6 +330,9 @@ def create_gan(
output_channels=output_channels,
gan_mode=gan_mode,
lambda_L1=lambda_L1,
+ lambda_GAN=lambda_GAN,
+ lambda_SSIM=lambda_SSIM,
+ lambda_edge=lambda_edge,
use_cuda=use_cuda,
)
@@ -252,4 +361,4 @@ if __name__ == "__main__":
# Count parameters
gen_params = sum(p.numel() for p in model.generator.parameters())
disc_params = sum(p.numel() for p in model.discriminator.parameters())
- print(f"Generator: {gen_params:,} params, Discriminator: {disc_params:,} params")
\ No newline at end of file
+ print(f"Generator: {gen_params:,} params, Discriminator: {disc_params:,} params")
diff --git a/models/GAN/src/test_dataloader.py b/models/GAN/src/test_dataloader.py
new file mode 100644
index 0000000..e7b949f
--- /dev/null
+++ b/models/GAN/src/test_dataloader.py
@@ -0,0 +1,40 @@
+from config import create_config
+from dataloader import YaGoDataset, create_data_loaders
+
+
+def get_dataset_info():
+ config = create_config()
+ dataset = YaGoDataset(
+ root_dir=config["data_dir"],
+ image_size=tuple(config["image_size"]),
+ )
+ sample = dataset[0] if len(dataset) else {}
+ return {
+ "size": len(dataset),
+ "sample_keys": list(sample.keys()),
+ "google_shape": tuple(sample["google_img"].shape) if sample else None,
+ "yandex_shape": tuple(sample["yandex_img"].shape) if sample else None,
+ }
+
+
+def smoke_test_dataloader(batch_size=4):
+ config = create_config()
+ train_loader, val_loader = create_data_loaders(
+ root_dir=config["data_dir"],
+ batch_size=batch_size,
+ train_split=config["train_split"],
+ num_workers=config["num_workers"],
+ image_size=tuple(config["image_size"]),
+ )
+ batch = next(iter(train_loader))
+ return {
+ "train_size": len(train_loader.dataset),
+ "val_size": len(val_loader.dataset),
+ "google_batch_shape": tuple(batch["google_img"].shape),
+ "yandex_batch_shape": tuple(batch["yandex_img"].shape),
+ }
+
+
+if __name__ == "__main__":
+ print(get_dataset_info())
+ print(smoke_test_dataloader())
diff --git a/models/GAN/trainer.py b/models/GAN/src/trainer.py
similarity index 58%
rename from models/GAN/trainer.py
rename to models/GAN/src/trainer.py
index 77fb0b0..7bac290 100644
--- a/models/GAN/trainer.py
+++ b/models/GAN/src/trainer.py
@@ -28,18 +28,26 @@ class GANTrainer:
# Optimizers
lr = config.get("learning_rate", 2e-4)
+ lr_d = config.get("discriminator_learning_rate", lr * config.get("discriminator_lr_factor", 0.5))
beta1 = config.get("beta1", 0.5)
beta2 = config.get("beta2", 0.999)
self.opt_G = torch.optim.Adam(model.generator.parameters(), lr=lr, betas=(beta1, beta2))
- self.opt_D = torch.optim.Adam(model.discriminator.parameters(), lr=lr, betas=(beta1, beta2))
+ self.opt_D = torch.optim.Adam(model.discriminator.parameters(), lr=lr_d, betas=(beta1, beta2))
# Training state
self.current_epoch = 0
self.best_val_loss = float("inf")
self.g_losses = []
self.d_losses = []
+ self.l1_losses = []
+ self.ssim_losses = []
+ self.edge_losses = []
self.val_g_losses = []
self.val_d_losses = []
+ self.val_l1_losses = []
+ self.val_ssim_losses = []
+ self.val_edge_losses = []
+ self.val_reconstruction_losses = []
# Output dir
self.output_dir = Path(config.get("output_dir", "runs/gan"))
@@ -54,35 +62,55 @@ class GANTrainer:
"""Train for one epoch."""
self.model.train()
total_g = total_d = 0.0
+ total_l1 = total_ssim = total_edge = 0.0
num_batches = len(self.train_loader)
+ d_update_interval = max(1, self.config.get("discriminator_update_interval", 1))
- pbar = tqdm(self.train_loader, desc=f"Epoch {self.current_epoch + 1}")
- for batch in pbar:
- yandex_img = batch["yandex_img"].to(self.device)
+ pbar = tqdm(enumerate(self.train_loader), total=num_batches, desc=f"Epoch {self.current_epoch + 1}")
+ for batch_idx, batch in pbar:
google_img = batch["google_img"].to(self.device)
+ yandex_img = batch["yandex_img"].to(self.device)
# Train D
- self.opt_D.zero_grad()
- with torch.no_grad():
- fake_img = self.model.generator(yandex_img)
- d_loss = self.model.discriminator_step(yandex_img, google_img, fake_img)[0]
- d_loss.backward()
- self.opt_D.step()
+ if batch_idx % d_update_interval == 0:
+ self.opt_D.zero_grad()
+ with torch.no_grad():
+ fake_img = self.model.generator(google_img)
+ d_loss = self.model.discriminator_step(google_img, yandex_img, fake_img)[0]
+ d_loss.backward()
+ self.opt_D.step()
+ else:
+ d_loss = google_img.new_tensor(0.0)
# Train G
self.opt_G.zero_grad()
- g_loss = self.model.generator_step(yandex_img, google_img)[0]
+ g_loss, gan_loss, l1_loss, ssim_loss, edge_loss = self.model.generator_step(google_img, yandex_img)
g_loss.backward()
self.opt_G.step()
total_g += g_loss.item()
total_d += d_loss.item()
- pbar.set_postfix({"g_loss": g_loss.item(), "d_loss": d_loss.item()})
+ total_l1 += l1_loss.item()
+ total_ssim += ssim_loss.item()
+ total_edge += edge_loss.item()
+ pbar.set_postfix({
+ "g_loss": g_loss.item(),
+ "d_loss": d_loss.item(),
+ "l1": l1_loss.item(),
+ "ssim": ssim_loss.item(),
+ "edge": edge_loss.item(),
+ })
avg_g = total_g / num_batches
avg_d = total_d / num_batches
+ avg_l1 = total_l1 / num_batches
+ avg_ssim = total_ssim / num_batches
+ avg_edge = total_edge / num_batches
self.g_losses.append(avg_g)
self.d_losses.append(avg_d)
+ self.l1_losses.append(avg_l1)
+ self.ssim_losses.append(avg_ssim)
+ self.edge_losses.append(avg_edge)
return avg_g, avg_d
@torch.no_grad()
@@ -90,20 +118,32 @@ class GANTrainer:
"""Validate the model."""
self.model.eval()
total_g = total_d = 0.0
+ total_l1 = total_ssim = total_edge = 0.0
for batch in tqdm(self.val_loader, desc="Val"):
- yandex_img = batch["yandex_img"].to(self.device)
google_img = batch["google_img"].to(self.device)
- fake_img = self.model.generator(yandex_img)
- g_loss = self.model.generator_step(yandex_img, google_img)[0]
- d_loss = self.model.discriminator_step(yandex_img, google_img, fake_img)[0]
+ yandex_img = batch["yandex_img"].to(self.device)
+ fake_img = self.model.generator(google_img)
+ g_loss, _, l1_loss, ssim_loss, edge_loss = self.model.generator_step(google_img, yandex_img)
+ d_loss = self.model.discriminator_step(google_img, yandex_img, fake_img)[0]
total_g += g_loss.item()
total_d += d_loss.item()
+ total_l1 += l1_loss.item()
+ total_ssim += ssim_loss.item()
+ total_edge += edge_loss.item()
avg_g = total_g / len(self.val_loader)
avg_d = total_d / len(self.val_loader)
+ avg_l1 = total_l1 / len(self.val_loader)
+ avg_ssim = total_ssim / len(self.val_loader)
+ avg_edge = total_edge / len(self.val_loader)
+ avg_reconstruction = avg_l1 + avg_ssim + avg_edge
self.val_g_losses.append(avg_g)
self.val_d_losses.append(avg_d)
+ self.val_l1_losses.append(avg_l1)
+ self.val_ssim_losses.append(avg_ssim)
+ self.val_edge_losses.append(avg_edge)
+ self.val_reconstruction_losses.append(avg_reconstruction)
return avg_g, avg_d
def train(self, num_epochs: int):
@@ -117,23 +157,30 @@ class GANTrainer:
train_g, train_d = self.train_epoch()
val_g, val_d = self.validate()
- # Save best checkpoint
- val_total = val_g + val_d
- if val_total < self.best_val_loss:
- self.best_val_loss = val_total
+ val_reconstruction = self.val_reconstruction_losses[-1]
+ if val_reconstruction < self.best_val_loss:
+ self.best_val_loss = val_reconstruction
self.save_checkpoint("best")
# Periodic checkpoint
if (epoch + 1) % self.config.get("save_interval", 5) == 0:
self.save_checkpoint(f"epoch_{epoch + 1}")
- print(f"Epoch {epoch + 1}: train_g={train_g:.4f}, train_d={train_d:.4f}, val_g={val_g:.4f}, val_d={val_d:.4f}")
+ print(
+ f"Epoch {epoch + 1}: "
+ f"train_g={train_g:.4f}, train_d={train_d:.4f}, "
+ f"train_l1={self.l1_losses[-1]:.4f}, train_ssim={self.ssim_losses[-1]:.4f}, "
+ f"train_edge={self.edge_losses[-1]:.4f}, val_g={val_g:.4f}, val_d={val_d:.4f}, "
+ f"val_l1={self.val_l1_losses[-1]:.4f}, val_ssim={self.val_ssim_losses[-1]:.4f}, "
+ f"val_edge={self.val_edge_losses[-1]:.4f}, val_rec={val_reconstruction:.4f}"
+ )
# Early stopping
patience = self.config.get("early_stopping_patience", 0)
- if patience > 0 and len(self.val_g_losses) > patience:
- recent = self.val_g_losses[-patience:]
- if all(l >= min(self.val_g_losses[:-patience]) for l in recent):
+ if patience > 0 and len(self.val_reconstruction_losses) > patience:
+ recent = self.val_reconstruction_losses[-patience:]
+ previous_best = min(self.val_reconstruction_losses[:-patience])
+ if all(loss >= previous_best for loss in recent):
print(f"Early stopping at epoch {epoch + 1}")
break
@@ -188,4 +235,4 @@ if __name__ == "__main__":
g_loss, d_loss = trainer.train_epoch()
print(f"Training step succeeded: G={g_loss:.4f}, D={d_loss:.4f}")
except Exception as e:
- print(f"Training step failed: {e}")
\ No newline at end of file
+ print(f"Training step failed: {e}")
diff --git a/practice/Indiv_zadanie_Сагитов.docx b/practice/Indiv_zadanie_Сагитов.docx
new file mode 100644
index 0000000..81166db
Binary files /dev/null and b/practice/Indiv_zadanie_Сагитов.docx differ
diff --git a/practice/otchet_po_praktike.docx b/practice/otchet_po_praktike.docx
new file mode 100644
index 0000000..6d895b4
Binary files /dev/null and b/practice/otchet_po_praktike.docx differ
diff --git a/practice/Дневник практики.docx b/practice/Дневник практики.docx
new file mode 100644
index 0000000..b543ed6
Binary files /dev/null and b/practice/Дневник практики.docx differ
diff --git a/sian_similarity.py b/sian_similarity.py
new file mode 100644
index 0000000..2ca2622
--- /dev/null
+++ b/sian_similarity.py
@@ -0,0 +1,145 @@
+from __future__ import annotations
+
+import importlib.util
+import os
+from pathlib import Path
+from typing import Optional, Sequence
+
+import numpy as np
+import torch
+from PIL import Image
+
+from vision_chunk import VisionChunk
+
+
+ROOT_DIR = Path(__file__).resolve().parent
+MODEL_FILE = ROOT_DIR / "models" / "SiaN-similarity" / "model.py"
+DEFAULT_CHECKPOINT_PATH = (
+ ROOT_DIR
+ / "models"
+ / "SiaN-similarity"
+ / "runs"
+ / "gan_training"
+ / "checkpoints"
+ / "best_model.pt"
+)
+
+IMAGE_SIZE = (256, 256)
+DEFAULT_THRESHOLD = 0.5
+CHECKPOINT_ENV = "SIAN_SIMILARITY_CHECKPOINT"
+THRESHOLD_ENV = "SIAN_SIMILARITY_THRESHOLD"
+
+_model: Optional[torch.nn.Module] = None
+_device: Optional[torch.device] = None
+
+
+def _load_similarity_class():
+ spec = importlib.util.spec_from_file_location("sian_similarity_model", MODEL_FILE)
+ if spec is None or spec.loader is None:
+ raise ImportError(f"Cannot load similarity model from {MODEL_FILE}")
+
+ module = importlib.util.module_from_spec(spec)
+ spec.loader.exec_module(module)
+ return module.SimilarityCNN
+
+
+def _get_checkpoint_path() -> Path:
+ checkpoint_path = os.getenv(CHECKPOINT_ENV)
+ if checkpoint_path:
+ return Path(checkpoint_path).expanduser().resolve()
+ return DEFAULT_CHECKPOINT_PATH
+
+
+def get_threshold() -> float:
+ threshold = os.getenv(THRESHOLD_ENV)
+ if threshold is None:
+ return DEFAULT_THRESHOLD
+ return float(threshold)
+
+
+def _get_device() -> torch.device:
+ global _device
+
+ if _device is None:
+ _device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+ return _device
+
+
+def _get_model() -> torch.nn.Module:
+ global _model
+
+ if _model is not None:
+ return _model
+
+ checkpoint_path = _get_checkpoint_path()
+ if not checkpoint_path.exists():
+ raise FileNotFoundError(
+ f"SiaN similarity checkpoint not found: {checkpoint_path}. "
+ f"Set {CHECKPOINT_ENV} to another .pt file if needed."
+ )
+
+ SimilarityCNN = _load_similarity_class()
+ device = _get_device()
+
+ model = SimilarityCNN(
+ input_channels=3,
+ backbone_name="resnet18",
+ pretrained=False,
+ dropout_rate=0.3,
+ use_batch_norm=True,
+ ).to(device)
+
+ checkpoint = torch.load(checkpoint_path, map_location=device)
+ state_dict = checkpoint.get("model_state_dict", checkpoint)
+ model.load_state_dict(state_dict)
+ model.eval()
+
+ _model = model
+ return _model
+
+
+def _chunk_to_tensor(chunk: VisionChunk) -> torch.Tensor:
+ image = chunk.image.convert("RGB").resize(IMAGE_SIZE, Image.BILINEAR)
+ array = np.asarray(image, dtype=np.float32) / 255.0
+ tensor = torch.from_numpy(array).permute(2, 0, 1).unsqueeze(0)
+ return tensor.to(_get_device())
+
+
+def get_similarity_score(chunk1: VisionChunk, chunk2: VisionChunk) -> float:
+ if chunk1 is None or chunk2 is None:
+ return 0.0
+
+ model = _get_model()
+ img1 = _chunk_to_tensor(chunk1)
+ img2 = _chunk_to_tensor(chunk2)
+
+ with torch.inference_mode():
+ similarity = model(img1, img2)
+
+ return float(similarity.squeeze().item())
+
+
+def get_similarity_scores(chunk: VisionChunk, candidates: Sequence[VisionChunk]) -> list[float]:
+ if chunk is None or not candidates:
+ return []
+
+ model = _get_model()
+ img = _chunk_to_tensor(chunk)
+ candidate_images = torch.cat([_chunk_to_tensor(candidate) for candidate in candidates], dim=0)
+ repeated_img = img.expand(candidate_images.shape[0], -1, -1, -1)
+
+ with torch.inference_mode():
+ similarities = model(repeated_img, candidate_images)
+
+ return [float(score) for score in similarities.squeeze(1).detach().cpu().tolist()]
+
+
+def is_similar(
+ chunk1: VisionChunk,
+ chunk2: VisionChunk,
+ threshold: Optional[float] = None,
+) -> bool:
+ if threshold is None:
+ threshold = get_threshold()
+
+ return get_similarity_score(chunk1, chunk2) >= threshold