diff --git a/.vscode/c_cpp_properties.json b/.vscode/c_cpp_properties.json new file mode 100644 index 0000000..79ddb0c --- /dev/null +++ b/.vscode/c_cpp_properties.json @@ -0,0 +1,21 @@ +{ + "configurations": [ + { + "name": "Win32", + "includePath": [ + "${workspaceFolder}/**" + ], + "defines": [ + "_DEBUG", + "UNICODE", + "_UNICODE" + ], + "windowsSdkVersion": "10.0.26100.0", + "compilerPath": "cl.exe", + "cStandard": "c17", + "cppStandard": "c++17", + "intelliSenseMode": "windows-msvc-x64" + } + ], + "version": 4 +} \ No newline at end of file diff --git a/autopilot.py b/autopilot.py index 1383e99..8ff8752 100644 --- a/autopilot.py +++ b/autopilot.py @@ -8,6 +8,7 @@ import cv2 import numpy as np from PIL import Image +import sian_similarity from timer import Timer from vision_chunk import VisionChunk @@ -58,8 +59,16 @@ class AutoPilot(Pilot): # Положение на основе ориентира reserved_pos: Position | None proccessing_time: float + use_sian_similarity: bool - def __init__(self, points = [], chunks = [], viz_manager=None, pixel_ratio: float = 1.): + def __init__( + self, + points = [], + chunks = [], + viz_manager=None, + pixel_ratio: float = 1., + use_sian_similarity: bool = False, + ): self.prev_chunk = None self.pos = Position(0, 0, 1, 0, 0, 0) self.chunks = chunks @@ -67,6 +76,7 @@ class AutoPilot(Pilot): self.vis_manager = viz_manager # Менеджер визуализации self.reserved_pos = None self.pixel_ratio = pixel_ratio + self.use_sian_similarity = use_sian_similarity # Пороговые значения качества сопоставления/гомографии self.min_inliers: int = 12 @@ -153,17 +163,33 @@ class AutoPilot(Pilot): landmark_timer = Timer() landmark_timer.start() - cur_pos = np.array([self.pos.x, self.pos.y]) - closest_chunk_idx = ((self.chunk_points - cur_pos) ** 2).sum(1).argmin() - current_chunk = self.prev_chunk - landmark_chunk = self.chunks[closest_chunk_idx] + if current_chunk is None or not self.chunks: + return None + + if self.use_sian_similarity: + similarity_scores = sian_similarity.get_similarity_scores(current_chunk, self.chunks) + best_chunk_idx = int(np.argmax(similarity_scores)) + best_similarity_score = similarity_scores[best_chunk_idx] + + print(f"[LANDMARK]: best similarity={best_similarity_score:.4f}, idx={best_chunk_idx}") + + if best_similarity_score < sian_similarity.get_threshold(): + print("[LANDMARK]: not similar") + return None + + print("[LANDMARK]: similar") + landmark_chunk = self.chunks[best_chunk_idx] + else: + cur_pos = np.array([self.pos.x, self.pos.y]) + closest_chunk_idx = ((self.chunk_points - cur_pos) ** 2).sum(1).argmin() + landmark_chunk = self.chunks[closest_chunk_idx] if constants.DEBUG_FPS: - print(f"[LANDMARK]: Closest chunk finding: {landmark_timer.loop() * 1000:.2f} ms") + print(f"[LANDMARK]: Landmark chunk finding: {landmark_timer.loop() * 1000:.2f} ms") # Краевой случай: отсутствие чанков - if current_chunk is None or landmark_chunk is None: + if landmark_chunk is None: return None landmark_timer.start() @@ -312,10 +338,10 @@ class AutoPilot(Pilot): # Пытаемся найти ориентир на картинке: self.prev_chunk = current_chunk # Для улучшения среднего FPS - # if self.frame_count % 5 == 0: - # pos_by_chunk = self.get_position_by_chunk() - # if pos_by_chunk is not None: - # self.pos = pos_by_chunk + if self.frame_count % 5 == 0: + pos_by_chunk = self.get_position_by_chunk() + if pos_by_chunk is not None: + self.pos = pos_by_chunk command = self.make_command() self.timer.reset() diff --git a/dissertation/_media/gan_patchgan_discriminator.svg b/dissertation/_media/gan_patchgan_discriminator.svg new file mode 100644 index 0000000..c89e7d3 --- /dev/null +++ b/dissertation/_media/gan_patchgan_discriminator.svg @@ -0,0 +1,118 @@ + + + + + + + + + + Дискриминатор PatchGAN + Проверяет пару изображений локальными патчами: настоящая ли это пара Google + Яндекс или результат генератора + + Входные пары + Объединение + PatchGAN + Оценка + + + + Real pair + Google + настоящий Яндекс + целевая метка: 1 + + + + + + Fake pair + Google + Generated Яндекс + целевая метка: 0 + + + + + + + + + Concat + 6 каналов + RGB Google + RGB Yandex + B x 6 x 256 x 256 + + + + + + + Сверточные блоки + Conv + BatchNorm + LeakyReLU + + + + + + + + + + + + 64 -> 128 -> 256 -> 512 + + + + + + + Patch map + оценка real/fake + + + + + + + + + + + + + каждая ячейка + соответствует локальной + области изображения + + + + + + + Функция потерь + реальная пара -> 1, сгенерированная пара -> 0 + + + Главная идея PatchGAN: проверять не всю карту одним числом, а локальные признаки стиля и структуры. + diff --git a/dissertation/_media/gan_pipeline.svg b/dissertation/_media/gan_pipeline.svg new file mode 100644 index 0000000..912246f --- /dev/null +++ b/dissertation/_media/gan_pipeline.svg @@ -0,0 +1,119 @@ + + + + + + + + + + Применение GAN для приведения карт к единому домену + Нейросеть меняет стиль карты, а дальнейшее сопоставление выполняется классическими геометрическими методами + + Исходные данные + Доменная адаптация + Сопоставление + Результат + + + + Google Maps + фрагмент области полета + + + + + + + + + Яндекс.Карты + эталон с ориентирами + + + + + + + + + + + GAN + Google -> Yandex + + сохраняет геометрию + меняет визуальный стиль + + + + + + + Сгенерированный кадр + карта в целевом домене + + + + + + + + + + + + + Ключевые точки + ORB / SIFT / AKAZE + + + + + + + + + + + + RANSAC + отбор совпадений + + + + + + + + + + Гомография H + связь координат + + + коррекция позиции БПЛА + + + Итог: изображения становятся визуально ближе, поэтому классические методы получают больше устойчивых совпадений. + diff --git a/dissertation/_media/gan_unet_generator.svg b/dissertation/_media/gan_unet_generator.svg new file mode 100644 index 0000000..765567e --- /dev/null +++ b/dissertation/_media/gan_unet_generator.svg @@ -0,0 +1,104 @@ + + + + + + + + + + Архитектура генератора U-Net: Google Maps -> Яндекс.Карты + Сжатие изображения до признакового представления и восстановление в целевом стиле с сохранением геометрии через skip-соединения + + + + Input + 3 x 256 x 256 + Google RGB + + + Down 1 + 64 + 128 x 128 + + + Down 2 + 128 + 64 x 64 + + + Down 3 + 256 + 32 x 32 + + + Down 4 + 512 + 16 x 16 + + + + Bottleneck + 512 + + + + Up 4 + 512 + 16 x 16 + + + Up 5 + 256 + 32 x 32 + + + Up 6 + 128 + 64 x 64 + + + Up 7 + 64 + 128 x 128 + + + + Final Conv + Tanh + 3 x 256 x 256 + Generated Yandex RGB + + + + + + + + + + + + + + + + + + Down-блок: Conv2d, BatchNorm, LeakyReLU. Up-блок: Upsample, Conv2d, BatchNorm, ReLU, затем конкатенация со skip-признаками. + Назначение skip-соединений: сохранить дороги, перекрестки и контуры объектов при изменении визуального стиля карты. + diff --git a/dissertation/chapter_2/2.3_deep_learning/2.3.3_additional.md b/dissertation/chapter_2/2.3_deep_learning/2.3.3_additional.md index 1d864b7..8648653 100644 --- a/dissertation/chapter_2/2.3_deep_learning/2.3.3_additional.md +++ b/dissertation/chapter_2/2.3_deep_learning/2.3.3_additional.md @@ -1,3 +1,39 @@ -2.3.3 Применение архитектуры сиамских близнецов для вычисления матрицы гомографии между двумя кадрами +2.3.3 Генеративно-состязательная сеть для приведения карт к единому домену +При сопоставлении кадров БПЛА с эталонной картой возникает не только геометрическое, но и доменное различие изображений. Даже если два фрагмента описывают один и тот же участок местности, снимки из разных источников могут иметь разные цвета, толщину дорог, подписи, условные обозначения, контраст зданий и набор отображаемых ориентиров. Например, один и тот же район на карте Google и на карте Яндекс визуально отличается настолько, что классические алгоритмы поиска ключевых точек могут находить мало устойчивых совпадений или формировать большое число ложных соответствий. +Одним из способов уменьшить этот доменный разрыв является предварительное преобразование изображения из одного картографического домена в другой. В данной работе рассматривается генеративно-состязательная сеть (Generative Adversarial Network, GAN), которая переводит фрагмент карты Google в визуальный стиль карты Яндекс. После такого преобразования исходный фрагмент Google становится ближе к эталонному домену Яндекс, а значит для дальнейшей локализации можно применять классические методы выделения и сопоставления ключевых точек: ORB, SIFT, AKAZE, BRISK и последующую оценку матрицы гомографии при помощи RANSAC. + +В отличие от модели сиамских близнецов, которая напрямую оценивает схожесть пары изображений, GAN решает вспомогательную задачу нормализации домена. То есть нейросеть не заменяет классический алгоритм сопоставления, а подготавливает данные так, чтобы классический алгоритм работал в более благоприятных условиях. + +![Схема применения GAN в задаче навигации БПЛА](../../_media/gan_pipeline.svg) + +Рисунок 8 – Применение GAN для приведения картографических изображений к единому домену + +Архитектура модели построена по принципу pix2pix, так как для обучения доступны парные изображения одного и того же участка местности в двух доменах: Google Maps и Яндекс.Карты. На вход генератора подается изображение Google размером \left(B,3,256,256\right), где B – размер пакета данных. Генератор формирует изображение \hat{Y}, визуально соответствующее стилю Яндекс.Карт. Дискриминатор получает на вход пару изображений и должен определить, является ли пара реальной \left(G,Y\right) или сгенерированной \left(G,\hat{Y}\right), где G – исходный фрагмент Google, Y – настоящий фрагмент Яндекс, \hat{Y} – результат работы генератора. + +Генератор реализован в виде U-Net. Энкодер последовательно уменьшает пространственное разрешение изображения и извлекает признаки высокого уровня: структуру дорог, контуры кварталов, границы зданий, водные объекты и другие устойчивые элементы карты. Декодер восстанавливает изображение в целевом стиле. Между симметричными уровнями энкодера и декодера используются skip-соединения, которые передают локальную геометрию напрямую из ранних слоев в поздние. Это важно для навигационной задачи: модель должна изменить стиль карты, но не должна смещать дороги, перекрестки и контуры объектов, так как именно они затем используются как ориентиры. + +![Архитектура генератора U-Net для преобразования Google в Яндекс](../../_media/gan_unet_generator.svg) + +Рисунок 9 – Архитектура генератора U-Net + +Дискриминатор реализован как PatchGAN. В отличие от обычного дискриминатора, который выдает одно значение для всего изображения, PatchGAN оценивает реалистичность локальных областей. На вход дискриминатора подается конкатенация исходного изображения Google и изображения Яндекс по каналам, поэтому входной тензор имеет размер \left(B,6,256,256\right). Далее изображение проходит через сверточные блоки с постепенным уменьшением разрешения, а выходом является карта оценок для отдельных фрагментов. Такой подход подходит для картографических изображений, потому что локальные признаки – ширина дорог, стиль подписей, границы объектов, цветовые переходы – важнее глобальной художественной реалистичности. + +![Архитектура дискриминатора PatchGAN](../../_media/gan_patchgan_discriminator.svg) + +Рисунок 10 – Архитектура дискриминатора PatchGAN + +Обучение модели является состязательным. Генератор стремится сформировать такое изображение \hat{Y}, чтобы дискриминатор считал пару \left(G,\hat{Y}\right) реальной. Дискриминатор, наоборот, учится отличать настоящие пары \left(G,Y\right) от сгенерированных. Для сохранения геометрии карты одной только состязательной функции потерь недостаточно, поэтому итоговая функция потерь генератора включает несколько компонентов: + +L_G = \lambda_{GAN}L_{GAN}\left(D\left(G,\hat{Y}\right),1\right)+\lambda_{L1}\left\|\hat{Y}-Y\right\|_1+\lambda_{SSIM}L_{SSIM}\left(\hat{Y},Y\right)+\lambda_{edge}L_{edge}\left(\hat{Y},Y\right) (1) + +где L_{GAN} – состязательная функция потерь, L1 – попиксельная ошибка между сгенерированным и настоящим изображением Яндекс, L_{SSIM} – структурная ошибка, сохраняющая сходство локальной структуры, L_{edge} – ошибка по картам границ, вычисленным оператором Собеля. Коэффициенты \lambda_{GAN}, \lambda_{L1}, \lambda_{SSIM} и \lambda_{edge} задают вклад каждого компонента. В реализованной модели используются значения \lambda_{GAN}=0.5, \lambda_{L1}=150, \lambda_{SSIM}=25, \lambda_{edge}=20. Усиленные L1, SSIM и edge-компоненты делают модель менее «творческой», но лучше сохраняют контуры дорог и объектов, что важнее для последующего поиска ключевых точек. + +Функция потерь дискриминатора имеет вид: + +L_D = \frac{1}{2}\left(L_{GAN}\left(D\left(G,Y\right),1\right)+L_{GAN}\left(D\left(G,\hat{Y}\right),0\right)\right) (2) + +После обучения GAN может использоваться в навигационном пайплайне следующим образом. Сначала для области предполагаемого положения БПЛА загружается или выбирается фрагмент Google Maps. Затем генератор переводит этот фрагмент в стиль Яндекс.Карт. На полученном изображении и на эталонном фрагменте Яндекс.Карт выделяются ключевые точки и дескрипторы. Далее дескрипторы сопоставляются, ложные соответствия отбрасываются при помощи RANSAC, а по оставшимся точкам оценивается матрица гомографии. Полученная матрица позволяет связать координаты текущего изображения с координатами эталонной карты и уточнить положение БПЛА. + +Таким образом, GAN выступает как промежуточный модуль доменной адаптации. Его применение особенно полезно в ситуации, когда источник доступной карты и источник эталонных ориентиров различаются. В рассматриваемой задаче это позволяет перевести изображения Google в представление, близкое к Яндекс.Картам, где ориентиры визуально согласованы между собой. Благодаря этому классические методы компьютерного зрения получают более похожие изображения и могут устойчивее находить ключевые точки, не требуя полного отказа от интерпретируемого геометрического пайплайна. diff --git a/dissertation/chapter_2/2.3_deep_learning/readme.md b/dissertation/chapter_2/2.3_deep_learning/readme.md index 01c5cbe..2ef72d5 100644 --- a/dissertation/chapter_2/2.3_deep_learning/readme.md +++ b/dissertation/chapter_2/2.3_deep_learning/readme.md @@ -6,3 +6,4 @@ |-----------|----------|------| | 2.3.1 | Применение архитектуры сиамских близнецов для сопоставления кадров из различных доменов | 2.3.1_siamese_match.md | | 2.3.2 | Применение архитектуры сиамских близнецов для вычисления матрицы гомографии | 2.3.2_siamese_homography.md | +| 2.3.3 | Генеративно-состязательная сеть для приведения карт к единому домену | 2.3.3_additional.md | diff --git a/dissertation/chapter_2/readme.md b/dissertation/chapter_2/readme.md index b40136d..4804abe 100644 --- a/dissertation/chapter_2/readme.md +++ b/dissertation/chapter_2/readme.md @@ -9,5 +9,6 @@ | 2.3 | Методы глубокого обучения | 2.3_deep_learning/ | | 2.3.1 | Сиамские близнецы для сопоставления кадров из различных доменов | 2.3_deep_learning/2.3.1_siamese_match.md | | 2.3.2 | Сиамские близнецы для вычисления матрицы гомографии | 2.3_deep_learning/2.3.2_siamese_homography.md | +| 2.3.3 | Генеративно-состязательная сеть для приведения карт к единому домену | 2.3_deep_learning/2.3.3_additional.md | | 2.4 | Датасет | 2.4_dataset/ | | 2.5 | Обучение моделей глубокого обучения | 2.5_training/ | diff --git a/docs/google-city-1 --use-sian-similarity.txt b/docs/google-city-1 --use-sian-similarity.txt new file mode 100644 index 0000000..5c93abf --- /dev/null +++ b/docs/google-city-1 --use-sian-similarity.txt @@ -0,0 +1,41 @@ +google-city-1 --use-sian-similarity +MSE: 1006.2278793057307 +RMSE: 31.721095178220608 +Average FPS: 2.1113510203551025 + +google-city-2 --use-sian-similarity +MSE: 0.018811735940268307 +RMSE: 0.13715588190182842 +Average FPS: 4.383734633135686 + +google-city-3 --use-sian-similarity +MSE: 0.06464743825147634 +RMSE: 0.2542586050686905 +Average FPS: 4.047345287474747 + +google-city-4 --use-sian-similarity +MSE: 0.3089718382086463 +RMSE: 0.5558523528857697 +Average FPS: 2.9425785696773636 + + + +google-city-1 +MSE: 716.9078171684257 +RMSE: 26.775134307196776 +Average FPS: 38.766897877580504 + +google-city-2 +MSE: 0.06237516673056363 +RMSE: 0.2497502086697099 +Average FPS: 25.821371544698586 + +google-city-3 +MSE: 0.07598762118336233 +RMSE: 0.2756585227838282 +Average FPS: 26.37680633915316 + +google-city-3 +MSE: 0.7199724273833397 +RMSE: 0.8485118899481254 +Average FPS: 29.534802136097262 \ No newline at end of file diff --git a/main.py b/main.py index 31d3d90..807491c 100644 --- a/main.py +++ b/main.py @@ -120,7 +120,7 @@ def build(name: str, map_name: str, lat: float, lon: float): sleep(15) online_map.destroy() -def run(name: str, map_name: str, ref_min_distance: float): +def run(name: str, map_name: str, ref_min_distance: float, use_sian_similarity: bool = False): dir = Path('trajectories') assert dir.exists() dir /= name @@ -158,7 +158,13 @@ def run(name: str, map_name: str, ref_min_distance: float): print("READ POINTS:", points) vis_manager = VisualizationManager() - pilot = autopilot.AutoPilot(points, chunks, vis_manager, online_map.pixel_ratio) + pilot = autopilot.AutoPilot( + points, + chunks, + vis_manager, + online_map.pixel_ratio, + use_sian_similarity=use_sian_similarity, + ) simulator = Simulator(online_map) pilot.target_idx = 0 @@ -300,6 +306,12 @@ def parse_args(): help='Включить отладку эталонов' ) + parser.add_argument( + '--use-sian-similarity', + action='store_true', + help='Выбирать ориентир через SiaN similarity вместо ближайшего по текущей позиции' + ) + # Парсим аргументы args = parser.parse_args() @@ -327,4 +339,4 @@ if __name__ == "__main__": build(name, ref, lat, lon) if mode == 'run' or mode == 'standalone': - run(name, sim, rmd) + run(name, sim, rmd, args.use_sian_similarity) diff --git a/map.jpg b/map.jpg index 071d7b8..d8f8ff8 100644 Binary files a/map.jpg and b/map.jpg differ diff --git a/models/GAN/README.md b/models/GAN/README.md index b7c7824..dae5679 100644 --- a/models/GAN/README.md +++ b/models/GAN/README.md @@ -1,254 +1,111 @@ -# GAN Trainer для преобразования изображений Yandex → Google - -Этот модуль содержит реализацию тренера для GAN (Generative Adversarial Network) модели, предназначенной для преобразования изображений карт Yandex в стиль Google Maps. - -## Структура проекта - -``` -autopilot/models/GAN/ -├── gan.py # Основная реализация GAN модели -├── trainer.py # Тренер для обучения GAN -├── test_trainer.py # Тесты для тренера -├── train_example.py # Пример использования тренера -└── README.md # Этот файл -``` - -## Модель GAN - -Модель состоит из двух основных компонентов: - -### 1. Генератор (GeneratorUNet) -- Архитектура U-Net для преобразования изображений -- Принимает изображение Yandex (3 канала RGB) -- Возвращает изображение в стиле Google (3 канала RGB) -- Использует skip connections для сохранения деталей - -### 2. Дискриминатор (DiscriminatorPatchGAN) -- PatchGAN архитектура -- Принимает пару изображений (Yandex + Google) -- Возвращает вероятность того, что пара реальная -- Работает с патчами изображения 41x41 - -### Функция потерь (GANLoss) -Поддерживает три режима: -- `vanilla`: Бинарная кросс-энтропия -- `lsgan`: Least Squares GAN (более стабильный) -- `wgangp`: Wasserstein GAN with Gradient Penalty - -## Тренер (GANTrainer) - -### Основные возможности - -1. **Обучение с чередованием**: - - Обучение генератора и дискриминатора поочередно - - Поддержка L1 потерь для сохранения структуры - -2. **Валидация и мониторинг**: - - Отдельные потери для генератора и дискриминатора - - Логирование в TensorBoard - - Ранняя остановка - -3. **Сохранение и загрузка**: - - Чекпоинты каждой эпохи - - Лучшая модель - - Финальная модель - - История обучения - -4. **Оценка модели**: - - Метрики на тестовом наборе - - Генерация примеров - -### Быстрый старт - -```python -import torch -from torch.utils.data import DataLoader -from models.GAN.gan import create_image_gan -from models.GAN.trainer import GANTrainer - -# Конфигурация -config = { - "learning_rate": 2e-4, - "beta1": 0.5, - "beta2": 0.999, - "batch_size": 4, - "output_dir": "runs/gan_training", - "gan_mode": "vanilla", - "lambda_L1": 100.0, - "early_stopping_patience": 20, -} - -# Устройство -device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - -# Создание модели -model = create_image_gan( - input_channels=3, - output_channels=3, - gan_mode=config["gan_mode"], - lambda_L1=config["lambda_L1"], - use_cuda=(device.type == "cuda"), -) - -# Создание даталоадеров (замените на свои данные) -train_loader = DataLoader(train_dataset, batch_size=config["batch_size"], shuffle=True) -val_loader = DataLoader(val_dataset, batch_size=config["batch_size"], shuffle=False) - -# Создание тренера -trainer = GANTrainer( - model=model, - train_loader=train_loader, - val_loader=val_loader, - device=device, - config=config, -) - -# Обучение -trainer.train(num_epochs=100) - -# Оценка -metrics = trainer.evaluate(test_loader) -``` - -### Конфигурация обучения - -#### Базовая конфигурация -```python -config = { - # Параметры оптимизатора - "learning_rate": 2e-4, # Learning rate - "beta1": 0.5, # Adam beta1 - "beta2": 0.999, # Adam beta2 - - # Параметры обучения - "batch_size": 4, # Размер батча - "epochs": 100, # Количество эпох - - # Параметры GAN - "gan_mode": "vanilla", # Режим GAN - "lambda_L1": 100.0, # Вес L1 потерь - - # Регуляризация - "grad_clip": 1.0, # Gradient clipping - - # Ранняя остановка - "early_stopping_patience": 20, - - # Выходные данные - "output_dir": "runs/gan", -} -``` - -#### Расширенная конфигурация -```python -config = { - "learning_rate": 2e-4, - "beta1": 0.5, - "beta2": 0.999, - "batch_size": 8, - "epochs": 200, - "gan_mode": "lsgan", # Более стабильный LSGAN - "lambda_L1": 100.0, - "grad_clip": 1.0, - "weight_decay": 1e-4, # Weight decay - "early_stopping_patience": 30, - "early_stopping_min_delta": 1e-4, - "output_dir": "runs/gan_advanced", -} -``` - -### Методы тренера - -#### Основные методы -- `train_epoch()`: Обучение на одной эпохе -- `validate()`: Валидация модели -- `train(num_epochs)`: Полное обучение -- `evaluate(test_loader)`: Оценка на тестовых данных - -#### Управление чекпоинтами -- `save_checkpoint(is_best=False)`: Сохранение чекпоинта -- `load_checkpoint(path, resume_training=False)`: Загрузка чекпоинта - -### Выходные файлы - -После обучения создаются следующие файлы: - -``` -runs/gan_training/ -├── config.json # Конфигурация обучения -├── training_history.json # История потерь -├── model_best.pth # Лучшая модель -├── model_final.pth # Финальная модель -├── checkpoint_epoch_1.pth # Чекпоинты каждой эпохи -├── checkpoint_epoch_2.pth -├── ... -└── tensorboard/ # Логи TensorBoard - ├── events.out.tfevents... - └── ... -``` - -### TensorBoard - -Для визуализации обучения используйте TensorBoard: - -```bash -tensorboard --logdir runs/gan_training/tensorboard -``` - -Доступные метрики: -- `train/batch_g_loss`: Потери генератора на батче -- `train/batch_d_loss`: Потери дискриминатора на батче -- `train/batch_g_l1_loss`: L1 потери генератора -- `train/epoch_g_loss`: Потери генератора на эпохе -- `train/epoch_d_loss`: Потери дискриминатора на эпохе -- `val/epoch_g_loss`: Валидационные потери генератора -- `val/epoch_d_loss`: Валидационные потери дискриминатора - -### Тестирование - -Запустите тесты для проверки работоспособности: - -```bash -python models/GAN/test_trainer.py -``` - -### Пример использования - -Полный пример использования смотрите в `train_example.py`. - -### Советы по обучению - -1. **Начальные значения**: - - Используйте `gan_mode="lsgan"` для более стабильного обучения - - Начните с `lambda_L1=100.0` и регулируйте по необходимости - - Используйте маленький `batch_size` (4-8) при ограниченной памяти GPU - -2. **Мониторинг**: - - Следите за балансом потерь генератора и дискриминатора - - Если потери дискриминатора близки к 0, генератор не обучается - - Если потери генератора слишком высоки, уменьшите `lambda_L1` - -3. **Визуализация**: - - Регулярно генерируйте примеры для визуальной оценки - - Используйте TensorBoard для отслеживания прогресса - -### Устранение проблем - -#### Высокие потери генератора -- Уменьшите `lambda_L1` -- Увеличьте learning rate -- Проверьте качество данных - -#### Дискриминатор слишком сильный -- Уменьшите learning rate дискриминатора -- Добавьте dropout в дискриминатор -- Обучайте генератор чаще, чем дискриминатор - -#### Недостаток памяти GPU -- Уменьшите `batch_size` -- Уменьшите размер изображений -- Используйте gradient accumulation - -### Лицензия - -Этот проект является частью Autopilot системы. \ No newline at end of file +# GAN для преобразования Google -> Yandex + +Модуль реализует pix2pix-подобный GAN для парных изображений карт. Генератор получает изображение Google из пары и пытается сгенерировать изображение в стиле Yandex. Сгенерированная картинка сравнивается со вторым изображением пары, то есть с оригинальным `*_yandex.png`. + +## Структура + +```text +models/GAN/ +├── build.py # генератор notebook.gen.ipynb по схеме +├── _schema.py # структура генерируемого ноутбука +├── _schema.md # описание формата схемы +├── notebook.gen.ipynb +├── src/ +│ ├── config.py +│ ├── dataloader.py +│ ├── model.py +│ ├── trainer.py +│ ├── analyze.py +│ └── main.py +└── README.md +``` + +## Архитектура + +`GeneratorUNet` принимает `google_img` с 3 RGB-каналами и возвращает `fake_yandex`. + +`DiscriminatorPatchGAN` получает пару `(google_img, yandex_img)` и отличает настоящую пару от `(google_img, fake_yandex)`. + +`ImageGAN.generator_step()` считает: + +- adversarial loss: дискриминатор должен принять `(google_img, fake_yandex)` за реальную пару; +- L1 loss: `fake_yandex` сравнивается с оригинальным `yandex_img` из той же пары. +- SSIM loss: штрафует структурные отличия, чтобы карта не расплывалась; +- Sobel edge loss: сохраняет контуры дорог и объектов для последующего поиска ключевых точек. + +Итоговая функция потерь генератора: + +```text +G_loss = + lambda_GAN * GAN_loss(D(google_img, fake_yandex), real) + + lambda_L1 * L1(fake_yandex, yandex_img) + + lambda_SSIM * SSIMLoss(fake_yandex, yandex_img) + + lambda_edge * SobelEdgeLoss(fake_yandex, yandex_img) +``` + +По умолчанию используется `lsgan`, усиленная реконструкция и более медленный +дискриминатор. Это менее “креативно”, зато обычно даёт более чистые контуры. + +## Запуск + +Из папки `models/GAN`: + +```bash +python src/main.py +``` + +`src/main.py` специально написан без функции `main()`: при генерации ноутбука +все переменные остаются доступны после запуска ячейки, как в SiaN. + +Чтобы пересобрать ноутбук из схемы: + +```bash +python build.py +``` + +По умолчанию датасет берется из: + +```text +C:\Users\admin\Projects\autopilot\datasets\ya_go_maps\images +``` + +Путь, размер изображений, batch size и число эпох меняются в `src/config.py`. + +Если CUDA видна, но установленный PyTorch не поддерживает архитектуру GPU +например Tesla P100 `sm_60`, код автоматически переключится на CPU. Это +управляется параметром `prefer_cuda` в `src/config.py`. + +## Ожидаемый формат данных + +В директории датасета должны лежать пары: + +```text +0000_google.png +0000_yandex.png +0001_google.png +0001_yandex.png +... +``` + +Даталоадер возвращает: + +- `google_img`: вход генератора; +- `yandex_img`: целевое изображение для сравнения; +- `idx`: номер пары. + +## Чекпоинты + +Тренер сохраняет чекпоинты в: + +```text +models/GAN/runs/checkpoints +``` + +Сохраняются `best.pth`, периодические `epoch_N.pth` и `final.pth`. + +После обучения `src/main.py` также строит: + +- графики `G/D/L1/SSIM/edge loss`; +- сетку `Google input -> Generated Yandex -> Yandex target`. + +Файлы сохраняются в `models/GAN/runs/images`. diff --git a/models/GAN/_schema.md b/models/GAN/_schema.md new file mode 100644 index 0000000..407668c --- /dev/null +++ b/models/GAN/_schema.md @@ -0,0 +1,32 @@ +# GAN Schema + +Notebook structure definition for the Google -> Yandex GAN model. + +--- + +## Format + +```text +# === IMPORTS === + + +# markdown +"""Description""" + +# code: ./src/file.py + +# # shell: + +``` + +## Directives + +| Directive | Description | +|----------------|------------------------------------| +| `# code:` | Include file from `src/` | +| `# markdown` | Add markdown cell | +| `# # shell:` | Add notebook shell command cell | + +--- + +`build.py` generates `notebook.gen.ipynb` from `_schema.py`. diff --git a/models/GAN/_schema.py b/models/GAN/_schema.py new file mode 100644 index 0000000..424259a --- /dev/null +++ b/models/GAN/_schema.py @@ -0,0 +1,121 @@ +# _schema.py + +# === IMPORTS === +import json +import os +from pathlib import Path +from typing import Any, Dict, List, Tuple + +import matplotlib.pyplot as plt +import torch +import torch.nn as nn +import torch.nn.functional as F +from PIL import Image +from torch.utils.data import DataLoader, Dataset, Subset +from torchvision import transforms +from tqdm import tqdm + +# markdown +"""# Configuration + +Global settings for the Google -> Yandex GAN: +- Dataset path and image size +- Optimizer and training hyperparameters +- Device preference with safe CUDA compatibility checks +- GAN, L1, SSIM and edge reconstruction weights +- Output directories for checkpoints and generated samples""" +# code: ./src/config.py + +# markdown +"""## Dataset + +Google/Yandex paired image loader. + +**Direction:** +- `google_img` is the generator input +- `yandex_img` is the target image from the same pair + +**Returns:** +- Batch dict with `google_img`, `yandex_img`, `idx`""" +# code: ./src/dataloader.py + +# code: ./src/test_dataloader.py + +# markdown +"""## Model + +Pix2pix-style GAN for Google -> Yandex map translation. + +**Generator:** +- `GeneratorUNet` +- Input: Google image `(B, 3, H, W)` +- Output: generated Yandex image `(B, 3, H, W)` + +**Discriminator:** +- `DiscriminatorPatchGAN` +- Input pair: `(google_img, yandex_img)` +- Learns to distinguish real pairs from `(google_img, fake_yandex)` + +**Generator loss:** +- adversarial loss +- `lambda_L1 * L1(fake_yandex, yandex_img)` +- `lambda_SSIM * SSIMLoss(fake_yandex, yandex_img)` +- `lambda_edge * SobelEdgeLoss(fake_yandex, yandex_img)` + +The generator uses bilinear upsampling followed by convolution to avoid +checkerboard artifacts from transposed convolutions.""" +# code: ./src/model.py + +# markdown +"""## Training + +`GANTrainer` trains discriminator and generator alternately. + +**Training step:** +1. Generate `fake_yandex = G(google_img)` +2. Train discriminator on real pair `(google_img, yandex_img)` and fake pair `(google_img, fake_yandex)` +3. Train generator against discriminator and paired Yandex target + +**Checkpoint saving:** +- `best.pth` +- `epoch_N.pth` +- `final.pth`""" +# code: ./src/trainer.py + +# markdown +"""## Analysis + +Visualization helpers for generated samples and collected training metrics. + +Training history plot contains: +1. Generator loss +2. Discriminator loss +3. L1 loss against the paired Yandex target +4. SSIM structure loss +5. Sobel edge loss +6. Best-checkpoint reconstruction score + +The sample grid contains: +1. Google input +2. Generated Yandex +3. Original Yandex target""" +# code: ./src/analyze.py + +# markdown +"""## Main Pipeline + +Executes the full GAN workflow: +1. Create config +2. Build paired data loaders +3. Initialize Google -> Yandex GAN +4. Train with validation +5. Save checkpoints in `runs/checkpoints/` +6. Show loss plots and generated sample grid + +This block is intentionally top-level, not wrapped in `main()`, so notebook +variables such as `model`, `trainer`, `train_loader`, `val_loader`, and +`training_analysis` remain available for debugging.""" +# code: ./src/main.py + +# # shell: +# !zip artefacts.zip runs/checkpoints/best.pth runs/images/training_history.png runs/images/generation_samples.png diff --git a/models/GAN/build.py b/models/GAN/build.py new file mode 100644 index 0000000..cdf5fbd --- /dev/null +++ b/models/GAN/build.py @@ -0,0 +1,260 @@ +import json +import os +import re + + +def parse_schema(schema_path): + with open(schema_path, 'r', encoding='utf-8') as f: + content = f.read() + + imports = [] + items = [] + + in_imports_section = False + lines = content.split('\n') + + i = 0 + while i < len(lines): + line = lines[i] + stripped = line.strip() + + if stripped == '# === IMPORTS ===': + in_imports_section = True + i += 1 + continue + elif stripped.startswith('# ==='): + in_imports_section = False + i += 1 + continue + + if in_imports_section and stripped and not stripped.startswith('#'): + imports.append(stripped) + i += 1 + continue + + if in_imports_section and stripped.startswith('#'): + in_imports_section = False + + if stripped.startswith('# code:'): + match = re.search(r'# code:\s*(.+)', stripped) + if match: + items.append(('code', match.group(1).strip())) + i += 1 + continue + + if stripped.startswith('# inline:') or stripped.startswith('# # shell:'): + directive = 'inline' if stripped.startswith('# inline:') else 'shell' + i += 1 + code_lines = [] + while i < len(lines): + stripped = lines[i].strip() + if directive == 'shell': + if stripped.startswith('# ') or stripped.startswith('# # '): + line_content = lines[i].strip() + if line_content.startswith('# # '): + line_content = line_content[3:].lstrip() + else: + line_content = line_content[2:] + code_lines.append(line_content) + i += 1 + continue + if stripped.startswith('#'): + if stripped.startswith('# code:') or stripped.startswith('# inline:') or stripped.startswith('# # shell:') or stripped == '# markdown' or stripped.startswith('# ==='): + break + i += 1 + continue + if stripped == '': + i += 1 + continue + code_lines.append(lines[i]) + i += 1 + if code_lines: + items.append((directive, '\n'.join(code_lines).rstrip())) + continue + + if stripped == '# markdown': + i += 1 + if i < len(lines): + next_line = lines[i].strip() + if next_line.startswith('"""'): + if next_line.endswith('"""') and len(next_line) > 3: + md_content = next_line[3:-3].strip() + i += 1 + else: + end_idx = None + for j in range(i + 1, len(lines)): + if '"""' in lines[j]: + end_idx = j + break + if end_idx: + md_content = '\n'.join(lines[i:end_idx]) + md_content = md_content.strip('"""').strip() + i = end_idx + 1 + else: + md_content = "" + i += 1 + items.append(('markdown', md_content)) + continue + + i += 1 + + return imports, items + + +def strip_all_imports(content): + lines = content.split('\n') + result_lines = [] + skip_block = False + if_block_indent = 0 + + for line in lines: + stripped = line.strip() + + if stripped.startswith('if __name__') or stripped.startswith('if __name__ =='): + skip_block = True + if_block_indent = len(line) - len(line.lstrip()) + continue + + if skip_block: + current_indent = len(line) - len(line.lstrip()) + if line.strip() == '': + continue + if current_indent < if_block_indent: + skip_block = False + elif current_indent == if_block_indent and stripped.startswith('if '): + skip_block = True + if_block_indent = current_indent + continue + else: + continue + + if stripped.startswith('import ') or stripped.startswith('from '): + continue + + result_lines.append(line) + + while result_lines and result_lines[-1].strip() == '': + result_lines.pop() + + return '\n'.join(result_lines) + + +def read_src_file(ref, base_dir): + ref_path = ref.replace('./src/', '').replace('src/', '') + full_path = os.path.join(base_dir, 'src', ref_path.replace('./src/', '').lstrip('/')) + + if not full_path.endswith('.py'): + full_path += '.py' + + with open(full_path, 'r', encoding='utf-8') as f: + return f.read() + + +def build_notebook(schema_path, output_path=None): + schema_dir = os.path.dirname(os.path.abspath(schema_path)) + if output_path is None: + output_path = os.path.join(schema_dir, 'notebook.gen.ipynb') + + imports, items = parse_schema(schema_path) + + cells = [] + + if imports: + imports_cell = { + "cell_type": "code", + "execution_count": None, + "metadata": {}, + "outputs": [], + "source": [imp + "\n" for imp in imports] + } + cells.append(imports_cell) + + for item_type, item_content in items: + if item_type == 'markdown': + md_cell = { + "cell_type": "markdown", + "metadata": {}, + "source": item_content + "\n" + } + cells.append(md_cell) + elif item_type == 'inline': + lines = item_content.split('\n') + while lines and lines[-1].strip() == '': + lines.pop() + if lines: + lines.append('') + cell = { + "cell_type": "code", + "execution_count": None, + "metadata": {}, + "outputs": [], + "source": [line + "\n" for line in lines] + } + cells.append(cell) + elif item_type == 'shell': + lines = item_content.split('\n') + while lines and lines[-1].strip() == '': + lines.pop() + if lines: + lines.append('') + cell = { + "cell_type": "code", + "execution_count": None, + "metadata": {}, + "outputs": [], + "source": [line + "\n" for line in lines] + } + cells.append(cell) + elif item_type == 'code': + try: + content = read_src_file(item_content, schema_dir) + content = strip_all_imports(content) + + lines = content.split('\n') + while lines and lines[-1].strip() == '': + lines.pop() + if lines: + lines.append('') + + cell = { + "cell_type": "code", + "execution_count": None, + "metadata": {}, + "outputs": [], + "source": [line + "\n" for line in lines] + } + cells.append(cell) + except FileNotFoundError: + print(f"Warning: Could not find: {item_content}") + + notebook = { + "cells": cells, + "metadata": { + "kernelspec": { + "display_name": ".venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "name": "python", + "version": "3.11.0" + } + }, + "nbformat": 4, + "nbformat_minor": 5 + } + + with open(output_path, 'w', encoding='utf-8') as f: + json.dump(notebook, f, indent=1, ensure_ascii=False) + + print(f"Notebook generated: {output_path}") + + +if __name__ == "__main__": + script_dir = os.path.dirname(os.path.abspath(__file__)) + schema_path = os.path.join(script_dir, '_schema.py') + + if os.path.exists(schema_path): + build_notebook(schema_path) + else: + print(f"Error: _schema.py not found at {schema_path}") diff --git a/models/GAN/main.py b/models/GAN/main.py deleted file mode 100644 index 5534ac9..0000000 --- a/models/GAN/main.py +++ /dev/null @@ -1,30 +0,0 @@ -"""Main entry point for GAN training.""" - -from config import create_config -from dataloader import create_data_loaders -from model import create_gan -from trainer import create_trainer - - -def main(): - """Run training pipeline.""" - config = create_config() - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - - # Create components - model = create_gan(use_cuda=False) # Set to True to use GPU - train_loader, val_loader = create_data_loaders( - root_dir=config["data_dir"], - batch_size=config["batch_size"], - image_size=tuple(config["image_size"]), - num_workers=config["num_workers"], - ) - trainer = create_trainer(model, train_loader, val_loader, config) - - # Train - trainer.train(config["epochs"]) - - -if __name__ == "__main__": - import torch - main() \ No newline at end of file diff --git a/models/GAN/notebook.gen.ipynb b/models/GAN/notebook.gen.ipynb new file mode 100644 index 0000000..49dfcaa --- /dev/null +++ b/models/GAN/notebook.gen.ipynb @@ -0,0 +1,1064 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import json\n", + "import os\n", + "from pathlib import Path\n", + "from typing import Any, Dict, List, Tuple\n", + "import matplotlib.pyplot as plt\n", + "import torch\n", + "import torch.nn as nn\n", + "import torch.nn.functional as F\n", + "from PIL import Image\n", + "from torch.utils.data import DataLoader, Dataset, Subset\n", + "from torchvision import transforms\n", + "from tqdm import tqdm\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": "# Configuration\n\nGlobal settings for the Google -> Yandex GAN:\n- Dataset path and image size\n- Optimizer and training hyperparameters\n- Device preference with safe CUDA compatibility checks\n- GAN, L1, SSIM and edge reconstruction weights\n" + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "\"\"\"Configuration for GAN training.\"\"\"\n", + "\n", + "\n", + "def create_config():\n", + " \"\"\"Create default configuration dictionary.\"\"\"\n", + " return {\n", + " # Optimizer params\n", + " \"learning_rate\": 2e-4,\n", + " \"discriminator_lr_factor\": 0.5,\n", + " \"beta1\": 0.5,\n", + " \"beta2\": 0.999,\n", + " # Training params\n", + " \"batch_size\": 32,\n", + " \"epochs\": 100,\n", + " \"prefer_cuda\": True,\n", + " # GAN params\n", + " \"gan_mode\": \"lsgan\",\n", + " \"lambda_GAN\": 0.5,\n", + " \"lambda_L1\": 150.0,\n", + " \"lambda_SSIM\": 25.0,\n", + " \"lambda_edge\": 20.0,\n", + " \"discriminator_update_interval\": 1,\n", + " # Regularization\n", + " \"grad_clip\": 1.0,\n", + " # Early stopping\n", + " \"early_stopping_patience\": 25,\n", + " # Output\n", + " \"output_dir\": r\"C:\\Users\\admin\\Projects\\autopilot\\models\\GAN\\runs\",\n", + " # Logging\n", + " \"log_interval\": 10,\n", + " \"save_interval\": 5,\n", + " \"num_visual_samples\": 4,\n", + " # Data\n", + " \"data_dir\": r\"C:\\Users\\admin\\Projects\\autopilot\\datasets\\ya_go_maps\\images\",\n", + " \"image_size\": [256, 256],\n", + " \"train_split\": 0.8,\n", + " \"num_workers\": 0,\n", + " }\n", + "\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": "## Dataset\n\nGoogle/Yandex paired image loader.\n\n**Direction:**\n- `google_img` is the generator input\n- `yandex_img` is the target image from the same pair\n\n**Returns:**\n" + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "\"\"\"Data loader for Google-to-Yandex image translation.\"\"\"\n", + "\n", + "\n", + "\n", + "\n", + "class YaGoDataset(Dataset):\n", + " \"\"\"Dataset loading paired Google and Yandex map images.\"\"\"\n", + "\n", + " def __init__(\n", + " self,\n", + " root_dir: str,\n", + " image_size: Tuple[int, int] = (256, 256),\n", + " augment: bool = False,\n", + " ):\n", + " \"\"\"\n", + " Args:\n", + " root_dir: Directory with images named {idx:04d}_google.png and {idx:04d}_yandex.png\n", + " image_size: Target image size (height, width)\n", + " augment: Whether to apply augmentation (not implemented for simplicity)\n", + " \"\"\"\n", + " self.root_dir = root_dir\n", + " self.image_size = image_size\n", + " self.augment = augment\n", + "\n", + " # Discover image pairs\n", + " self.pairs = self._find_pairs()\n", + "\n", + " # Transform to tensor + normalization\n", + " self.transform = transforms.Compose(\n", + " [\n", + " transforms.ToTensor(),\n", + " transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),\n", + " ]\n", + " )\n", + "\n", + " def _find_pairs(self) -> List[Dict]:\n", + " \"\"\"Find all matching Google-Yandex image pairs.\"\"\"\n", + " pairs = []\n", + " google_files = [f for f in os.listdir(self.root_dir) if f.endswith(\"_google.png\")]\n", + "\n", + " for google_file in sorted(google_files):\n", + " idx_str = google_file.split(\"_\")[0]\n", + " try:\n", + " idx = int(idx_str)\n", + " except ValueError:\n", + " continue\n", + "\n", + " yandex_file = f\"{idx:04d}_yandex.png\"\n", + " yandex_path = os.path.join(self.root_dir, yandex_file)\n", + "\n", + " if os.path.exists(yandex_path):\n", + " pairs.append(\n", + " {\n", + " \"idx\": idx,\n", + " \"google_path\": os.path.join(self.root_dir, google_file),\n", + " \"yandex_path\": yandex_path,\n", + " }\n", + " )\n", + "\n", + " return pairs\n", + "\n", + " def __len__(self) -> int:\n", + " return len(self.pairs)\n", + "\n", + " def __getitem__(self, idx: int) -> dict:\n", + " pair = self.pairs[idx]\n", + "\n", + " # Load images\n", + " google_img = Image.open(pair[\"google_path\"]).convert(\"RGB\")\n", + " yandex_img = Image.open(pair[\"yandex_path\"]).convert(\"RGB\")\n", + "\n", + " # Resize\n", + " google_img = google_img.resize((self.image_size[1], self.image_size[0]))\n", + " yandex_img = yandex_img.resize((self.image_size[1], self.image_size[0]))\n", + "\n", + " # Apply transforms\n", + " google_tensor = self.transform(google_img)\n", + " yandex_tensor = self.transform(yandex_img)\n", + "\n", + " return {\n", + " \"google_img\": google_tensor,\n", + " \"yandex_img\": yandex_tensor,\n", + " \"idx\": torch.tensor(pair[\"idx\"], dtype=torch.long),\n", + " }\n", + "\n", + "\n", + "def create_data_loaders(\n", + " root_dir: str,\n", + " batch_size: int = 32,\n", + " train_split: float = 0.8,\n", + " num_workers: int = 0,\n", + " image_size: Tuple[int, int] = (256, 256),\n", + ") -> Tuple[DataLoader, DataLoader]:\n", + " \"\"\"\n", + " Create train and validation data loaders.\n", + "\n", + " Args:\n", + " root_dir: Directory with image pairs\n", + " batch_size: Batch size\n", + " train_split: Fraction for training (0.0-1.0)\n", + " num_workers: DataLoader workers\n", + " image_size: Target image size\n", + "\n", + " Returns:\n", + " (train_loader, val_loader)\n", + " \"\"\"\n", + " # Full dataset\n", + " dataset = YaGoDataset(root_dir=root_dir, image_size=image_size)\n", + "\n", + " # Split\n", + " dataset_size = len(dataset)\n", + " train_size = int(train_split * dataset_size)\n", + " indices = torch.randperm(dataset_size).tolist()\n", + " train_indices = indices[:train_size]\n", + " val_indices = indices[train_size:]\n", + "\n", + " # Subsets\n", + "\n", + " train_dataset = Subset(dataset, train_indices)\n", + " val_dataset = Subset(dataset, val_indices)\n", + "\n", + " # DataLoaders\n", + " train_loader = DataLoader(\n", + " train_dataset,\n", + " batch_size=batch_size,\n", + " shuffle=True,\n", + " num_workers=num_workers,\n", + " pin_memory=True,\n", + " )\n", + "\n", + " val_loader = DataLoader(\n", + " val_dataset,\n", + " batch_size=batch_size,\n", + " shuffle=False,\n", + " num_workers=num_workers,\n", + " pin_memory=True,\n", + " )\n", + "\n", + " return train_loader, val_loader\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "\n", + "def get_dataset_info():\n", + " config = create_config()\n", + " dataset = YaGoDataset(\n", + " root_dir=config[\"data_dir\"],\n", + " image_size=tuple(config[\"image_size\"]),\n", + " )\n", + " sample = dataset[0] if len(dataset) else {}\n", + " return {\n", + " \"size\": len(dataset),\n", + " \"sample_keys\": list(sample.keys()),\n", + " \"google_shape\": tuple(sample[\"google_img\"].shape) if sample else None,\n", + " \"yandex_shape\": tuple(sample[\"yandex_img\"].shape) if sample else None,\n", + " }\n", + "\n", + "\n", + "def smoke_test_dataloader(batch_size=4):\n", + " config = create_config()\n", + " train_loader, val_loader = create_data_loaders(\n", + " root_dir=config[\"data_dir\"],\n", + " batch_size=batch_size,\n", + " train_split=config[\"train_split\"],\n", + " num_workers=config[\"num_workers\"],\n", + " image_size=tuple(config[\"image_size\"]),\n", + " )\n", + " batch = next(iter(train_loader))\n", + " return {\n", + " \"train_size\": len(train_loader.dataset),\n", + " \"val_size\": len(val_loader.dataset),\n", + " \"google_batch_shape\": tuple(batch[\"google_img\"].shape),\n", + " \"yandex_batch_shape\": tuple(batch[\"yandex_img\"].shape),\n", + " }\n", + "\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": "## Model\n\nPix2pix-style GAN for Google -> Yandex map translation.\n\n**Generator:**\n- `GeneratorUNet`\n- Input: Google image `(B, 3, H, W)`\n- Output: generated Yandex image `(B, 3, H, W)`\n\n**Discriminator:**\n- `DiscriminatorPatchGAN`\n- Input pair: `(google_img, yandex_img)`\n- Learns to distinguish real pairs from `(google_img, fake_yandex)`\n\n**Generator loss:**\n- adversarial loss\n- `lambda_L1 * L1(fake_yandex, yandex_img)`\n- `lambda_SSIM * SSIMLoss(fake_yandex, yandex_img)`\n- `lambda_edge * SobelEdgeLoss(fake_yandex, yandex_img)`\n\nThe generator uses bilinear upsampling followed by convolution to avoid\n" + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "\"\"\"GAN model for image translation Google -> Yandex.\"\"\"\n", + "\n", + "\n", + "\n", + "def get_compatible_device(prefer_cuda: bool = True, verbose: bool = True) -> torch.device:\n", + " \"\"\"Return CUDA only when the current PyTorch build supports the GPU arch.\"\"\"\n", + " if not prefer_cuda or not torch.cuda.is_available():\n", + " return torch.device(\"cpu\")\n", + "\n", + " try:\n", + " major, minor = torch.cuda.get_device_capability()\n", + " arch = f\"sm_{major}{minor}\"\n", + " supported_arches = set(torch.cuda.get_arch_list())\n", + " gpu_name = torch.cuda.get_device_name()\n", + " except Exception as exc:\n", + " if verbose:\n", + " print(f\"CUDA is visible but cannot be inspected ({exc}); using CPU.\")\n", + " return torch.device(\"cpu\")\n", + "\n", + " if supported_arches and arch not in supported_arches:\n", + " if verbose:\n", + " supported = \", \".join(sorted(supported_arches))\n", + " print(\n", + " f\"CUDA GPU '{gpu_name}' has capability {arch}, but this PyTorch build \"\n", + " f\"supports only: {supported}. Using CPU.\"\n", + " )\n", + " return torch.device(\"cpu\")\n", + "\n", + " return torch.device(\"cuda\")\n", + "\n", + "\n", + "class UNetDownBlock(nn.Module):\n", + " \"\"\"Downsampling block for U-Net.\"\"\"\n", + "\n", + " def __init__(self, in_channels: int, out_channels: int, normalize: bool = True, dropout: float = 0.0):\n", + " super().__init__()\n", + " layers = [\n", + " nn.Conv2d(in_channels, out_channels, kernel_size=4, stride=2, padding=1, bias=False)\n", + " ]\n", + " if normalize:\n", + " layers.append(nn.BatchNorm2d(out_channels))\n", + " layers.append(nn.LeakyReLU(0.2, inplace=True))\n", + " if dropout > 0:\n", + " layers.append(nn.Dropout2d(dropout))\n", + " self.model = nn.Sequential(*layers)\n", + "\n", + " def forward(self, x):\n", + " return self.model(x)\n", + "\n", + "\n", + "class UNetUpBlock(nn.Module):\n", + " \"\"\"Upsampling block for U-Net.\"\"\"\n", + "\n", + " def __init__(self, in_channels: int, out_channels: int, dropout: float = 0.0):\n", + " super().__init__()\n", + " self.upconv = nn.Sequential(\n", + " nn.Upsample(scale_factor=2, mode=\"bilinear\", align_corners=False),\n", + " nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, bias=False),\n", + " )\n", + " self.norm = nn.BatchNorm2d(out_channels)\n", + " self.relu = nn.ReLU(inplace=True)\n", + " if dropout > 0:\n", + " self.dropout = nn.Dropout2d(dropout)\n", + " else:\n", + " self.dropout = None\n", + "\n", + " def forward(self, x, skip_input):\n", + " x = self.upconv(x)\n", + " # Pad if needed to match skip connection size\n", + " if x.shape != skip_input.shape:\n", + " diff_h = skip_input.size(2) - x.size(2)\n", + " diff_w = skip_input.size(3) - x.size(3)\n", + " x = F.pad(x, [diff_w // 2, diff_w - diff_w // 2, diff_h // 2, diff_h - diff_h // 2])\n", + " x = self.norm(x)\n", + " x = self.relu(x)\n", + " if self.dropout:\n", + " x = self.dropout(x)\n", + " x = torch.cat([x, skip_input], dim=1)\n", + " return x\n", + "\n", + "\n", + "class GeneratorUNet(nn.Module):\n", + " \"\"\"U-Net generator for Google -> Yandex translation.\"\"\"\n", + "\n", + " def __init__(self, in_channels: int = 3, out_channels: int = 3):\n", + " super().__init__()\n", + "\n", + " # Downsampling\n", + " self.down1 = UNetDownBlock(in_channels, 64, normalize=False)\n", + " self.down2 = UNetDownBlock(64, 128)\n", + " self.down3 = UNetDownBlock(128, 256)\n", + " self.down4 = UNetDownBlock(256, 512)\n", + " self.down5 = UNetDownBlock(512, 512)\n", + " self.down6 = UNetDownBlock(512, 512)\n", + " self.down7 = UNetDownBlock(512, 512)\n", + "\n", + " # Bottleneck\n", + " self.bottleneck = nn.Sequential(\n", + " nn.Conv2d(512, 512, kernel_size=4, stride=2, padding=1, bias=False),\n", + " nn.ReLU(inplace=True),\n", + " )\n", + "\n", + " # Upsampling - input channels from previous layer, output before concat\n", + " self.up1 = UNetUpBlock(512, 512, dropout=0.5) # in: 512 (bottleneck) -> out: 512, concat with d7 (512) = 1024\n", + " self.up2 = UNetUpBlock(1024, 512, dropout=0.5) # in: 1024 -> out: 512, concat with d6 (512) = 1024\n", + " self.up3 = UNetUpBlock(1024, 512, dropout=0.5) # in: 1024 -> out: 512, concat with d5 (512) = 1024\n", + " self.up4 = UNetUpBlock(1024, 512) # in: 1024 -> out: 512, concat with d4 (512) = 1024\n", + " self.up5 = UNetUpBlock(1024, 256) # in: 1024 -> out: 256, concat with d3 (256) = 512\n", + " self.up6 = UNetUpBlock(512, 128) # in: 512 -> out: 128, concat with d2 (128) = 256\n", + " self.up7 = UNetUpBlock(256, 64) # in: 256 -> out: 64, concat with d1 (64) = 128\n", + "\n", + " # Final\n", + " self.final = nn.Sequential(\n", + " nn.Upsample(scale_factor=2, mode=\"bilinear\", align_corners=False),\n", + " nn.Conv2d(128, out_channels, kernel_size=3, padding=1),\n", + " nn.Tanh(),\n", + " )\n", + "\n", + " def forward(self, x):\n", + " # Down\n", + " d1 = self.down1(x)\n", + " d2 = self.down2(d1)\n", + " d3 = self.down3(d2)\n", + " d4 = self.down4(d3)\n", + " d5 = self.down5(d4)\n", + " d6 = self.down6(d5)\n", + " d7 = self.down7(d6)\n", + "\n", + " # Bottleneck\n", + " u = self.bottleneck(d7)\n", + "\n", + " # Up with skip connections\n", + " u = self.up1(u, d7)\n", + " u = self.up2(u, d6)\n", + " u = self.up3(u, d5)\n", + " u = self.up4(u, d4)\n", + " u = self.up5(u, d3)\n", + " u = self.up6(u, d2)\n", + " u = self.up7(u, d1)\n", + "\n", + " return self.final(u)\n", + "\n", + "\n", + "class DiscriminatorPatchGAN(nn.Module):\n", + " \"\"\"PatchGAN discriminator for paired source/target images.\"\"\"\n", + "\n", + " def __init__(self, in_channels: int = 6):\n", + " super().__init__()\n", + " self.model = nn.Sequential(\n", + " nn.Conv2d(in_channels, 64, kernel_size=4, stride=2, padding=1),\n", + " nn.LeakyReLU(0.2, inplace=True),\n", + " nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),\n", + " nn.BatchNorm2d(128),\n", + " nn.LeakyReLU(0.2, inplace=True),\n", + " nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1),\n", + " nn.BatchNorm2d(256),\n", + " nn.LeakyReLU(0.2, inplace=True),\n", + " nn.Conv2d(256, 512, kernel_size=4, stride=2, padding=1),\n", + " nn.BatchNorm2d(512),\n", + " nn.LeakyReLU(0.2, inplace=True),\n", + " nn.Conv2d(512, 1, kernel_size=4, stride=1, padding=1),\n", + " )\n", + "\n", + " def forward(self, img_A, img_B):\n", + " x = torch.cat([img_A, img_B], dim=1)\n", + " return self.model(x)\n", + "\n", + "\n", + "class GANLoss(nn.Module):\n", + " \"\"\"GAN loss supporting different GAN modes.\"\"\"\n", + "\n", + " def __init__(self, gan_mode: str = \"vanilla\", target_real: float = 1.0, target_fake: float = 0.0):\n", + " super().__init__()\n", + " self.gan_mode = gan_mode\n", + " self.register_buffer(\"real_label\", torch.tensor(target_real))\n", + " self.register_buffer(\"fake_label\", torch.tensor(target_fake))\n", + "\n", + " if gan_mode == \"vanilla\":\n", + " self.loss_fn = nn.BCEWithLogitsLoss()\n", + " elif gan_mode == \"lsgan\":\n", + " self.loss_fn = nn.MSELoss()\n", + " elif gan_mode == \"wgangp\":\n", + " self.loss_fn = None\n", + " else:\n", + " raise ValueError(f\"Unknown GAN mode: {gan_mode}\")\n", + "\n", + " def forward(self, prediction: torch.Tensor, target_is_real: bool) -> torch.Tensor:\n", + " if self.gan_mode in [\"vanilla\", \"lsgan\"]:\n", + " target = self.real_label if target_is_real else self.fake_label\n", + " target = target.expand_as(prediction)\n", + " return self.loss_fn(prediction, target)\n", + " elif self.gan_mode == \"wgangp\":\n", + " return -prediction.mean() if target_is_real else prediction.mean()\n", + "\n", + "\n", + "class SSIMLoss(nn.Module):\n", + " \"\"\"Local SSIM loss for normalized image tensors in [-1, 1].\"\"\"\n", + "\n", + " def __init__(self, window_size: int = 11, c1: float = 0.01 ** 2, c2: float = 0.03 ** 2):\n", + " super().__init__()\n", + " self.window_size = window_size\n", + " self.c1 = c1\n", + " self.c2 = c2\n", + "\n", + " def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:\n", + " pred = (pred + 1.0) * 0.5\n", + " target = (target + 1.0) * 0.5\n", + " padding = self.window_size // 2\n", + "\n", + " mu_pred = F.avg_pool2d(pred, self.window_size, stride=1, padding=padding)\n", + " mu_target = F.avg_pool2d(target, self.window_size, stride=1, padding=padding)\n", + " mu_pred_sq = mu_pred.pow(2)\n", + " mu_target_sq = mu_target.pow(2)\n", + " mu_pred_target = mu_pred * mu_target\n", + "\n", + " sigma_pred = F.avg_pool2d(pred * pred, self.window_size, stride=1, padding=padding) - mu_pred_sq\n", + " sigma_target = F.avg_pool2d(target * target, self.window_size, stride=1, padding=padding) - mu_target_sq\n", + " sigma_pred_target = F.avg_pool2d(pred * target, self.window_size, stride=1, padding=padding) - mu_pred_target\n", + "\n", + " ssim_map = (\n", + " (2 * mu_pred_target + self.c1) * (2 * sigma_pred_target + self.c2)\n", + " ) / (\n", + " (mu_pred_sq + mu_target_sq + self.c1) * (sigma_pred + sigma_target + self.c2)\n", + " )\n", + " return (1.0 - ssim_map.clamp(0, 1)).mean()\n", + "\n", + "\n", + "class SobelEdgeLoss(nn.Module):\n", + " \"\"\"L1 loss between Sobel edge maps, useful for stable keypoint structure.\"\"\"\n", + "\n", + " def __init__(self):\n", + " super().__init__()\n", + " kernel_x = torch.tensor(\n", + " [[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]],\n", + " dtype=torch.float32,\n", + " ).view(1, 1, 3, 3)\n", + " kernel_y = torch.tensor(\n", + " [[-1, -2, -1], [0, 0, 0], [1, 2, 1]],\n", + " dtype=torch.float32,\n", + " ).view(1, 1, 3, 3)\n", + " self.register_buffer(\"kernel_x\", kernel_x)\n", + " self.register_buffer(\"kernel_y\", kernel_y)\n", + "\n", + " @staticmethod\n", + " def _to_gray(x: torch.Tensor) -> torch.Tensor:\n", + " x = (x + 1.0) * 0.5\n", + " weights = x.new_tensor([0.299, 0.587, 0.114]).view(1, 3, 1, 1)\n", + " return (x * weights).sum(dim=1, keepdim=True)\n", + "\n", + " def _edges(self, x: torch.Tensor) -> torch.Tensor:\n", + " gray = self._to_gray(x)\n", + " grad_x = F.conv2d(gray, self.kernel_x, padding=1)\n", + " grad_y = F.conv2d(gray, self.kernel_y, padding=1)\n", + " return torch.sqrt(grad_x.pow(2) + grad_y.pow(2) + 1e-6)\n", + "\n", + " def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:\n", + " return F.l1_loss(self._edges(pred), self._edges(target))\n", + "\n", + "\n", + "class ImageGAN(nn.Module):\n", + " \"\"\"Complete pix2pix-style GAN for Google -> Yandex image translation.\"\"\"\n", + "\n", + " def __init__(\n", + " self,\n", + " input_channels: int = 3,\n", + " output_channels: int = 3,\n", + " gan_mode: str = \"lsgan\",\n", + " lambda_L1: float = 150.0,\n", + " lambda_GAN: float = 0.5,\n", + " lambda_SSIM: float = 25.0,\n", + " lambda_edge: float = 20.0,\n", + " use_cuda: bool = True,\n", + " ):\n", + " super().__init__()\n", + " self.generator = GeneratorUNet(input_channels, output_channels)\n", + " self.discriminator = DiscriminatorPatchGAN(input_channels + output_channels)\n", + " self.gan_loss = GANLoss(gan_mode)\n", + " self.l1_loss = nn.L1Loss()\n", + " self.ssim_loss = SSIMLoss()\n", + " self.edge_loss = SobelEdgeLoss()\n", + " self.lambda_L1 = lambda_L1\n", + " self.lambda_GAN = lambda_GAN\n", + " self.lambda_SSIM = lambda_SSIM\n", + " self.lambda_edge = lambda_edge\n", + "\n", + " self.device = get_compatible_device(prefer_cuda=use_cuda)\n", + " self.to(self.device)\n", + "\n", + " def forward(self, google_image):\n", + " \"\"\"Generate a Yandex-style image from a Google image.\"\"\"\n", + " return self.generator(google_image)\n", + "\n", + " def generator_step(self, google_img, real_yandex_img):\n", + " \"\"\"Compute generator losses against the paired original Yandex image.\"\"\"\n", + " fake_yandex = self.generator(google_img)\n", + " fake_pred = self.discriminator(google_img, fake_yandex)\n", + " gan_loss = self.gan_loss(fake_pred, True) * self.lambda_GAN\n", + " l1_loss = self.l1_loss(fake_yandex, real_yandex_img) * self.lambda_L1\n", + " ssim_loss = self.ssim_loss(fake_yandex, real_yandex_img) * self.lambda_SSIM\n", + " edge_loss = self.edge_loss(fake_yandex, real_yandex_img) * self.lambda_edge\n", + " total_loss = gan_loss + l1_loss + ssim_loss + edge_loss\n", + " return total_loss, gan_loss, l1_loss, ssim_loss, edge_loss\n", + "\n", + " def discriminator_step(self, google_img, real_yandex_img, fake_yandex_img):\n", + " \"\"\"Compute discriminator losses for real and generated Yandex targets.\"\"\"\n", + " real_pred = self.discriminator(google_img, real_yandex_img)\n", + " real_loss = self.gan_loss(real_pred, True)\n", + " fake_pred = self.discriminator(google_img, fake_yandex_img.detach())\n", + " fake_loss = self.gan_loss(fake_pred, False)\n", + " total_loss = (real_loss + fake_loss) * 0.5\n", + " return total_loss, real_loss, fake_loss\n", + "\n", + "\n", + "def create_gan(\n", + " input_channels: int = 3,\n", + " output_channels: int = 3,\n", + " gan_mode: str = \"lsgan\",\n", + " lambda_L1: float = 150.0,\n", + " lambda_GAN: float = 0.5,\n", + " lambda_SSIM: float = 25.0,\n", + " lambda_edge: float = 20.0,\n", + " use_cuda: bool = True,\n", + ") -> ImageGAN:\n", + " \"\"\"Create a GAN model.\"\"\"\n", + " return ImageGAN(\n", + " input_channels=input_channels,\n", + " output_channels=output_channels,\n", + " gan_mode=gan_mode,\n", + " lambda_L1=lambda_L1,\n", + " lambda_GAN=lambda_GAN,\n", + " lambda_SSIM=lambda_SSIM,\n", + " lambda_edge=lambda_edge,\n", + " use_cuda=use_cuda,\n", + " )\n", + "\n", + "\n", + "def initialize_weights(model: nn.Module):\n", + " \"\"\"Initialize model weights.\"\"\"\n", + " for m in model.modules():\n", + " if isinstance(m, nn.Conv2d):\n", + " nn.init.normal_(m.weight.data, 0.0, 0.02)\n", + " elif isinstance(m, nn.BatchNorm2d):\n", + " nn.init.normal_(m.weight.data, 1.0, 0.02)\n", + " nn.init.constant_(m.bias.data, 0.0)\n", + "\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": "## Training\n\n`GANTrainer` trains discriminator and generator alternately.\n\n**Training step:**\n1. Generate `fake_yandex = G(google_img)`\n2. Train discriminator on real pair `(google_img, yandex_img)` and fake pair `(google_img, fake_yandex)`\n3. Train generator against discriminator and paired Yandex target\n\n**Checkpoint saving:**\n- `best.pth`\n- `epoch_N.pth`\n" + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "\"\"\"Trainer for GAN model.\"\"\"\n", + "\n", + "\n", + "\n", + "\n", + "class GANTrainer:\n", + " \"\"\"Simple GAN trainer.\"\"\"\n", + "\n", + " def __init__(\n", + " self,\n", + " model: torch.nn.Module,\n", + " train_loader: DataLoader,\n", + " val_loader: DataLoader,\n", + " config: Dict[str, Any],\n", + " ):\n", + " self.model = model\n", + " self.train_loader = train_loader\n", + " self.val_loader = val_loader\n", + " self.config = config\n", + " self.device = model.device\n", + "\n", + " # Optimizers\n", + " lr = config.get(\"learning_rate\", 2e-4)\n", + " lr_d = config.get(\"discriminator_learning_rate\", lr * config.get(\"discriminator_lr_factor\", 0.5))\n", + " beta1 = config.get(\"beta1\", 0.5)\n", + " beta2 = config.get(\"beta2\", 0.999)\n", + " self.opt_G = torch.optim.Adam(model.generator.parameters(), lr=lr, betas=(beta1, beta2))\n", + " self.opt_D = torch.optim.Adam(model.discriminator.parameters(), lr=lr_d, betas=(beta1, beta2))\n", + "\n", + " # Training state\n", + " self.current_epoch = 0\n", + " self.best_val_loss = float(\"inf\")\n", + " self.g_losses = []\n", + " self.d_losses = []\n", + " self.l1_losses = []\n", + " self.ssim_losses = []\n", + " self.edge_losses = []\n", + " self.val_g_losses = []\n", + " self.val_d_losses = []\n", + " self.val_l1_losses = []\n", + " self.val_ssim_losses = []\n", + " self.val_edge_losses = []\n", + " self.val_reconstruction_losses = []\n", + "\n", + " # Output dir\n", + " self.output_dir = Path(config.get(\"output_dir\", \"runs/gan\"))\n", + " self.output_dir.mkdir(parents=True, exist_ok=True)\n", + " (self.output_dir / \"checkpoints\").mkdir(exist_ok=True)\n", + "\n", + " # Save config\n", + " with open(self.output_dir / \"config.json\", \"w\") as f:\n", + " json.dump(config, f, indent=2)\n", + "\n", + " def train_epoch(self) -> Tuple[float, float]:\n", + " \"\"\"Train for one epoch.\"\"\"\n", + " self.model.train()\n", + " total_g = total_d = 0.0\n", + " total_l1 = total_ssim = total_edge = 0.0\n", + " num_batches = len(self.train_loader)\n", + " d_update_interval = max(1, self.config.get(\"discriminator_update_interval\", 1))\n", + "\n", + " pbar = tqdm(enumerate(self.train_loader), total=num_batches, desc=f\"Epoch {self.current_epoch + 1}\")\n", + " for batch_idx, batch in pbar:\n", + " google_img = batch[\"google_img\"].to(self.device)\n", + " yandex_img = batch[\"yandex_img\"].to(self.device)\n", + "\n", + " # Train D\n", + " if batch_idx % d_update_interval == 0:\n", + " self.opt_D.zero_grad()\n", + " with torch.no_grad():\n", + " fake_img = self.model.generator(google_img)\n", + " d_loss = self.model.discriminator_step(google_img, yandex_img, fake_img)[0]\n", + " d_loss.backward()\n", + " self.opt_D.step()\n", + " else:\n", + " d_loss = google_img.new_tensor(0.0)\n", + "\n", + " # Train G\n", + " self.opt_G.zero_grad()\n", + " g_loss, gan_loss, l1_loss, ssim_loss, edge_loss = self.model.generator_step(google_img, yandex_img)\n", + " g_loss.backward()\n", + " self.opt_G.step()\n", + "\n", + " total_g += g_loss.item()\n", + " total_d += d_loss.item()\n", + " total_l1 += l1_loss.item()\n", + " total_ssim += ssim_loss.item()\n", + " total_edge += edge_loss.item()\n", + " pbar.set_postfix({\n", + " \"g_loss\": g_loss.item(),\n", + " \"d_loss\": d_loss.item(),\n", + " \"l1\": l1_loss.item(),\n", + " \"ssim\": ssim_loss.item(),\n", + " \"edge\": edge_loss.item(),\n", + " })\n", + "\n", + " avg_g = total_g / num_batches\n", + " avg_d = total_d / num_batches\n", + " avg_l1 = total_l1 / num_batches\n", + " avg_ssim = total_ssim / num_batches\n", + " avg_edge = total_edge / num_batches\n", + " self.g_losses.append(avg_g)\n", + " self.d_losses.append(avg_d)\n", + " self.l1_losses.append(avg_l1)\n", + " self.ssim_losses.append(avg_ssim)\n", + " self.edge_losses.append(avg_edge)\n", + " return avg_g, avg_d\n", + "\n", + " @torch.no_grad()\n", + " def validate(self) -> Tuple[float, float]:\n", + " \"\"\"Validate the model.\"\"\"\n", + " self.model.eval()\n", + " total_g = total_d = 0.0\n", + " total_l1 = total_ssim = total_edge = 0.0\n", + "\n", + " for batch in tqdm(self.val_loader, desc=\"Val\"):\n", + " google_img = batch[\"google_img\"].to(self.device)\n", + " yandex_img = batch[\"yandex_img\"].to(self.device)\n", + " fake_img = self.model.generator(google_img)\n", + " g_loss, _, l1_loss, ssim_loss, edge_loss = self.model.generator_step(google_img, yandex_img)\n", + " d_loss = self.model.discriminator_step(google_img, yandex_img, fake_img)[0]\n", + " total_g += g_loss.item()\n", + " total_d += d_loss.item()\n", + " total_l1 += l1_loss.item()\n", + " total_ssim += ssim_loss.item()\n", + " total_edge += edge_loss.item()\n", + "\n", + " avg_g = total_g / len(self.val_loader)\n", + " avg_d = total_d / len(self.val_loader)\n", + " avg_l1 = total_l1 / len(self.val_loader)\n", + " avg_ssim = total_ssim / len(self.val_loader)\n", + " avg_edge = total_edge / len(self.val_loader)\n", + " avg_reconstruction = avg_l1 + avg_ssim + avg_edge\n", + " self.val_g_losses.append(avg_g)\n", + " self.val_d_losses.append(avg_d)\n", + " self.val_l1_losses.append(avg_l1)\n", + " self.val_ssim_losses.append(avg_ssim)\n", + " self.val_edge_losses.append(avg_edge)\n", + " self.val_reconstruction_losses.append(avg_reconstruction)\n", + " return avg_g, avg_d\n", + "\n", + " def train(self, num_epochs: int):\n", + " \"\"\"Train the model.\"\"\"\n", + " print(f\"Training for {num_epochs} epochs...\")\n", + "\n", + " for epoch in range(num_epochs):\n", + " self.current_epoch = epoch\n", + "\n", + " # Train & validate\n", + " train_g, train_d = self.train_epoch()\n", + " val_g, val_d = self.validate()\n", + "\n", + " val_reconstruction = self.val_reconstruction_losses[-1]\n", + " if val_reconstruction < self.best_val_loss:\n", + " self.best_val_loss = val_reconstruction\n", + " self.save_checkpoint(\"best\")\n", + "\n", + " # Periodic checkpoint\n", + " if (epoch + 1) % self.config.get(\"save_interval\", 5) == 0:\n", + " self.save_checkpoint(f\"epoch_{epoch + 1}\")\n", + "\n", + " print(\n", + " f\"Epoch {epoch + 1}: \"\n", + " f\"train_g={train_g:.4f}, train_d={train_d:.4f}, \"\n", + " f\"train_l1={self.l1_losses[-1]:.4f}, train_ssim={self.ssim_losses[-1]:.4f}, \"\n", + " f\"train_edge={self.edge_losses[-1]:.4f}, val_g={val_g:.4f}, val_d={val_d:.4f}, \"\n", + " f\"val_l1={self.val_l1_losses[-1]:.4f}, val_ssim={self.val_ssim_losses[-1]:.4f}, \"\n", + " f\"val_edge={self.val_edge_losses[-1]:.4f}, val_rec={val_reconstruction:.4f}\"\n", + " )\n", + "\n", + " # Early stopping\n", + " patience = self.config.get(\"early_stopping_patience\", 0)\n", + " if patience > 0 and len(self.val_reconstruction_losses) > patience:\n", + " recent = self.val_reconstruction_losses[-patience:]\n", + " previous_best = min(self.val_reconstruction_losses[:-patience])\n", + " if all(loss >= previous_best for loss in recent):\n", + " print(f\"Early stopping at epoch {epoch + 1}\")\n", + " break\n", + "\n", + " # Save final\n", + " self.save_checkpoint(\"final\")\n", + " print(f\"Training finished. Best val loss: {self.best_val_loss:.4f}\")\n", + "\n", + " def save_checkpoint(self, name: str):\n", + " \"\"\"Save model checkpoint.\"\"\"\n", + " path = self.output_dir / \"checkpoints\" / f\"{name}.pth\"\n", + " torch.save({\n", + " \"epoch\": self.current_epoch,\n", + " \"generator\": self.model.generator.state_dict(),\n", + " \"discriminator\": self.model.discriminator.state_dict(),\n", + " \"opt_G\": self.opt_G.state_dict(),\n", + " \"opt_D\": self.opt_D.state_dict(),\n", + " }, path)\n", + "\n", + "\n", + "def create_trainer(\n", + " model: torch.nn.Module,\n", + " train_loader: DataLoader,\n", + " val_loader: DataLoader,\n", + " config: Dict[str, Any],\n", + ") -> GANTrainer:\n", + " \"\"\"Create a trainer instance.\"\"\"\n", + " return GANTrainer(model, train_loader, val_loader, config)\n", + "\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": "## Analysis\n\nVisualization helpers for generated samples and collected training metrics.\n\nTraining history plot contains:\n1. Generator loss\n2. Discriminator loss\n3. L1 loss against the paired Yandex target\n4. SSIM structure loss\n5. Sobel edge loss\n6. Best-checkpoint reconstruction score\n\nThe sample grid contains:\n1. Google input\n2. Generated Yandex\n" + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "\n", + "\n", + "def denormalize_image(tensor):\n", + " return (tensor.detach().cpu() * 0.5 + 0.5).clamp(0, 1)\n", + "\n", + "\n", + "@torch.no_grad()\n", + "def visualize_generation(model, data_loader, output_dir, device=None, num_samples=4, show=True):\n", + " device = device or model.device\n", + " model.eval()\n", + "\n", + " batch = next(iter(data_loader))\n", + " google_img = batch[\"google_img\"][:num_samples].to(device)\n", + " yandex_img = batch[\"yandex_img\"][:num_samples].to(device)\n", + " fake_yandex = model.generator(google_img)\n", + "\n", + " output_dir = Path(output_dir)\n", + " output_dir.mkdir(parents=True, exist_ok=True)\n", + " output_path = output_dir / \"generation_samples.png\"\n", + "\n", + " fig, axes = plt.subplots(num_samples, 3, figsize=(9, 3 * num_samples))\n", + " if num_samples == 1:\n", + " axes = axes.reshape(1, 3)\n", + "\n", + " titles = [\"Google input\", \"Generated Yandex\", \"Yandex target\"]\n", + " for row in range(num_samples):\n", + " images = [google_img[row], fake_yandex[row], yandex_img[row]]\n", + " for col, image in enumerate(images):\n", + " axes[row, col].imshow(denormalize_image(image).permute(1, 2, 0))\n", + " axes[row, col].set_title(titles[col])\n", + " axes[row, col].axis(\"off\")\n", + "\n", + " fig.tight_layout()\n", + " fig.savefig(output_path, dpi=150)\n", + " if show:\n", + " plt.show()\n", + " plt.close(fig)\n", + " return output_path\n", + "\n", + "\n", + "def plot_training_history(trainer, output_dir, show=True):\n", + " output_dir = Path(output_dir)\n", + " output_dir.mkdir(parents=True, exist_ok=True)\n", + " output_path = output_dir / \"training_history.png\"\n", + "\n", + " epochs = range(1, len(trainer.g_losses) + 1)\n", + " fig, axes = plt.subplots(2, 3, figsize=(15, 8))\n", + " axes = axes.ravel()\n", + "\n", + " axes[0].plot(epochs, trainer.g_losses, label=\"train G\")\n", + " axes[0].plot(epochs, trainer.val_g_losses, label=\"val G\")\n", + " axes[0].set_title(\"Generator loss\")\n", + " axes[0].set_xlabel(\"Epoch\")\n", + " axes[0].legend()\n", + " axes[0].grid(True, alpha=0.3)\n", + "\n", + " axes[1].plot(epochs, trainer.d_losses, label=\"train D\")\n", + " axes[1].plot(epochs, trainer.val_d_losses, label=\"val D\")\n", + " axes[1].set_title(\"Discriminator loss\")\n", + " axes[1].set_xlabel(\"Epoch\")\n", + " axes[1].legend()\n", + " axes[1].grid(True, alpha=0.3)\n", + "\n", + " axes[2].plot(epochs, trainer.l1_losses, label=\"train L1\")\n", + " axes[2].plot(epochs, trainer.val_l1_losses, label=\"val L1\")\n", + " axes[2].set_title(\"Paired Yandex L1 loss\")\n", + " axes[2].set_xlabel(\"Epoch\")\n", + " axes[2].legend()\n", + " axes[2].grid(True, alpha=0.3)\n", + "\n", + " axes[3].plot(epochs, trainer.ssim_losses, label=\"train SSIM\")\n", + " axes[3].plot(epochs, trainer.val_ssim_losses, label=\"val SSIM\")\n", + " axes[3].set_title(\"SSIM structure loss\")\n", + " axes[3].set_xlabel(\"Epoch\")\n", + " axes[3].legend()\n", + " axes[3].grid(True, alpha=0.3)\n", + "\n", + " axes[4].plot(epochs, trainer.edge_losses, label=\"train edge\")\n", + " axes[4].plot(epochs, trainer.val_edge_losses, label=\"val edge\")\n", + " axes[4].set_title(\"Sobel edge loss\")\n", + " axes[4].set_xlabel(\"Epoch\")\n", + " axes[4].legend()\n", + " axes[4].grid(True, alpha=0.3)\n", + "\n", + " axes[5].plot(epochs, trainer.val_reconstruction_losses, label=\"val reconstruction\")\n", + " axes[5].set_title(\"Best-checkpoint score\")\n", + " axes[5].set_xlabel(\"Epoch\")\n", + " axes[5].legend()\n", + " axes[5].grid(True, alpha=0.3)\n", + "\n", + " fig.tight_layout()\n", + " fig.savefig(output_path, dpi=150)\n", + " if show:\n", + " plt.show()\n", + " plt.close(fig)\n", + " return output_path\n", + "\n", + "\n", + "def analyze_training(trainer):\n", + " return {\n", + " \"best_val_loss\": trainer.best_val_loss,\n", + " \"g_losses\": trainer.g_losses,\n", + " \"d_losses\": trainer.d_losses,\n", + " \"l1_losses\": trainer.l1_losses,\n", + " \"ssim_losses\": trainer.ssim_losses,\n", + " \"edge_losses\": trainer.edge_losses,\n", + " \"val_g_losses\": trainer.val_g_losses,\n", + " \"val_d_losses\": trainer.val_d_losses,\n", + " \"val_l1_losses\": trainer.val_l1_losses,\n", + " \"val_ssim_losses\": trainer.val_ssim_losses,\n", + " \"val_edge_losses\": trainer.val_edge_losses,\n", + " \"val_reconstruction_losses\": trainer.val_reconstruction_losses,\n", + " }\n", + "\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": "## Main Pipeline\n\nExecutes the full GAN workflow:\n1. Create config\n2. Build paired data loaders\n3. Initialize Google -> Yandex GAN\n4. Train with validation\n5. Save checkpoints in `runs/checkpoints/`\n6. Show loss plots and generated sample grid\n\nThis block is intentionally top-level, not wrapped in `main()`, so notebook\nvariables such as `model`, `trainer`, `train_loader`, `val_loader`, and\n" + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "\"\"\"Executable GAN training pipeline.\n", + "\n", + "The code is intentionally top-level, mirroring the SiaN notebook style:\n", + "when this file is included in the generated notebook, variables remain\n", + "available for debugging and interactive experiments.\n", + "\"\"\"\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "config = create_config()\n", + "device = get_compatible_device(prefer_cuda=config[\"prefer_cuda\"])\n", + "print(f\"Using device: {device}\")\n", + "\n", + "train_loader, val_loader = create_data_loaders(\n", + " root_dir=config[\"data_dir\"],\n", + " batch_size=config[\"batch_size\"],\n", + " train_split=config[\"train_split\"],\n", + " image_size=tuple(config[\"image_size\"]),\n", + " num_workers=config[\"num_workers\"],\n", + ")\n", + "print(f\"Data loaders created: train={len(train_loader.dataset)}, val={len(val_loader.dataset)}\")\n", + "\n", + "model = create_gan(\n", + " gan_mode=config[\"gan_mode\"],\n", + " lambda_GAN=config[\"lambda_GAN\"],\n", + " lambda_L1=config[\"lambda_L1\"],\n", + " lambda_SSIM=config[\"lambda_SSIM\"],\n", + " lambda_edge=config[\"lambda_edge\"],\n", + " use_cuda=(device.type == \"cuda\"),\n", + ")\n", + "\n", + "generator_params = sum(p.numel() for p in model.generator.parameters())\n", + "discriminator_params = sum(p.numel() for p in model.discriminator.parameters())\n", + "print(f\"Model created: generator={generator_params:,}, discriminator={discriminator_params:,}\")\n", + "\n", + "trainer = create_trainer(model, train_loader, val_loader, config)\n", + "trainer.train(config[\"epochs\"])\n", + "\n", + "training_analysis = analyze_training(trainer)\n", + "images_dir = Path(config[\"output_dir\"]) / \"images\"\n", + "history_plot_path = plot_training_history(trainer, images_dir)\n", + "generation_samples_path = visualize_generation(\n", + " model=model,\n", + " data_loader=val_loader,\n", + " output_dir=images_dir,\n", + " device=device,\n", + " num_samples=config[\"num_visual_samples\"],\n", + ")\n", + "\n", + "print(f\"Training history plot: {history_plot_path}\")\n", + "print(f\"Generation samples: {generation_samples_path}\")\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!zip artefacts.zip runs/checkpoints/best.pth runs/images/training_history.png runs/images/generation_samples.png\n", + "\n" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": ".venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "name": "python", + "version": "3.11.0" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} \ No newline at end of file diff --git a/models/GAN/src/__init__.py b/models/GAN/src/__init__.py new file mode 100644 index 0000000..bfa6849 --- /dev/null +++ b/models/GAN/src/__init__.py @@ -0,0 +1 @@ +"""GAN package for Google-to-Yandex map image translation.""" diff --git a/models/GAN/src/analyze.py b/models/GAN/src/analyze.py new file mode 100644 index 0000000..bbe8f30 --- /dev/null +++ b/models/GAN/src/analyze.py @@ -0,0 +1,117 @@ +from pathlib import Path + +import matplotlib.pyplot as plt +import torch + + +def denormalize_image(tensor): + return (tensor.detach().cpu() * 0.5 + 0.5).clamp(0, 1) + + +@torch.no_grad() +def visualize_generation(model, data_loader, output_dir, device=None, num_samples=4, show=True): + device = device or model.device + model.eval() + + batch = next(iter(data_loader)) + google_img = batch["google_img"][:num_samples].to(device) + yandex_img = batch["yandex_img"][:num_samples].to(device) + fake_yandex = model.generator(google_img) + + output_dir = Path(output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + output_path = output_dir / "generation_samples.png" + + fig, axes = plt.subplots(num_samples, 3, figsize=(9, 3 * num_samples)) + if num_samples == 1: + axes = axes.reshape(1, 3) + + titles = ["Google input", "Generated Yandex", "Yandex target"] + for row in range(num_samples): + images = [google_img[row], fake_yandex[row], yandex_img[row]] + for col, image in enumerate(images): + axes[row, col].imshow(denormalize_image(image).permute(1, 2, 0)) + axes[row, col].set_title(titles[col]) + axes[row, col].axis("off") + + fig.tight_layout() + fig.savefig(output_path, dpi=150) + if show: + plt.show() + plt.close(fig) + return output_path + + +def plot_training_history(trainer, output_dir, show=True): + output_dir = Path(output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + output_path = output_dir / "training_history.png" + + epochs = range(1, len(trainer.g_losses) + 1) + fig, axes = plt.subplots(2, 3, figsize=(15, 8)) + axes = axes.ravel() + + axes[0].plot(epochs, trainer.g_losses, label="train G") + axes[0].plot(epochs, trainer.val_g_losses, label="val G") + axes[0].set_title("Generator loss") + axes[0].set_xlabel("Epoch") + axes[0].legend() + axes[0].grid(True, alpha=0.3) + + axes[1].plot(epochs, trainer.d_losses, label="train D") + axes[1].plot(epochs, trainer.val_d_losses, label="val D") + axes[1].set_title("Discriminator loss") + axes[1].set_xlabel("Epoch") + axes[1].legend() + axes[1].grid(True, alpha=0.3) + + axes[2].plot(epochs, trainer.l1_losses, label="train L1") + axes[2].plot(epochs, trainer.val_l1_losses, label="val L1") + axes[2].set_title("Paired Yandex L1 loss") + axes[2].set_xlabel("Epoch") + axes[2].legend() + axes[2].grid(True, alpha=0.3) + + axes[3].plot(epochs, trainer.ssim_losses, label="train SSIM") + axes[3].plot(epochs, trainer.val_ssim_losses, label="val SSIM") + axes[3].set_title("SSIM structure loss") + axes[3].set_xlabel("Epoch") + axes[3].legend() + axes[3].grid(True, alpha=0.3) + + axes[4].plot(epochs, trainer.edge_losses, label="train edge") + axes[4].plot(epochs, trainer.val_edge_losses, label="val edge") + axes[4].set_title("Sobel edge loss") + axes[4].set_xlabel("Epoch") + axes[4].legend() + axes[4].grid(True, alpha=0.3) + + axes[5].plot(epochs, trainer.val_reconstruction_losses, label="val reconstruction") + axes[5].set_title("Best-checkpoint score") + axes[5].set_xlabel("Epoch") + axes[5].legend() + axes[5].grid(True, alpha=0.3) + + fig.tight_layout() + fig.savefig(output_path, dpi=150) + if show: + plt.show() + plt.close(fig) + return output_path + + +def analyze_training(trainer): + return { + "best_val_loss": trainer.best_val_loss, + "g_losses": trainer.g_losses, + "d_losses": trainer.d_losses, + "l1_losses": trainer.l1_losses, + "ssim_losses": trainer.ssim_losses, + "edge_losses": trainer.edge_losses, + "val_g_losses": trainer.val_g_losses, + "val_d_losses": trainer.val_d_losses, + "val_l1_losses": trainer.val_l1_losses, + "val_ssim_losses": trainer.val_ssim_losses, + "val_edge_losses": trainer.val_edge_losses, + "val_reconstruction_losses": trainer.val_reconstruction_losses, + } diff --git a/models/GAN/config.py b/models/GAN/src/config.py similarity index 62% rename from models/GAN/config.py rename to models/GAN/src/config.py index b97e1b5..69fa681 100644 --- a/models/GAN/config.py +++ b/models/GAN/src/config.py @@ -6,23 +6,30 @@ def create_config(): return { # Optimizer params "learning_rate": 2e-4, + "discriminator_lr_factor": 0.5, "beta1": 0.5, "beta2": 0.999, # Training params "batch_size": 32, "epochs": 100, + "prefer_cuda": True, # GAN params - "gan_mode": "vanilla", - "lambda_L1": 100.0, + "gan_mode": "lsgan", + "lambda_GAN": 0.5, + "lambda_L1": 150.0, + "lambda_SSIM": 25.0, + "lambda_edge": 20.0, + "discriminator_update_interval": 1, # Regularization "grad_clip": 1.0, # Early stopping - "early_stopping_patience": 20, + "early_stopping_patience": 25, # Output - "output_dir": "runs/gan_training", + "output_dir": r"C:\Users\admin\Projects\autopilot\models\GAN\runs", # Logging "log_interval": 10, "save_interval": 5, + "num_visual_samples": 4, # Data "data_dir": r"C:\Users\admin\Projects\autopilot\datasets\ya_go_maps\images", "image_size": [256, 256], @@ -33,4 +40,4 @@ def create_config(): if __name__ == "__main__": config = create_config() - print("Default config:", config) \ No newline at end of file + print("Default config:", config) diff --git a/models/GAN/dataloader.py b/models/GAN/src/dataloader.py similarity index 97% rename from models/GAN/dataloader.py rename to models/GAN/src/dataloader.py index 3472db1..c21ba3f 100644 --- a/models/GAN/dataloader.py +++ b/models/GAN/src/dataloader.py @@ -1,4 +1,4 @@ -"""Data loader for Yandex-to-Google image translation.""" +"""Data loader for Google-to-Yandex image translation.""" import os from typing import Dict, List, Tuple @@ -10,7 +10,7 @@ from torchvision import transforms class YaGoDataset(Dataset): - """Dataset loading pairs of Yandex and Google map images.""" + """Dataset loading paired Google and Yandex map images.""" def __init__( self, @@ -159,4 +159,4 @@ if __name__ == "__main__": batch = next(iter(train_loader)) print(f"Batch shapes: google={batch['google_img'].shape}, yandex={batch['yandex_img'].shape}") - print(f"Train batches: {len(train_loader)}, Val batches: {len(val_loader)}") \ No newline at end of file + print(f"Train batches: {len(train_loader)}, Val batches: {len(val_loader)}") diff --git a/models/GAN/src/main.py b/models/GAN/src/main.py new file mode 100644 index 0000000..4e2966d --- /dev/null +++ b/models/GAN/src/main.py @@ -0,0 +1,60 @@ +"""Executable GAN training pipeline. + +The code is intentionally top-level, mirroring the SiaN notebook style: +when this file is included in the generated notebook, variables remain +available for debugging and interactive experiments. +""" + +from pathlib import Path + +import torch + +from analyze import analyze_training, plot_training_history, visualize_generation +from config import create_config +from dataloader import create_data_loaders +from model import create_gan, get_compatible_device +from trainer import create_trainer + + +config = create_config() +device = get_compatible_device(prefer_cuda=config["prefer_cuda"]) +print(f"Using device: {device}") + +train_loader, val_loader = create_data_loaders( + root_dir=config["data_dir"], + batch_size=config["batch_size"], + train_split=config["train_split"], + image_size=tuple(config["image_size"]), + num_workers=config["num_workers"], +) +print(f"Data loaders created: train={len(train_loader.dataset)}, val={len(val_loader.dataset)}") + +model = create_gan( + gan_mode=config["gan_mode"], + lambda_GAN=config["lambda_GAN"], + lambda_L1=config["lambda_L1"], + lambda_SSIM=config["lambda_SSIM"], + lambda_edge=config["lambda_edge"], + use_cuda=(device.type == "cuda"), +) + +generator_params = sum(p.numel() for p in model.generator.parameters()) +discriminator_params = sum(p.numel() for p in model.discriminator.parameters()) +print(f"Model created: generator={generator_params:,}, discriminator={discriminator_params:,}") + +trainer = create_trainer(model, train_loader, val_loader, config) +trainer.train(config["epochs"]) + +training_analysis = analyze_training(trainer) +images_dir = Path(config["output_dir"]) / "images" +history_plot_path = plot_training_history(trainer, images_dir) +generation_samples_path = visualize_generation( + model=model, + data_loader=val_loader, + output_dir=images_dir, + device=device, + num_samples=config["num_visual_samples"], +) + +print(f"Training history plot: {history_plot_path}") +print(f"Generation samples: {generation_samples_path}") diff --git a/models/GAN/model.py b/models/GAN/src/model.py similarity index 56% rename from models/GAN/model.py rename to models/GAN/src/model.py index dfb0787..f250c13 100644 --- a/models/GAN/model.py +++ b/models/GAN/src/model.py @@ -1,10 +1,37 @@ -"""GAN model for image translation Yandex -> Google.""" +"""GAN model for image translation Google -> Yandex.""" import torch import torch.nn as nn import torch.nn.functional as F +def get_compatible_device(prefer_cuda: bool = True, verbose: bool = True) -> torch.device: + """Return CUDA only when the current PyTorch build supports the GPU arch.""" + if not prefer_cuda or not torch.cuda.is_available(): + return torch.device("cpu") + + try: + major, minor = torch.cuda.get_device_capability() + arch = f"sm_{major}{minor}" + supported_arches = set(torch.cuda.get_arch_list()) + gpu_name = torch.cuda.get_device_name() + except Exception as exc: + if verbose: + print(f"CUDA is visible but cannot be inspected ({exc}); using CPU.") + return torch.device("cpu") + + if supported_arches and arch not in supported_arches: + if verbose: + supported = ", ".join(sorted(supported_arches)) + print( + f"CUDA GPU '{gpu_name}' has capability {arch}, but this PyTorch build " + f"supports only: {supported}. Using CPU." + ) + return torch.device("cpu") + + return torch.device("cuda") + + class UNetDownBlock(nn.Module): """Downsampling block for U-Net.""" @@ -29,7 +56,10 @@ class UNetUpBlock(nn.Module): def __init__(self, in_channels: int, out_channels: int, dropout: float = 0.0): super().__init__() - self.upconv = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=4, stride=2, padding=1, bias=False) + self.upconv = nn.Sequential( + nn.Upsample(scale_factor=2, mode="bilinear", align_corners=False), + nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, bias=False), + ) self.norm = nn.BatchNorm2d(out_channels) self.relu = nn.ReLU(inplace=True) if dropout > 0: @@ -38,7 +68,6 @@ class UNetUpBlock(nn.Module): self.dropout = None def forward(self, x, skip_input): - x = self.upconv(x) # Pad if needed to match skip connection size if x.shape != skip_input.shape: @@ -54,7 +83,7 @@ class UNetUpBlock(nn.Module): class GeneratorUNet(nn.Module): - """U-Net generator for Yandex -> Google translation.""" + """U-Net generator for Google -> Yandex translation.""" def __init__(self, in_channels: int = 3, out_channels: int = 3): super().__init__() @@ -85,7 +114,8 @@ class GeneratorUNet(nn.Module): # Final self.final = nn.Sequential( - nn.ConvTranspose2d(128, out_channels, kernel_size=4, stride=2, padding=1), + nn.Upsample(scale_factor=2, mode="bilinear", align_corners=False), + nn.Conv2d(128, out_channels, kernel_size=3, padding=1), nn.Tanh(), ) @@ -115,7 +145,7 @@ class GeneratorUNet(nn.Module): class DiscriminatorPatchGAN(nn.Module): - """PatchGAN discriminator.""" + """PatchGAN discriminator for paired source/target images.""" def __init__(self, in_channels: int = 6): super().__init__() @@ -132,7 +162,6 @@ class DiscriminatorPatchGAN(nn.Module): nn.BatchNorm2d(512), nn.LeakyReLU(0.2, inplace=True), nn.Conv2d(512, 1, kernel_size=4, stride=1, padding=1), - nn.Sigmoid(), ) def forward(self, img_A, img_B): @@ -167,15 +196,82 @@ class GANLoss(nn.Module): return -prediction.mean() if target_is_real else prediction.mean() +class SSIMLoss(nn.Module): + """Local SSIM loss for normalized image tensors in [-1, 1].""" + + def __init__(self, window_size: int = 11, c1: float = 0.01 ** 2, c2: float = 0.03 ** 2): + super().__init__() + self.window_size = window_size + self.c1 = c1 + self.c2 = c2 + + def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor: + pred = (pred + 1.0) * 0.5 + target = (target + 1.0) * 0.5 + padding = self.window_size // 2 + + mu_pred = F.avg_pool2d(pred, self.window_size, stride=1, padding=padding) + mu_target = F.avg_pool2d(target, self.window_size, stride=1, padding=padding) + mu_pred_sq = mu_pred.pow(2) + mu_target_sq = mu_target.pow(2) + mu_pred_target = mu_pred * mu_target + + sigma_pred = F.avg_pool2d(pred * pred, self.window_size, stride=1, padding=padding) - mu_pred_sq + sigma_target = F.avg_pool2d(target * target, self.window_size, stride=1, padding=padding) - mu_target_sq + sigma_pred_target = F.avg_pool2d(pred * target, self.window_size, stride=1, padding=padding) - mu_pred_target + + ssim_map = ( + (2 * mu_pred_target + self.c1) * (2 * sigma_pred_target + self.c2) + ) / ( + (mu_pred_sq + mu_target_sq + self.c1) * (sigma_pred + sigma_target + self.c2) + ) + return (1.0 - ssim_map.clamp(0, 1)).mean() + + +class SobelEdgeLoss(nn.Module): + """L1 loss between Sobel edge maps, useful for stable keypoint structure.""" + + def __init__(self): + super().__init__() + kernel_x = torch.tensor( + [[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]], + dtype=torch.float32, + ).view(1, 1, 3, 3) + kernel_y = torch.tensor( + [[-1, -2, -1], [0, 0, 0], [1, 2, 1]], + dtype=torch.float32, + ).view(1, 1, 3, 3) + self.register_buffer("kernel_x", kernel_x) + self.register_buffer("kernel_y", kernel_y) + + @staticmethod + def _to_gray(x: torch.Tensor) -> torch.Tensor: + x = (x + 1.0) * 0.5 + weights = x.new_tensor([0.299, 0.587, 0.114]).view(1, 3, 1, 1) + return (x * weights).sum(dim=1, keepdim=True) + + def _edges(self, x: torch.Tensor) -> torch.Tensor: + gray = self._to_gray(x) + grad_x = F.conv2d(gray, self.kernel_x, padding=1) + grad_y = F.conv2d(gray, self.kernel_y, padding=1) + return torch.sqrt(grad_x.pow(2) + grad_y.pow(2) + 1e-6) + + def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor: + return F.l1_loss(self._edges(pred), self._edges(target)) + + class ImageGAN(nn.Module): - """Complete GAN model for image translation.""" + """Complete pix2pix-style GAN for Google -> Yandex image translation.""" def __init__( self, input_channels: int = 3, output_channels: int = 3, - gan_mode: str = "vanilla", - lambda_L1: float = 100.0, + gan_mode: str = "lsgan", + lambda_L1: float = 150.0, + lambda_GAN: float = 0.5, + lambda_SSIM: float = 25.0, + lambda_edge: float = 20.0, use_cuda: bool = True, ): super().__init__() @@ -183,29 +279,36 @@ class ImageGAN(nn.Module): self.discriminator = DiscriminatorPatchGAN(input_channels + output_channels) self.gan_loss = GANLoss(gan_mode) self.l1_loss = nn.L1Loss() + self.ssim_loss = SSIMLoss() + self.edge_loss = SobelEdgeLoss() self.lambda_L1 = lambda_L1 + self.lambda_GAN = lambda_GAN + self.lambda_SSIM = lambda_SSIM + self.lambda_edge = lambda_edge - self.device = torch.device("cuda" if use_cuda and torch.cuda.is_available() else "cpu") + self.device = get_compatible_device(prefer_cuda=use_cuda) self.to(self.device) - def forward(self, yandex_image): - """Generate Google image from Yandex.""" - return self.generator(yandex_image) + def forward(self, google_image): + """Generate a Yandex-style image from a Google image.""" + return self.generator(google_image) - def generator_step(self, yandex_img, real_google_img): - """Compute generator losses.""" - fake_google = self.generator(yandex_img) - fake_pred = self.discriminator(yandex_img, fake_google) - gan_loss = self.gan_loss(fake_pred, True) - l1_loss = self.l1_loss(fake_google, real_google_img) * self.lambda_L1 - total_loss = gan_loss + l1_loss - return total_loss, gan_loss, l1_loss + def generator_step(self, google_img, real_yandex_img): + """Compute generator losses against the paired original Yandex image.""" + fake_yandex = self.generator(google_img) + fake_pred = self.discriminator(google_img, fake_yandex) + gan_loss = self.gan_loss(fake_pred, True) * self.lambda_GAN + l1_loss = self.l1_loss(fake_yandex, real_yandex_img) * self.lambda_L1 + ssim_loss = self.ssim_loss(fake_yandex, real_yandex_img) * self.lambda_SSIM + edge_loss = self.edge_loss(fake_yandex, real_yandex_img) * self.lambda_edge + total_loss = gan_loss + l1_loss + ssim_loss + edge_loss + return total_loss, gan_loss, l1_loss, ssim_loss, edge_loss - def discriminator_step(self, yandex_img, real_google_img, fake_google_img): - """Compute discriminator losses.""" - real_pred = self.discriminator(yandex_img, real_google_img) + def discriminator_step(self, google_img, real_yandex_img, fake_yandex_img): + """Compute discriminator losses for real and generated Yandex targets.""" + real_pred = self.discriminator(google_img, real_yandex_img) real_loss = self.gan_loss(real_pred, True) - fake_pred = self.discriminator(yandex_img, fake_google_img.detach()) + fake_pred = self.discriminator(google_img, fake_yandex_img.detach()) fake_loss = self.gan_loss(fake_pred, False) total_loss = (real_loss + fake_loss) * 0.5 return total_loss, real_loss, fake_loss @@ -214,8 +317,11 @@ class ImageGAN(nn.Module): def create_gan( input_channels: int = 3, output_channels: int = 3, - gan_mode: str = "vanilla", - lambda_L1: float = 100.0, + gan_mode: str = "lsgan", + lambda_L1: float = 150.0, + lambda_GAN: float = 0.5, + lambda_SSIM: float = 25.0, + lambda_edge: float = 20.0, use_cuda: bool = True, ) -> ImageGAN: """Create a GAN model.""" @@ -224,6 +330,9 @@ def create_gan( output_channels=output_channels, gan_mode=gan_mode, lambda_L1=lambda_L1, + lambda_GAN=lambda_GAN, + lambda_SSIM=lambda_SSIM, + lambda_edge=lambda_edge, use_cuda=use_cuda, ) @@ -252,4 +361,4 @@ if __name__ == "__main__": # Count parameters gen_params = sum(p.numel() for p in model.generator.parameters()) disc_params = sum(p.numel() for p in model.discriminator.parameters()) - print(f"Generator: {gen_params:,} params, Discriminator: {disc_params:,} params") \ No newline at end of file + print(f"Generator: {gen_params:,} params, Discriminator: {disc_params:,} params") diff --git a/models/GAN/src/test_dataloader.py b/models/GAN/src/test_dataloader.py new file mode 100644 index 0000000..e7b949f --- /dev/null +++ b/models/GAN/src/test_dataloader.py @@ -0,0 +1,40 @@ +from config import create_config +from dataloader import YaGoDataset, create_data_loaders + + +def get_dataset_info(): + config = create_config() + dataset = YaGoDataset( + root_dir=config["data_dir"], + image_size=tuple(config["image_size"]), + ) + sample = dataset[0] if len(dataset) else {} + return { + "size": len(dataset), + "sample_keys": list(sample.keys()), + "google_shape": tuple(sample["google_img"].shape) if sample else None, + "yandex_shape": tuple(sample["yandex_img"].shape) if sample else None, + } + + +def smoke_test_dataloader(batch_size=4): + config = create_config() + train_loader, val_loader = create_data_loaders( + root_dir=config["data_dir"], + batch_size=batch_size, + train_split=config["train_split"], + num_workers=config["num_workers"], + image_size=tuple(config["image_size"]), + ) + batch = next(iter(train_loader)) + return { + "train_size": len(train_loader.dataset), + "val_size": len(val_loader.dataset), + "google_batch_shape": tuple(batch["google_img"].shape), + "yandex_batch_shape": tuple(batch["yandex_img"].shape), + } + + +if __name__ == "__main__": + print(get_dataset_info()) + print(smoke_test_dataloader()) diff --git a/models/GAN/trainer.py b/models/GAN/src/trainer.py similarity index 58% rename from models/GAN/trainer.py rename to models/GAN/src/trainer.py index 77fb0b0..7bac290 100644 --- a/models/GAN/trainer.py +++ b/models/GAN/src/trainer.py @@ -28,18 +28,26 @@ class GANTrainer: # Optimizers lr = config.get("learning_rate", 2e-4) + lr_d = config.get("discriminator_learning_rate", lr * config.get("discriminator_lr_factor", 0.5)) beta1 = config.get("beta1", 0.5) beta2 = config.get("beta2", 0.999) self.opt_G = torch.optim.Adam(model.generator.parameters(), lr=lr, betas=(beta1, beta2)) - self.opt_D = torch.optim.Adam(model.discriminator.parameters(), lr=lr, betas=(beta1, beta2)) + self.opt_D = torch.optim.Adam(model.discriminator.parameters(), lr=lr_d, betas=(beta1, beta2)) # Training state self.current_epoch = 0 self.best_val_loss = float("inf") self.g_losses = [] self.d_losses = [] + self.l1_losses = [] + self.ssim_losses = [] + self.edge_losses = [] self.val_g_losses = [] self.val_d_losses = [] + self.val_l1_losses = [] + self.val_ssim_losses = [] + self.val_edge_losses = [] + self.val_reconstruction_losses = [] # Output dir self.output_dir = Path(config.get("output_dir", "runs/gan")) @@ -54,35 +62,55 @@ class GANTrainer: """Train for one epoch.""" self.model.train() total_g = total_d = 0.0 + total_l1 = total_ssim = total_edge = 0.0 num_batches = len(self.train_loader) + d_update_interval = max(1, self.config.get("discriminator_update_interval", 1)) - pbar = tqdm(self.train_loader, desc=f"Epoch {self.current_epoch + 1}") - for batch in pbar: - yandex_img = batch["yandex_img"].to(self.device) + pbar = tqdm(enumerate(self.train_loader), total=num_batches, desc=f"Epoch {self.current_epoch + 1}") + for batch_idx, batch in pbar: google_img = batch["google_img"].to(self.device) + yandex_img = batch["yandex_img"].to(self.device) # Train D - self.opt_D.zero_grad() - with torch.no_grad(): - fake_img = self.model.generator(yandex_img) - d_loss = self.model.discriminator_step(yandex_img, google_img, fake_img)[0] - d_loss.backward() - self.opt_D.step() + if batch_idx % d_update_interval == 0: + self.opt_D.zero_grad() + with torch.no_grad(): + fake_img = self.model.generator(google_img) + d_loss = self.model.discriminator_step(google_img, yandex_img, fake_img)[0] + d_loss.backward() + self.opt_D.step() + else: + d_loss = google_img.new_tensor(0.0) # Train G self.opt_G.zero_grad() - g_loss = self.model.generator_step(yandex_img, google_img)[0] + g_loss, gan_loss, l1_loss, ssim_loss, edge_loss = self.model.generator_step(google_img, yandex_img) g_loss.backward() self.opt_G.step() total_g += g_loss.item() total_d += d_loss.item() - pbar.set_postfix({"g_loss": g_loss.item(), "d_loss": d_loss.item()}) + total_l1 += l1_loss.item() + total_ssim += ssim_loss.item() + total_edge += edge_loss.item() + pbar.set_postfix({ + "g_loss": g_loss.item(), + "d_loss": d_loss.item(), + "l1": l1_loss.item(), + "ssim": ssim_loss.item(), + "edge": edge_loss.item(), + }) avg_g = total_g / num_batches avg_d = total_d / num_batches + avg_l1 = total_l1 / num_batches + avg_ssim = total_ssim / num_batches + avg_edge = total_edge / num_batches self.g_losses.append(avg_g) self.d_losses.append(avg_d) + self.l1_losses.append(avg_l1) + self.ssim_losses.append(avg_ssim) + self.edge_losses.append(avg_edge) return avg_g, avg_d @torch.no_grad() @@ -90,20 +118,32 @@ class GANTrainer: """Validate the model.""" self.model.eval() total_g = total_d = 0.0 + total_l1 = total_ssim = total_edge = 0.0 for batch in tqdm(self.val_loader, desc="Val"): - yandex_img = batch["yandex_img"].to(self.device) google_img = batch["google_img"].to(self.device) - fake_img = self.model.generator(yandex_img) - g_loss = self.model.generator_step(yandex_img, google_img)[0] - d_loss = self.model.discriminator_step(yandex_img, google_img, fake_img)[0] + yandex_img = batch["yandex_img"].to(self.device) + fake_img = self.model.generator(google_img) + g_loss, _, l1_loss, ssim_loss, edge_loss = self.model.generator_step(google_img, yandex_img) + d_loss = self.model.discriminator_step(google_img, yandex_img, fake_img)[0] total_g += g_loss.item() total_d += d_loss.item() + total_l1 += l1_loss.item() + total_ssim += ssim_loss.item() + total_edge += edge_loss.item() avg_g = total_g / len(self.val_loader) avg_d = total_d / len(self.val_loader) + avg_l1 = total_l1 / len(self.val_loader) + avg_ssim = total_ssim / len(self.val_loader) + avg_edge = total_edge / len(self.val_loader) + avg_reconstruction = avg_l1 + avg_ssim + avg_edge self.val_g_losses.append(avg_g) self.val_d_losses.append(avg_d) + self.val_l1_losses.append(avg_l1) + self.val_ssim_losses.append(avg_ssim) + self.val_edge_losses.append(avg_edge) + self.val_reconstruction_losses.append(avg_reconstruction) return avg_g, avg_d def train(self, num_epochs: int): @@ -117,23 +157,30 @@ class GANTrainer: train_g, train_d = self.train_epoch() val_g, val_d = self.validate() - # Save best checkpoint - val_total = val_g + val_d - if val_total < self.best_val_loss: - self.best_val_loss = val_total + val_reconstruction = self.val_reconstruction_losses[-1] + if val_reconstruction < self.best_val_loss: + self.best_val_loss = val_reconstruction self.save_checkpoint("best") # Periodic checkpoint if (epoch + 1) % self.config.get("save_interval", 5) == 0: self.save_checkpoint(f"epoch_{epoch + 1}") - print(f"Epoch {epoch + 1}: train_g={train_g:.4f}, train_d={train_d:.4f}, val_g={val_g:.4f}, val_d={val_d:.4f}") + print( + f"Epoch {epoch + 1}: " + f"train_g={train_g:.4f}, train_d={train_d:.4f}, " + f"train_l1={self.l1_losses[-1]:.4f}, train_ssim={self.ssim_losses[-1]:.4f}, " + f"train_edge={self.edge_losses[-1]:.4f}, val_g={val_g:.4f}, val_d={val_d:.4f}, " + f"val_l1={self.val_l1_losses[-1]:.4f}, val_ssim={self.val_ssim_losses[-1]:.4f}, " + f"val_edge={self.val_edge_losses[-1]:.4f}, val_rec={val_reconstruction:.4f}" + ) # Early stopping patience = self.config.get("early_stopping_patience", 0) - if patience > 0 and len(self.val_g_losses) > patience: - recent = self.val_g_losses[-patience:] - if all(l >= min(self.val_g_losses[:-patience]) for l in recent): + if patience > 0 and len(self.val_reconstruction_losses) > patience: + recent = self.val_reconstruction_losses[-patience:] + previous_best = min(self.val_reconstruction_losses[:-patience]) + if all(loss >= previous_best for loss in recent): print(f"Early stopping at epoch {epoch + 1}") break @@ -188,4 +235,4 @@ if __name__ == "__main__": g_loss, d_loss = trainer.train_epoch() print(f"Training step succeeded: G={g_loss:.4f}, D={d_loss:.4f}") except Exception as e: - print(f"Training step failed: {e}") \ No newline at end of file + print(f"Training step failed: {e}") diff --git a/practice/Indiv_zadanie_Сагитов.docx b/practice/Indiv_zadanie_Сагитов.docx new file mode 100644 index 0000000..81166db Binary files /dev/null and b/practice/Indiv_zadanie_Сагитов.docx differ diff --git a/practice/otchet_po_praktike.docx b/practice/otchet_po_praktike.docx new file mode 100644 index 0000000..6d895b4 Binary files /dev/null and b/practice/otchet_po_praktike.docx differ diff --git a/practice/Дневник практики.docx b/practice/Дневник практики.docx new file mode 100644 index 0000000..b543ed6 Binary files /dev/null and b/practice/Дневник практики.docx differ diff --git a/sian_similarity.py b/sian_similarity.py new file mode 100644 index 0000000..2ca2622 --- /dev/null +++ b/sian_similarity.py @@ -0,0 +1,145 @@ +from __future__ import annotations + +import importlib.util +import os +from pathlib import Path +from typing import Optional, Sequence + +import numpy as np +import torch +from PIL import Image + +from vision_chunk import VisionChunk + + +ROOT_DIR = Path(__file__).resolve().parent +MODEL_FILE = ROOT_DIR / "models" / "SiaN-similarity" / "model.py" +DEFAULT_CHECKPOINT_PATH = ( + ROOT_DIR + / "models" + / "SiaN-similarity" + / "runs" + / "gan_training" + / "checkpoints" + / "best_model.pt" +) + +IMAGE_SIZE = (256, 256) +DEFAULT_THRESHOLD = 0.5 +CHECKPOINT_ENV = "SIAN_SIMILARITY_CHECKPOINT" +THRESHOLD_ENV = "SIAN_SIMILARITY_THRESHOLD" + +_model: Optional[torch.nn.Module] = None +_device: Optional[torch.device] = None + + +def _load_similarity_class(): + spec = importlib.util.spec_from_file_location("sian_similarity_model", MODEL_FILE) + if spec is None or spec.loader is None: + raise ImportError(f"Cannot load similarity model from {MODEL_FILE}") + + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + return module.SimilarityCNN + + +def _get_checkpoint_path() -> Path: + checkpoint_path = os.getenv(CHECKPOINT_ENV) + if checkpoint_path: + return Path(checkpoint_path).expanduser().resolve() + return DEFAULT_CHECKPOINT_PATH + + +def get_threshold() -> float: + threshold = os.getenv(THRESHOLD_ENV) + if threshold is None: + return DEFAULT_THRESHOLD + return float(threshold) + + +def _get_device() -> torch.device: + global _device + + if _device is None: + _device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + return _device + + +def _get_model() -> torch.nn.Module: + global _model + + if _model is not None: + return _model + + checkpoint_path = _get_checkpoint_path() + if not checkpoint_path.exists(): + raise FileNotFoundError( + f"SiaN similarity checkpoint not found: {checkpoint_path}. " + f"Set {CHECKPOINT_ENV} to another .pt file if needed." + ) + + SimilarityCNN = _load_similarity_class() + device = _get_device() + + model = SimilarityCNN( + input_channels=3, + backbone_name="resnet18", + pretrained=False, + dropout_rate=0.3, + use_batch_norm=True, + ).to(device) + + checkpoint = torch.load(checkpoint_path, map_location=device) + state_dict = checkpoint.get("model_state_dict", checkpoint) + model.load_state_dict(state_dict) + model.eval() + + _model = model + return _model + + +def _chunk_to_tensor(chunk: VisionChunk) -> torch.Tensor: + image = chunk.image.convert("RGB").resize(IMAGE_SIZE, Image.BILINEAR) + array = np.asarray(image, dtype=np.float32) / 255.0 + tensor = torch.from_numpy(array).permute(2, 0, 1).unsqueeze(0) + return tensor.to(_get_device()) + + +def get_similarity_score(chunk1: VisionChunk, chunk2: VisionChunk) -> float: + if chunk1 is None or chunk2 is None: + return 0.0 + + model = _get_model() + img1 = _chunk_to_tensor(chunk1) + img2 = _chunk_to_tensor(chunk2) + + with torch.inference_mode(): + similarity = model(img1, img2) + + return float(similarity.squeeze().item()) + + +def get_similarity_scores(chunk: VisionChunk, candidates: Sequence[VisionChunk]) -> list[float]: + if chunk is None or not candidates: + return [] + + model = _get_model() + img = _chunk_to_tensor(chunk) + candidate_images = torch.cat([_chunk_to_tensor(candidate) for candidate in candidates], dim=0) + repeated_img = img.expand(candidate_images.shape[0], -1, -1, -1) + + with torch.inference_mode(): + similarities = model(repeated_img, candidate_images) + + return [float(score) for score in similarities.squeeze(1).detach().cpu().tolist()] + + +def is_similar( + chunk1: VisionChunk, + chunk2: VisionChunk, + threshold: Optional[float] = None, +) -> bool: + if threshold is None: + threshold = get_threshold() + + return get_similarity_score(chunk1, chunk2) >= threshold