feat: add GAN

This commit is contained in:
2026-05-30 14:49:40 +03:00
parent 6477ce0776
commit 72e1950127
29 changed files with 2670 additions and 361 deletions

21
.vscode/c_cpp_properties.json vendored Normal file
View File

@@ -0,0 +1,21 @@
{
"configurations": [
{
"name": "Win32",
"includePath": [
"${workspaceFolder}/**"
],
"defines": [
"_DEBUG",
"UNICODE",
"_UNICODE"
],
"windowsSdkVersion": "10.0.26100.0",
"compilerPath": "cl.exe",
"cStandard": "c17",
"cppStandard": "c++17",
"intelliSenseMode": "windows-msvc-x64"
}
],
"version": 4
}

View File

@@ -8,6 +8,7 @@ import cv2
import numpy as np
from PIL import Image
import sian_similarity
from timer import Timer
from vision_chunk import VisionChunk
@@ -58,8 +59,16 @@ class AutoPilot(Pilot):
# Положение на основе ориентира
reserved_pos: Position | None
proccessing_time: float
use_sian_similarity: bool
def __init__(self, points = [], chunks = [], viz_manager=None, pixel_ratio: float = 1.):
def __init__(
self,
points = [],
chunks = [],
viz_manager=None,
pixel_ratio: float = 1.,
use_sian_similarity: bool = False,
):
self.prev_chunk = None
self.pos = Position(0, 0, 1, 0, 0, 0)
self.chunks = chunks
@@ -67,6 +76,7 @@ class AutoPilot(Pilot):
self.vis_manager = viz_manager # Менеджер визуализации
self.reserved_pos = None
self.pixel_ratio = pixel_ratio
self.use_sian_similarity = use_sian_similarity
# Пороговые значения качества сопоставления/гомографии
self.min_inliers: int = 12
@@ -153,17 +163,33 @@ class AutoPilot(Pilot):
landmark_timer = Timer()
landmark_timer.start()
cur_pos = np.array([self.pos.x, self.pos.y])
closest_chunk_idx = ((self.chunk_points - cur_pos) ** 2).sum(1).argmin()
current_chunk = self.prev_chunk
landmark_chunk = self.chunks[closest_chunk_idx]
if current_chunk is None or not self.chunks:
return None
if self.use_sian_similarity:
similarity_scores = sian_similarity.get_similarity_scores(current_chunk, self.chunks)
best_chunk_idx = int(np.argmax(similarity_scores))
best_similarity_score = similarity_scores[best_chunk_idx]
print(f"[LANDMARK]: best similarity={best_similarity_score:.4f}, idx={best_chunk_idx}")
if best_similarity_score < sian_similarity.get_threshold():
print("[LANDMARK]: not similar")
return None
print("[LANDMARK]: similar")
landmark_chunk = self.chunks[best_chunk_idx]
else:
cur_pos = np.array([self.pos.x, self.pos.y])
closest_chunk_idx = ((self.chunk_points - cur_pos) ** 2).sum(1).argmin()
landmark_chunk = self.chunks[closest_chunk_idx]
if constants.DEBUG_FPS:
print(f"[LANDMARK]: Closest chunk finding: {landmark_timer.loop() * 1000:.2f} ms")
print(f"[LANDMARK]: Landmark chunk finding: {landmark_timer.loop() * 1000:.2f} ms")
# Краевой случай: отсутствие чанков
if current_chunk is None or landmark_chunk is None:
if landmark_chunk is None:
return None
landmark_timer.start()
@@ -312,10 +338,10 @@ class AutoPilot(Pilot):
# Пытаемся найти ориентир на картинке:
self.prev_chunk = current_chunk
# Для улучшения среднего FPS
# if self.frame_count % 5 == 0:
# pos_by_chunk = self.get_position_by_chunk()
# if pos_by_chunk is not None:
# self.pos = pos_by_chunk
if self.frame_count % 5 == 0:
pos_by_chunk = self.get_position_by_chunk()
if pos_by_chunk is not None:
self.pos = pos_by_chunk
command = self.make_command()
self.timer.reset()

View File

@@ -0,0 +1,118 @@
<svg xmlns="http://www.w3.org/2000/svg" width="1800" height="1000" viewBox="0 0 1800 1000">
<defs>
<marker id="arrow" markerWidth="16" markerHeight="16" refX="14" refY="8" orient="auto">
<path d="M0,0 L16,8 L0,16 Z" fill="#1f2937"/>
</marker>
<style>
.bg { fill: #fafafa; }
.title { font: 700 40px Arial, sans-serif; fill: #111827; }
.subtitle { font: 400 21px Arial, sans-serif; fill: #4b5563; }
.box { stroke-width: 3; rx: 18; }
.blue { fill: #eaf3ff; stroke: #3b82f6; }
.orange { fill: #fff1df; stroke: #f59e0b; }
.violet { fill: #f0ebff; stroke: #7c3aed; }
.green { fill: #eaf8ef; stroke: #22a06b; }
.red { fill: #fff1f2; stroke: #ef4444; }
.label { font: 700 27px Arial, sans-serif; fill: #111827; text-anchor: middle; }
.desc { font: 400 18px Arial, sans-serif; fill: #4b5563; text-anchor: middle; }
.small { font: 400 16px Arial, sans-serif; fill: #6b7280; text-anchor: middle; }
.arrow { stroke: #1f2937; stroke-width: 4; fill: none; marker-end: url(#arrow); }
.branch { stroke: #6b7280; stroke-width: 3.5; fill: none; marker-end: url(#arrow); }
.icon-line { stroke: #111827; stroke-width: 5; fill: none; stroke-linecap: round; stroke-linejoin: round; }
.patch-grid { stroke: #7c3aed; stroke-width: 3; fill: none; }
.p1 { fill: #bbf7d0; }
.p2 { fill: #fde68a; }
.p3 { fill: #fecaca; }
.stage { font: 700 20px Arial, sans-serif; fill: #374151; text-anchor: middle; }
</style>
</defs>
<rect class="bg" width="1800" height="1000"/>
<text class="title" x="90" y="80">Дискриминатор PatchGAN</text>
<text class="subtitle" x="90" y="118">Проверяет пару изображений локальными патчами: настоящая ли это пара Google + Яндекс или результат генератора</text>
<text class="stage" x="245" y="195">Входные пары</text>
<text class="stage" x="690" y="195">Объединение</text>
<text class="stage" x="1125" y="195">PatchGAN</text>
<text class="stage" x="1530" y="195">Оценка</text>
<g transform="translate(90 240)">
<rect class="box green" width="310" height="205"/>
<text class="label" x="155" y="54">Real pair</text>
<text class="desc" x="155" y="85">Google + настоящий Яндекс</text>
<text class="small" x="155" y="116">целевая метка: 1</text>
<path class="icon-line" d="M78 158 C112 130, 150 176, 190 145 C211 128, 230 124, 250 122"/>
</g>
<g transform="translate(90 555)">
<rect class="box red" width="310" height="205"/>
<text class="label" x="155" y="54">Fake pair</text>
<text class="desc" x="155" y="85">Google + Generated Яндекс</text>
<text class="small" x="155" y="116">целевая метка: 0</text>
<path class="icon-line" d="M78 158 C112 130, 150 176, 190 145 C211 128, 230 124, 250 122"/>
</g>
<path class="branch" d="M420 342 H485 V470 H545"/>
<path class="branch" d="M420 657 H485 V520 H545"/>
<g transform="translate(565 390)">
<rect class="box blue" width="255" height="210"/>
<text class="label" x="128" y="58">Concat</text>
<text class="desc" x="128" y="92">6 каналов</text>
<text class="small" x="128" y="124">RGB Google + RGB Yandex</text>
<text class="small" x="128" y="154">B x 6 x 256 x 256</text>
</g>
<path class="arrow" d="M840 495 H930"/>
<g transform="translate(950 295)">
<rect class="box violet" width="350" height="400"/>
<text class="label" x="175" y="58">Сверточные блоки</text>
<text class="desc" x="175" y="92">Conv + BatchNorm + LeakyReLU</text>
<path class="icon-line" d="M78 170 H272"/>
<path class="icon-line" d="M78 230 H272"/>
<path class="icon-line" d="M78 290 H272"/>
<circle cx="95" cy="170" r="12" fill="#bfdbfe" stroke="#111827" stroke-width="4"/>
<circle cx="155" cy="170" r="12" fill="#bfdbfe" stroke="#111827" stroke-width="4"/>
<circle cx="215" cy="170" r="12" fill="#bfdbfe" stroke="#111827" stroke-width="4"/>
<circle cx="120" cy="230" r="12" fill="#c4b5fd" stroke="#111827" stroke-width="4"/>
<circle cx="180" cy="230" r="12" fill="#c4b5fd" stroke="#111827" stroke-width="4"/>
<circle cx="240" cy="230" r="12" fill="#c4b5fd" stroke="#111827" stroke-width="4"/>
<circle cx="145" cy="290" r="12" fill="#fed7aa" stroke="#111827" stroke-width="4"/>
<circle cx="205" cy="290" r="12" fill="#fed7aa" stroke="#111827" stroke-width="4"/>
<text class="small" x="175" y="352">64 -> 128 -> 256 -> 512</text>
</g>
<path class="arrow" d="M1320 495 H1410"/>
<g transform="translate(1430 295)">
<rect class="box orange" width="280" height="400"/>
<text class="label" x="140" y="58">Patch map</text>
<text class="desc" x="140" y="92">оценка real/fake</text>
<g transform="translate(78 135)">
<rect class="patch-grid" width="124" height="124" rx="8"/>
<rect class="p1" x="10" y="10" width="28" height="28"/>
<rect class="p2" x="48" y="10" width="28" height="28"/>
<rect class="p1" x="86" y="10" width="28" height="28"/>
<rect class="p2" x="10" y="48" width="28" height="28"/>
<rect class="p1" x="48" y="48" width="28" height="28"/>
<rect class="p3" x="86" y="48" width="28" height="28"/>
<rect class="p1" x="10" y="86" width="28" height="28"/>
<rect class="p2" x="48" y="86" width="28" height="28"/>
<rect class="p1" x="86" y="86" width="28" height="28"/>
</g>
<text class="small" x="140" y="310">каждая ячейка</text>
<text class="small" x="140" y="336">соответствует локальной</text>
<text class="small" x="140" y="362">области изображения</text>
</g>
<path class="arrow" d="M1570 720 V755 H1095 V785"/>
<g transform="translate(855 790)">
<rect class="box green" width="480" height="130"/>
<text class="label" x="240" y="50">Функция потерь</text>
<text class="desc" x="240" y="82">реальная пара -> 1, сгенерированная пара -> 0</text>
</g>
<text class="small" x="900" y="960">Главная идея PatchGAN: проверять не всю карту одним числом, а локальные признаки стиля и структуры.</text>
</svg>

After

Width:  |  Height:  |  Size: 6.3 KiB

View File

@@ -0,0 +1,119 @@
<svg xmlns="http://www.w3.org/2000/svg" width="1800" height="1000" viewBox="0 0 1800 1000">
<defs>
<marker id="arrow" markerWidth="16" markerHeight="16" refX="14" refY="8" orient="auto">
<path d="M0,0 L16,8 L0,16 Z" fill="#1f2937"/>
</marker>
<style>
.bg { fill: #fafafa; }
.title { font: 700 40px Arial, sans-serif; fill: #111827; }
.subtitle { font: 400 21px Arial, sans-serif; fill: #4b5563; }
.stage { font: 700 20px Arial, sans-serif; fill: #374151; text-anchor: middle; }
.box { stroke-width: 3; rx: 18; }
.blue { fill: #eaf3ff; stroke: #3b82f6; }
.orange { fill: #fff1df; stroke: #f59e0b; }
.green { fill: #eaf8ef; stroke: #22a06b; }
.violet { fill: #f0ebff; stroke: #7c3aed; }
.main { font: 700 28px Arial, sans-serif; fill: #111827; text-anchor: middle; }
.desc { font: 400 18px Arial, sans-serif; fill: #4b5563; text-anchor: middle; }
.small { font: 400 16px Arial, sans-serif; fill: #6b7280; text-anchor: middle; }
.arrow { stroke: #1f2937; stroke-width: 4; fill: none; marker-end: url(#arrow); }
.thin { stroke: #6b7280; stroke-width: 3; fill: none; marker-end: url(#arrow); }
.dash { stroke: #6b7280; stroke-width: 3; fill: none; stroke-dasharray: 10 9; marker-end: url(#arrow); }
.icon-line { stroke: #111827; stroke-width: 5; fill: none; stroke-linecap: round; stroke-linejoin: round; }
.icon-fill { fill: #ffffff; stroke: #111827; stroke-width: 4; }
.kp { fill: #16a34a; stroke: #ffffff; stroke-width: 3; }
.bad { fill: #ef4444; stroke: #ffffff; stroke-width: 3; }
</style>
</defs>
<rect class="bg" width="1800" height="1000"/>
<text class="title" x="90" y="80">Применение GAN для приведения карт к единому домену</text>
<text class="subtitle" x="90" y="118">Нейросеть меняет стиль карты, а дальнейшее сопоставление выполняется классическими геометрическими методами</text>
<text class="stage" x="255" y="190">Исходные данные</text>
<text class="stage" x="720" y="190">Доменная адаптация</text>
<text class="stage" x="1185" y="190">Сопоставление</text>
<text class="stage" x="1545" y="190">Результат</text>
<g transform="translate(90 235)">
<rect class="box blue" width="330" height="210"/>
<text class="main" x="165" y="54">Google Maps</text>
<text class="desc" x="165" y="86">фрагмент области полета</text>
<path class="icon-line" d="M85 145 C122 118, 164 169, 210 132 C230 116, 248 111, 272 110"/>
<path class="icon-line" d="M82 170 L280 118"/>
<circle class="kp" cx="116" cy="142" r="9"/>
<circle class="kp" cx="208" cy="132" r="9"/>
</g>
<g transform="translate(90 555)">
<rect class="box orange" width="330" height="210"/>
<text class="main" x="165" y="54">Яндекс.Карты</text>
<text class="desc" x="165" y="86">эталон с ориентирами</text>
<path class="icon-line" d="M85 145 C122 118, 164 169, 210 132 C230 116, 248 111, 272 110"/>
<path class="icon-line" d="M82 170 L280 118"/>
<circle class="kp" cx="116" cy="142" r="9"/>
<circle class="kp" cx="208" cy="132" r="9"/>
</g>
<path class="arrow" d="M440 340 H555"/>
<g transform="translate(575 300)">
<rect class="box orange" width="290" height="260"/>
<text class="main" x="145" y="58">GAN</text>
<text class="desc" x="145" y="91">Google -> Yandex</text>
<path class="icon-line" d="M80 170 L115 128 L150 170 L185 128 L220 170"/>
<text class="small" x="145" y="218">сохраняет геометрию</text>
<text class="small" x="145" y="242">меняет визуальный стиль</text>
</g>
<path class="arrow" d="M885 430 H1000"/>
<g transform="translate(1020 300)">
<rect class="box orange" width="330" height="260"/>
<text class="main" x="165" y="58">Сгенерированный кадр</text>
<text class="desc" x="165" y="91">карта в целевом домене</text>
<path class="icon-line" d="M85 166 C122 139, 164 190, 210 153 C230 137, 248 132, 272 131"/>
<path class="icon-line" d="M82 191 L280 139"/>
<circle class="kp" cx="116" cy="163" r="9"/>
<circle class="kp" cx="208" cy="153" r="9"/>
<circle class="kp" cx="250" cy="139" r="9"/>
</g>
<path class="dash" d="M420 660 C635 775, 940 775, 1040 580"/>
<path class="thin" d="M1185 580 V650"/>
<g transform="translate(1020 650)">
<rect class="box green" width="330" height="190"/>
<text class="main" x="165" y="52">Ключевые точки</text>
<text class="desc" x="165" y="84">ORB / SIFT / AKAZE</text>
<circle class="kp" cx="92" cy="132" r="10"/>
<circle class="kp" cx="142" cy="116" r="10"/>
<circle class="kp" cx="193" cy="139" r="10"/>
<circle class="kp" cx="244" cy="111" r="10"/>
<path class="icon-line" d="M90 132 L142 116 L193 139 L244 111"/>
</g>
<path class="arrow" d="M1370 745 H1445"/>
<g transform="translate(1465 650)">
<rect class="box green" width="245" height="190"/>
<text class="main" x="122" y="52">RANSAC</text>
<text class="desc" x="122" y="84">отбор совпадений</text>
<circle class="kp" cx="83" cy="130" r="10"/>
<circle class="kp" cx="124" cy="114" r="10"/>
<circle class="bad" cx="164" cy="137" r="10"/>
</g>
<path class="arrow" d="M1590 630 V530"/>
<g transform="translate(1425 300)">
<rect class="box violet" width="285" height="260"/>
<text class="main" x="142" y="58">Гомография H</text>
<text class="desc" x="142" y="91">связь координат</text>
<path class="icon-fill" d="M95 135 L172 118 L196 184 L112 201 Z"/>
<path class="icon-line" d="M103 212 L197 110"/>
<text class="small" x="142" y="238">коррекция позиции БПЛА</text>
</g>
<text class="small" x="900" y="930">Итог: изображения становятся визуально ближе, поэтому классические методы получают больше устойчивых совпадений.</text>
</svg>

After

Width:  |  Height:  |  Size: 6.1 KiB

View File

@@ -0,0 +1,104 @@
<svg xmlns="http://www.w3.org/2000/svg" width="1700" height="920" viewBox="0 0 1700 920">
<defs>
<marker id="arrow" markerWidth="12" markerHeight="12" refX="10" refY="6" orient="auto">
<path d="M0,0 L12,6 L0,12 Z" fill="#344054"/>
</marker>
<style>
.bg { fill: #fbfcfe; }
.title { font: 700 34px Arial, sans-serif; fill: #172033; }
.subtitle { font: 400 18px Arial, sans-serif; fill: #526070; }
.block { stroke: #344054; stroke-width: 2; rx: 10; }
.down { fill: #d9ecff; }
.bottle { fill: #fff0cf; }
.up { fill: #e6f5dc; }
.out { fill: #eee7ff; }
.label { font: 700 17px Arial, sans-serif; fill: #172033; text-anchor: middle; }
.small { font: 400 14px Arial, sans-serif; fill: #4b5563; text-anchor: middle; }
.axis { stroke: #344054; stroke-width: 3; fill: none; marker-end: url(#arrow); }
.skip { stroke: #d47f20; stroke-width: 3; fill: none; stroke-dasharray: 9 7; marker-end: url(#arrow); }
.arrow { stroke: #344054; stroke-width: 3; fill: none; marker-end: url(#arrow); }
.note { font: 400 16px Arial, sans-serif; fill: #667085; }
</style>
</defs>
<rect class="bg" width="1700" height="920"/>
<text class="title" x="80" y="70">Архитектура генератора U-Net: Google Maps -&gt; Яндекс.Карты</text>
<text class="subtitle" x="80" y="104">Сжатие изображения до признакового представления и восстановление в целевом стиле с сохранением геометрии через skip-соединения</text>
<g id="encoder">
<rect class="block down" x="80" y="230" width="120" height="360"/>
<text class="label" x="140" y="255">Input</text>
<text class="small" x="140" y="280">3 x 256 x 256</text>
<text class="small" x="140" y="550">Google RGB</text>
<rect class="block down" x="240" y="265" width="120" height="290"/>
<text class="label" x="300" y="292">Down 1</text>
<text class="small" x="300" y="318">64</text>
<text class="small" x="300" y="532">128 x 128</text>
<rect class="block down" x="400" y="295" width="120" height="230"/>
<text class="label" x="460" y="322">Down 2</text>
<text class="small" x="460" y="348">128</text>
<text class="small" x="460" y="502">64 x 64</text>
<rect class="block down" x="560" y="325" width="120" height="170"/>
<text class="label" x="620" y="352">Down 3</text>
<text class="small" x="620" y="378">256</text>
<text class="small" x="620" y="472">32 x 32</text>
<rect class="block down" x="720" y="350" width="120" height="120"/>
<text class="label" x="780" y="377">Down 4</text>
<text class="small" x="780" y="403">512</text>
<text class="small" x="780" y="448">16 x 16</text>
</g>
<rect class="block bottle" x="880" y="372" width="120" height="76"/>
<text class="label" x="940" y="399">Bottleneck</text>
<text class="small" x="940" y="425">512</text>
<g id="decoder">
<rect class="block up" x="1040" y="350" width="120" height="120"/>
<text class="label" x="1100" y="377">Up 4</text>
<text class="small" x="1100" y="403">512</text>
<text class="small" x="1100" y="448">16 x 16</text>
<rect class="block up" x="1200" y="325" width="120" height="170"/>
<text class="label" x="1260" y="352">Up 5</text>
<text class="small" x="1260" y="378">256</text>
<text class="small" x="1260" y="472">32 x 32</text>
<rect class="block up" x="1360" y="295" width="120" height="230"/>
<text class="label" x="1420" y="322">Up 6</text>
<text class="small" x="1420" y="348">128</text>
<text class="small" x="1420" y="502">64 x 64</text>
<rect class="block up" x="1520" y="265" width="120" height="290"/>
<text class="label" x="1580" y="292">Up 7</text>
<text class="small" x="1580" y="318">64</text>
<text class="small" x="1580" y="532">128 x 128</text>
</g>
<rect class="block out" x="1370" y="650" width="250" height="95"/>
<text class="label" x="1495" y="684">Final Conv + Tanh</text>
<text class="small" x="1495" y="711">3 x 256 x 256</text>
<text class="small" x="1495" y="734">Generated Yandex RGB</text>
<path class="arrow" d="M205 410 L235 410"/>
<path class="arrow" d="M365 410 L395 410"/>
<path class="arrow" d="M525 410 L555 410"/>
<path class="arrow" d="M685 410 L715 410"/>
<path class="arrow" d="M845 410 L875 410"/>
<path class="arrow" d="M1005 410 L1035 410"/>
<path class="arrow" d="M1165 410 L1195 410"/>
<path class="arrow" d="M1325 410 L1355 410"/>
<path class="arrow" d="M1485 410 L1515 410"/>
<path class="arrow" d="M1580 560 C1580 610, 1530 620, 1495 645"/>
<path class="skip" d="M300 255 C300 145, 1580 145, 1580 260"/>
<path class="skip" d="M460 290 C460 185, 1420 185, 1420 290"/>
<path class="skip" d="M620 320 C620 225, 1260 225, 1260 320"/>
<path class="skip" d="M780 345 C780 265, 1100 265, 1100 345"/>
<text class="note" x="90" y="805">Down-блок: Conv2d, BatchNorm, LeakyReLU. Up-блок: Upsample, Conv2d, BatchNorm, ReLU, затем конкатенация со skip-признаками.</text>
<text class="note" x="90" y="835">Назначение skip-соединений: сохранить дороги, перекрестки и контуры объектов при изменении визуального стиля карты.</text>
</svg>

After

Width:  |  Height:  |  Size: 5.3 KiB

View File

@@ -1,3 +1,39 @@
2.3.3 Применение архитектуры сиамских близнецов для вычисления матрицы гомографии между двумя кадрами
2.3.3 Генеративно-состязательная сеть для приведения карт к единому домену
При сопоставлении кадров БПЛА с эталонной картой возникает не только геометрическое, но и доменное различие изображений. Даже если два фрагмента описывают один и тот же участок местности, снимки из разных источников могут иметь разные цвета, толщину дорог, подписи, условные обозначения, контраст зданий и набор отображаемых ориентиров. Например, один и тот же район на карте Google и на карте Яндекс визуально отличается настолько, что классические алгоритмы поиска ключевых точек могут находить мало устойчивых совпадений или формировать большое число ложных соответствий.
Одним из способов уменьшить этот доменный разрыв является предварительное преобразование изображения из одного картографического домена в другой. В данной работе рассматривается генеративно-состязательная сеть (Generative Adversarial Network, GAN), которая переводит фрагмент карты Google в визуальный стиль карты Яндекс. После такого преобразования исходный фрагмент Google становится ближе к эталонному домену Яндекс, а значит для дальнейшей локализации можно применять классические методы выделения и сопоставления ключевых точек: ORB, SIFT, AKAZE, BRISK и последующую оценку матрицы гомографии при помощи RANSAC.
В отличие от модели сиамских близнецов, которая напрямую оценивает схожесть пары изображений, GAN решает вспомогательную задачу нормализации домена. То есть нейросеть не заменяет классический алгоритм сопоставления, а подготавливает данные так, чтобы классический алгоритм работал в более благоприятных условиях.
![Схема применения GAN в задаче навигации БПЛА](../../_media/gan_pipeline.svg)
Рисунок 8 Применение GAN для приведения картографических изображений к единому домену
Архитектура модели построена по принципу pix2pix, так как для обучения доступны парные изображения одного и того же участка местности в двух доменах: Google Maps и Яндекс.Карты. На вход генератора подается изображение Google размером \left(B,3,256,256\right), где B размер пакета данных. Генератор формирует изображение \hat{Y}, визуально соответствующее стилю Яндекс.Карт. Дискриминатор получает на вход пару изображений и должен определить, является ли пара реальной \left(G,Y\right) или сгенерированной \left(G,\hat{Y}\right), где G исходный фрагмент Google, Y настоящий фрагмент Яндекс, \hat{Y} результат работы генератора.
Генератор реализован в виде U-Net. Энкодер последовательно уменьшает пространственное разрешение изображения и извлекает признаки высокого уровня: структуру дорог, контуры кварталов, границы зданий, водные объекты и другие устойчивые элементы карты. Декодер восстанавливает изображение в целевом стиле. Между симметричными уровнями энкодера и декодера используются skip-соединения, которые передают локальную геометрию напрямую из ранних слоев в поздние. Это важно для навигационной задачи: модель должна изменить стиль карты, но не должна смещать дороги, перекрестки и контуры объектов, так как именно они затем используются как ориентиры.
![Архитектура генератора U-Net для преобразования Google в Яндекс](../../_media/gan_unet_generator.svg)
Рисунок 9 Архитектура генератора U-Net
Дискриминатор реализован как PatchGAN. В отличие от обычного дискриминатора, который выдает одно значение для всего изображения, PatchGAN оценивает реалистичность локальных областей. На вход дискриминатора подается конкатенация исходного изображения Google и изображения Яндекс по каналам, поэтому входной тензор имеет размер \left(B,6,256,256\right). Далее изображение проходит через сверточные блоки с постепенным уменьшением разрешения, а выходом является карта оценок для отдельных фрагментов. Такой подход подходит для картографических изображений, потому что локальные признаки ширина дорог, стиль подписей, границы объектов, цветовые переходы важнее глобальной художественной реалистичности.
![Архитектура дискриминатора PatchGAN](../../_media/gan_patchgan_discriminator.svg)
Рисунок 10 Архитектура дискриминатора PatchGAN
Обучение модели является состязательным. Генератор стремится сформировать такое изображение \hat{Y}, чтобы дискриминатор считал пару \left(G,\hat{Y}\right) реальной. Дискриминатор, наоборот, учится отличать настоящие пары \left(G,Y\right) от сгенерированных. Для сохранения геометрии карты одной только состязательной функции потерь недостаточно, поэтому итоговая функция потерь генератора включает несколько компонентов:
L_G = \lambda_{GAN}L_{GAN}\left(D\left(G,\hat{Y}\right),1\right)+\lambda_{L1}\left\|\hat{Y}-Y\right\|_1+\lambda_{SSIM}L_{SSIM}\left(\hat{Y},Y\right)+\lambda_{edge}L_{edge}\left(\hat{Y},Y\right) (1)
где L_{GAN} состязательная функция потерь, L1 попиксельная ошибка между сгенерированным и настоящим изображением Яндекс, L_{SSIM} структурная ошибка, сохраняющая сходство локальной структуры, L_{edge} ошибка по картам границ, вычисленным оператором Собеля. Коэффициенты \lambda_{GAN}, \lambda_{L1}, \lambda_{SSIM} и \lambda_{edge} задают вклад каждого компонента. В реализованной модели используются значения \lambda_{GAN}=0.5, \lambda_{L1}=150, \lambda_{SSIM}=25, \lambda_{edge}=20. Усиленные L1, SSIM и edge-компоненты делают модель менее «творческой», но лучше сохраняют контуры дорог и объектов, что важнее для последующего поиска ключевых точек.
Функция потерь дискриминатора имеет вид:
L_D = \frac{1}{2}\left(L_{GAN}\left(D\left(G,Y\right),1\right)+L_{GAN}\left(D\left(G,\hat{Y}\right),0\right)\right) (2)
После обучения GAN может использоваться в навигационном пайплайне следующим образом. Сначала для области предполагаемого положения БПЛА загружается или выбирается фрагмент Google Maps. Затем генератор переводит этот фрагмент в стиль Яндекс.Карт. На полученном изображении и на эталонном фрагменте Яндекс.Карт выделяются ключевые точки и дескрипторы. Далее дескрипторы сопоставляются, ложные соответствия отбрасываются при помощи RANSAC, а по оставшимся точкам оценивается матрица гомографии. Полученная матрица позволяет связать координаты текущего изображения с координатами эталонной карты и уточнить положение БПЛА.
Таким образом, GAN выступает как промежуточный модуль доменной адаптации. Его применение особенно полезно в ситуации, когда источник доступной карты и источник эталонных ориентиров различаются. В рассматриваемой задаче это позволяет перевести изображения Google в представление, близкое к Яндекс.Картам, где ориентиры визуально согласованы между собой. Благодаря этому классические методы компьютерного зрения получают более похожие изображения и могут устойчивее находить ключевые точки, не требуя полного отказа от интерпретируемого геометрического пайплайна.

View File

@@ -6,3 +6,4 @@
|-----------|----------|------|
| 2.3.1 | Применение архитектуры сиамских близнецов для сопоставления кадров из различных доменов | 2.3.1_siamese_match.md |
| 2.3.2 | Применение архитектуры сиамских близнецов для вычисления матрицы гомографии | 2.3.2_siamese_homography.md |
| 2.3.3 | Генеративно-состязательная сеть для приведения карт к единому домену | 2.3.3_additional.md |

View File

@@ -9,5 +9,6 @@
| 2.3 | Методы глубокого обучения | 2.3_deep_learning/ |
| 2.3.1 | Сиамские близнецы для сопоставления кадров из различных доменов | 2.3_deep_learning/2.3.1_siamese_match.md |
| 2.3.2 | Сиамские близнецы для вычисления матрицы гомографии | 2.3_deep_learning/2.3.2_siamese_homography.md |
| 2.3.3 | Генеративно-состязательная сеть для приведения карт к единому домену | 2.3_deep_learning/2.3.3_additional.md |
| 2.4 | Датасет | 2.4_dataset/ |
| 2.5 | Обучение моделей глубокого обучения | 2.5_training/ |

View File

@@ -0,0 +1,41 @@
google-city-1 --use-sian-similarity
MSE: 1006.2278793057307
RMSE: 31.721095178220608
Average FPS: 2.1113510203551025
google-city-2 --use-sian-similarity
MSE: 0.018811735940268307
RMSE: 0.13715588190182842
Average FPS: 4.383734633135686
google-city-3 --use-sian-similarity
MSE: 0.06464743825147634
RMSE: 0.2542586050686905
Average FPS: 4.047345287474747
google-city-4 --use-sian-similarity
MSE: 0.3089718382086463
RMSE: 0.5558523528857697
Average FPS: 2.9425785696773636
google-city-1
MSE: 716.9078171684257
RMSE: 26.775134307196776
Average FPS: 38.766897877580504
google-city-2
MSE: 0.06237516673056363
RMSE: 0.2497502086697099
Average FPS: 25.821371544698586
google-city-3
MSE: 0.07598762118336233
RMSE: 0.2756585227838282
Average FPS: 26.37680633915316
google-city-3
MSE: 0.7199724273833397
RMSE: 0.8485118899481254
Average FPS: 29.534802136097262

18
main.py
View File

@@ -120,7 +120,7 @@ def build(name: str, map_name: str, lat: float, lon: float):
sleep(15)
online_map.destroy()
def run(name: str, map_name: str, ref_min_distance: float):
def run(name: str, map_name: str, ref_min_distance: float, use_sian_similarity: bool = False):
dir = Path('trajectories')
assert dir.exists()
dir /= name
@@ -158,7 +158,13 @@ def run(name: str, map_name: str, ref_min_distance: float):
print("READ POINTS:", points)
vis_manager = VisualizationManager()
pilot = autopilot.AutoPilot(points, chunks, vis_manager, online_map.pixel_ratio)
pilot = autopilot.AutoPilot(
points,
chunks,
vis_manager,
online_map.pixel_ratio,
use_sian_similarity=use_sian_similarity,
)
simulator = Simulator(online_map)
pilot.target_idx = 0
@@ -300,6 +306,12 @@ def parse_args():
help='Включить отладку эталонов'
)
parser.add_argument(
'--use-sian-similarity',
action='store_true',
help='Выбирать ориентир через SiaN similarity вместо ближайшего по текущей позиции'
)
# Парсим аргументы
args = parser.parse_args()
@@ -327,4 +339,4 @@ if __name__ == "__main__":
build(name, ref, lat, lon)
if mode == 'run' or mode == 'standalone':
run(name, sim, rmd)
run(name, sim, rmd, args.use_sian_similarity)

BIN
map.jpg

Binary file not shown.

Before

Width:  |  Height:  |  Size: 3.9 MiB

After

Width:  |  Height:  |  Size: 424 KiB

View File

@@ -1,254 +1,111 @@
# GAN Trainer для преобразования изображений Yandex → Google
Этот модуль содержит реализацию тренера для GAN (Generative Adversarial Network) модели, предназначенной для преобразования изображений карт Yandex в стиль Google Maps.
## Структура проекта
```
autopilot/models/GAN/
├── gan.py # Основная реализация GAN модели
├── trainer.py # Тренер для обучения GAN
├── test_trainer.py # Тесты для тренера
├── train_example.py # Пример использования тренера
── README.md # Этот файл
```
## Модель GAN
Модель состоит из двух основных компонентов:
### 1. Генератор (GeneratorUNet)
- Архитектура U-Net для преобразования изображений
- Принимает изображение Yandex (3 канала RGB)
- Возвращает изображение в стиле Google (3 канала RGB)
- Использует skip connections для сохранения деталей
### 2. Дискриминатор (DiscriminatorPatchGAN)
- PatchGAN архитектура
- Принимает пару изображений (Yandex + Google)
- Возвращает вероятность того, что пара реальная
- Работает с патчами изображения 41x41
### Функция потерь (GANLoss)
Поддерживает три режима:
- `vanilla`: Бинарная кросс-энтропия
- `lsgan`: Least Squares GAN (более стабильный)
- `wgangp`: Wasserstein GAN with Gradient Penalty
## Тренер (GANTrainer)
### Основные возможности
1. **Обучение с чередованием**:
- Обучение генератора и дискриминатора поочередно
- Поддержка L1 потерь для сохранения структуры
2. **Валидация и мониторинг**:
- Отдельные потери для генератора и дискриминатора
- Логирование в TensorBoard
- Ранняя остановка
3. **Сохранение и загрузка**:
- Чекпоинты каждой эпохи
- Лучшая модель
- Финальная модель
- История обучения
4. **Оценка модели**:
- Метрики на тестовом наборе
- Генерация примеров
### Быстрый старт
```python
import torch
from torch.utils.data import DataLoader
from models.GAN.gan import create_image_gan
from models.GAN.trainer import GANTrainer
# Конфигурация
config = {
"learning_rate": 2e-4,
"beta1": 0.5,
"beta2": 0.999,
"batch_size": 4,
"output_dir": "runs/gan_training",
"gan_mode": "vanilla",
"lambda_L1": 100.0,
"early_stopping_patience": 20,
}
# Устройство
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Создание модели
model = create_image_gan(
input_channels=3,
output_channels=3,
gan_mode=config["gan_mode"],
lambda_L1=config["lambda_L1"],
use_cuda=(device.type == "cuda"),
)
# Создание даталоадеров (замените на свои данные)
train_loader = DataLoader(train_dataset, batch_size=config["batch_size"], shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=config["batch_size"], shuffle=False)
# Создание тренера
trainer = GANTrainer(
model=model,
train_loader=train_loader,
val_loader=val_loader,
device=device,
config=config,
)
# Обучение
trainer.train(num_epochs=100)
# Оценка
metrics = trainer.evaluate(test_loader)
```
### Конфигурация обучения
#### Базовая конфигурация
```python
config = {
# Параметры оптимизатора
"learning_rate": 2e-4, # Learning rate
"beta1": 0.5, # Adam beta1
"beta2": 0.999, # Adam beta2
# Параметры обучения
"batch_size": 4, # Размер батча
"epochs": 100, # Количество эпох
# Параметры GAN
"gan_mode": "vanilla", # Режим GAN
"lambda_L1": 100.0, # Вес L1 потерь
# Регуляризация
"grad_clip": 1.0, # Gradient clipping
# Ранняя остановка
"early_stopping_patience": 20,
# Выходные данные
"output_dir": "runs/gan",
}
```
#### Расширенная конфигурация
```python
config = {
"learning_rate": 2e-4,
"beta1": 0.5,
"beta2": 0.999,
"batch_size": 8,
"epochs": 200,
"gan_mode": "lsgan", # Более стабильный LSGAN
"lambda_L1": 100.0,
"grad_clip": 1.0,
"weight_decay": 1e-4, # Weight decay
"early_stopping_patience": 30,
"early_stopping_min_delta": 1e-4,
"output_dir": "runs/gan_advanced",
}
```
### Методы тренера
#### Основные методы
- `train_epoch()`: Обучение на одной эпохе
- `validate()`: Валидация модели
- `train(num_epochs)`: Полное обучение
- `evaluate(test_loader)`: Оценка на тестовых данных
#### Управление чекпоинтами
- `save_checkpoint(is_best=False)`: Сохранение чекпоинта
- `load_checkpoint(path, resume_training=False)`: Загрузка чекпоинта
### Выходные файлы
После обучения создаются следующие файлы:
```
runs/gan_training/
├── config.json # Конфигурация обучения
├── training_history.json # История потерь
├── model_best.pth # Лучшая модель
├── model_final.pth # Финальная модель
├── checkpoint_epoch_1.pth # Чекпоинты каждой эпохи
├── checkpoint_epoch_2.pth
├── ...
└── tensorboard/ # Логи TensorBoard
├── events.out.tfevents...
└── ...
```
### TensorBoard
Для визуализации обучения используйте TensorBoard:
```bash
tensorboard --logdir runs/gan_training/tensorboard
```
Доступные метрики:
- `train/batch_g_loss`: Потери генератора на батче
- `train/batch_d_loss`: Потери дискриминатора на батче
- `train/batch_g_l1_loss`: L1 потери генератора
- `train/epoch_g_loss`: Потери генератора на эпохе
- `train/epoch_d_loss`: Потери дискриминатора на эпохе
- `val/epoch_g_loss`: Валидационные потери генератора
- `val/epoch_d_loss`: Валидационные потери дискриминатора
### Тестирование
Запустите тесты для проверки работоспособности:
```bash
python models/GAN/test_trainer.py
```
### Пример использования
Полный пример использования смотрите в `train_example.py`.
### Советы по обучению
1. **Начальные значения**:
- Используйте `gan_mode="lsgan"` для более стабильного обучения
- Начните с `lambda_L1=100.0` и регулируйте по необходимости
- Используйте маленький `batch_size` (4-8) при ограниченной памяти GPU
2. **Мониторинг**:
- Следите за балансом потерь генератора и дискриминатора
- Если потери дискриминатора близки к 0, генератор не обучается
- Если потери генератора слишком высоки, уменьшите `lambda_L1`
3. **Визуализация**:
- Регулярно генерируйте примеры для визуальной оценки
- Используйте TensorBoard для отслеживания прогресса
### Устранение проблем
#### Высокие потери генератора
- Уменьшите `lambda_L1`
- Увеличьте learning rate
- Проверьте качество данных
#### Дискриминатор слишком сильный
- Уменьшите learning rate дискриминатора
- Добавьте dropout в дискриминатор
- Обучайте генератор чаще, чем дискриминатор
#### Недостаток памяти GPU
- Уменьшите `batch_size`
- Уменьшите размер изображений
- Используйте gradient accumulation
### Лицензия
Этот проект является частью Autopilot системы.
# GAN для преобразования Google -> Yandex
Модуль реализует pix2pix-подобный GAN для парных изображений карт. Генератор получает изображение Google из пары и пытается сгенерировать изображение в стиле Yandex. Сгенерированная картинка сравнивается со вторым изображением пары, то есть с оригинальным `*_yandex.png`.
## Структура
```text
models/GAN/
├── build.py # генератор notebook.gen.ipynb по схеме
├── _schema.py # структура генерируемого ноутбука
├── _schema.md # описание формата схемы
├── notebook.gen.ipynb
── src/
│ ├── config.py
│ ├── dataloader.py
│ ├── model.py
│ ├── trainer.py
│ ├── analyze.py
│ └── main.py
└── README.md
```
## Архитектура
`GeneratorUNet` принимает `google_img` с 3 RGB-каналами и возвращает `fake_yandex`.
`DiscriminatorPatchGAN` получает пару `(google_img, yandex_img)` и отличает настоящую пару от `(google_img, fake_yandex)`.
`ImageGAN.generator_step()` считает:
- adversarial loss: дискриминатор должен принять `(google_img, fake_yandex)` за реальную пару;
- L1 loss: `fake_yandex` сравнивается с оригинальным `yandex_img` из той же пары.
- SSIM loss: штрафует структурные отличия, чтобы карта не расплывалась;
- Sobel edge loss: сохраняет контуры дорог и объектов для последующего поиска ключевых точек.
Итоговая функция потерь генератора:
```text
G_loss =
lambda_GAN * GAN_loss(D(google_img, fake_yandex), real)
+ lambda_L1 * L1(fake_yandex, yandex_img)
+ lambda_SSIM * SSIMLoss(fake_yandex, yandex_img)
+ lambda_edge * SobelEdgeLoss(fake_yandex, yandex_img)
```
По умолчанию используется `lsgan`, усиленная реконструкция и более медленный
дискриминатор. Это менее “креативно”, зато обычно даёт более чистые контуры.
## Запуск
Из папки `models/GAN`:
```bash
python src/main.py
```
`src/main.py` специально написан без функции `main()`: при генерации ноутбука
все переменные остаются доступны после запуска ячейки, как в SiaN.
Чтобы пересобрать ноутбук из схемы:
```bash
python build.py
```
По умолчанию датасет берется из:
```text
C:\Users\admin\Projects\autopilot\datasets\ya_go_maps\images
```
Путь, размер изображений, batch size и число эпох меняются в `src/config.py`.
Если CUDA видна, но установленный PyTorch не поддерживает архитектуру GPU
например Tesla P100 `sm_60`, код автоматически переключится на CPU. Это
управляется параметром `prefer_cuda` в `src/config.py`.
## Ожидаемый формат данных
В директории датасета должны лежать пары:
```text
0000_google.png
0000_yandex.png
0001_google.png
0001_yandex.png
...
```
Даталоадер возвращает:
- `google_img`: вход генератора;
- `yandex_img`: целевое изображение для сравнения;
- `idx`: номер пары.
## Чекпоинты
Тренер сохраняет чекпоинты в:
```text
models/GAN/runs/checkpoints
```
Сохраняются `best.pth`, периодические `epoch_N.pth` и `final.pth`.
После обучения `src/main.py` также строит:
- графики `G/D/L1/SSIM/edge loss`;
- сетку `Google input -> Generated Yandex -> Yandex target`.
Файлы сохраняются в `models/GAN/runs/images`.

32
models/GAN/_schema.md Normal file
View File

@@ -0,0 +1,32 @@
# GAN Schema
Notebook structure definition for the Google -> Yandex GAN model.
---
## Format
```text
# === IMPORTS ===
<all imports>
# markdown
"""Description"""
# code: ./src/file.py
# # shell:
<shell commands>
```
## Directives
| Directive | Description |
|----------------|------------------------------------|
| `# code:` | Include file from `src/` |
| `# markdown` | Add markdown cell |
| `# # shell:` | Add notebook shell command cell |
---
`build.py` generates `notebook.gen.ipynb` from `_schema.py`.

121
models/GAN/_schema.py Normal file
View File

@@ -0,0 +1,121 @@
# _schema.py
# === IMPORTS ===
import json
import os
from pathlib import Path
from typing import Any, Dict, List, Tuple
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
from PIL import Image
from torch.utils.data import DataLoader, Dataset, Subset
from torchvision import transforms
from tqdm import tqdm
# markdown
"""# Configuration
Global settings for the Google -> Yandex GAN:
- Dataset path and image size
- Optimizer and training hyperparameters
- Device preference with safe CUDA compatibility checks
- GAN, L1, SSIM and edge reconstruction weights
- Output directories for checkpoints and generated samples"""
# code: ./src/config.py
# markdown
"""## Dataset
Google/Yandex paired image loader.
**Direction:**
- `google_img` is the generator input
- `yandex_img` is the target image from the same pair
**Returns:**
- Batch dict with `google_img`, `yandex_img`, `idx`"""
# code: ./src/dataloader.py
# code: ./src/test_dataloader.py
# markdown
"""## Model
Pix2pix-style GAN for Google -> Yandex map translation.
**Generator:**
- `GeneratorUNet`
- Input: Google image `(B, 3, H, W)`
- Output: generated Yandex image `(B, 3, H, W)`
**Discriminator:**
- `DiscriminatorPatchGAN`
- Input pair: `(google_img, yandex_img)`
- Learns to distinguish real pairs from `(google_img, fake_yandex)`
**Generator loss:**
- adversarial loss
- `lambda_L1 * L1(fake_yandex, yandex_img)`
- `lambda_SSIM * SSIMLoss(fake_yandex, yandex_img)`
- `lambda_edge * SobelEdgeLoss(fake_yandex, yandex_img)`
The generator uses bilinear upsampling followed by convolution to avoid
checkerboard artifacts from transposed convolutions."""
# code: ./src/model.py
# markdown
"""## Training
`GANTrainer` trains discriminator and generator alternately.
**Training step:**
1. Generate `fake_yandex = G(google_img)`
2. Train discriminator on real pair `(google_img, yandex_img)` and fake pair `(google_img, fake_yandex)`
3. Train generator against discriminator and paired Yandex target
**Checkpoint saving:**
- `best.pth`
- `epoch_N.pth`
- `final.pth`"""
# code: ./src/trainer.py
# markdown
"""## Analysis
Visualization helpers for generated samples and collected training metrics.
Training history plot contains:
1. Generator loss
2. Discriminator loss
3. L1 loss against the paired Yandex target
4. SSIM structure loss
5. Sobel edge loss
6. Best-checkpoint reconstruction score
The sample grid contains:
1. Google input
2. Generated Yandex
3. Original Yandex target"""
# code: ./src/analyze.py
# markdown
"""## Main Pipeline
Executes the full GAN workflow:
1. Create config
2. Build paired data loaders
3. Initialize Google -> Yandex GAN
4. Train with validation
5. Save checkpoints in `runs/checkpoints/`
6. Show loss plots and generated sample grid
This block is intentionally top-level, not wrapped in `main()`, so notebook
variables such as `model`, `trainer`, `train_loader`, `val_loader`, and
`training_analysis` remain available for debugging."""
# code: ./src/main.py
# # shell:
# !zip artefacts.zip runs/checkpoints/best.pth runs/images/training_history.png runs/images/generation_samples.png

260
models/GAN/build.py Normal file
View File

@@ -0,0 +1,260 @@
import json
import os
import re
def parse_schema(schema_path):
with open(schema_path, 'r', encoding='utf-8') as f:
content = f.read()
imports = []
items = []
in_imports_section = False
lines = content.split('\n')
i = 0
while i < len(lines):
line = lines[i]
stripped = line.strip()
if stripped == '# === IMPORTS ===':
in_imports_section = True
i += 1
continue
elif stripped.startswith('# ==='):
in_imports_section = False
i += 1
continue
if in_imports_section and stripped and not stripped.startswith('#'):
imports.append(stripped)
i += 1
continue
if in_imports_section and stripped.startswith('#'):
in_imports_section = False
if stripped.startswith('# code:'):
match = re.search(r'# code:\s*(.+)', stripped)
if match:
items.append(('code', match.group(1).strip()))
i += 1
continue
if stripped.startswith('# inline:') or stripped.startswith('# # shell:'):
directive = 'inline' if stripped.startswith('# inline:') else 'shell'
i += 1
code_lines = []
while i < len(lines):
stripped = lines[i].strip()
if directive == 'shell':
if stripped.startswith('# ') or stripped.startswith('# # '):
line_content = lines[i].strip()
if line_content.startswith('# # '):
line_content = line_content[3:].lstrip()
else:
line_content = line_content[2:]
code_lines.append(line_content)
i += 1
continue
if stripped.startswith('#'):
if stripped.startswith('# code:') or stripped.startswith('# inline:') or stripped.startswith('# # shell:') or stripped == '# markdown' or stripped.startswith('# ==='):
break
i += 1
continue
if stripped == '':
i += 1
continue
code_lines.append(lines[i])
i += 1
if code_lines:
items.append((directive, '\n'.join(code_lines).rstrip()))
continue
if stripped == '# markdown':
i += 1
if i < len(lines):
next_line = lines[i].strip()
if next_line.startswith('"""'):
if next_line.endswith('"""') and len(next_line) > 3:
md_content = next_line[3:-3].strip()
i += 1
else:
end_idx = None
for j in range(i + 1, len(lines)):
if '"""' in lines[j]:
end_idx = j
break
if end_idx:
md_content = '\n'.join(lines[i:end_idx])
md_content = md_content.strip('"""').strip()
i = end_idx + 1
else:
md_content = ""
i += 1
items.append(('markdown', md_content))
continue
i += 1
return imports, items
def strip_all_imports(content):
lines = content.split('\n')
result_lines = []
skip_block = False
if_block_indent = 0
for line in lines:
stripped = line.strip()
if stripped.startswith('if __name__') or stripped.startswith('if __name__ =='):
skip_block = True
if_block_indent = len(line) - len(line.lstrip())
continue
if skip_block:
current_indent = len(line) - len(line.lstrip())
if line.strip() == '':
continue
if current_indent < if_block_indent:
skip_block = False
elif current_indent == if_block_indent and stripped.startswith('if '):
skip_block = True
if_block_indent = current_indent
continue
else:
continue
if stripped.startswith('import ') or stripped.startswith('from '):
continue
result_lines.append(line)
while result_lines and result_lines[-1].strip() == '':
result_lines.pop()
return '\n'.join(result_lines)
def read_src_file(ref, base_dir):
ref_path = ref.replace('./src/', '').replace('src/', '')
full_path = os.path.join(base_dir, 'src', ref_path.replace('./src/', '').lstrip('/'))
if not full_path.endswith('.py'):
full_path += '.py'
with open(full_path, 'r', encoding='utf-8') as f:
return f.read()
def build_notebook(schema_path, output_path=None):
schema_dir = os.path.dirname(os.path.abspath(schema_path))
if output_path is None:
output_path = os.path.join(schema_dir, 'notebook.gen.ipynb')
imports, items = parse_schema(schema_path)
cells = []
if imports:
imports_cell = {
"cell_type": "code",
"execution_count": None,
"metadata": {},
"outputs": [],
"source": [imp + "\n" for imp in imports]
}
cells.append(imports_cell)
for item_type, item_content in items:
if item_type == 'markdown':
md_cell = {
"cell_type": "markdown",
"metadata": {},
"source": item_content + "\n"
}
cells.append(md_cell)
elif item_type == 'inline':
lines = item_content.split('\n')
while lines and lines[-1].strip() == '':
lines.pop()
if lines:
lines.append('')
cell = {
"cell_type": "code",
"execution_count": None,
"metadata": {},
"outputs": [],
"source": [line + "\n" for line in lines]
}
cells.append(cell)
elif item_type == 'shell':
lines = item_content.split('\n')
while lines and lines[-1].strip() == '':
lines.pop()
if lines:
lines.append('')
cell = {
"cell_type": "code",
"execution_count": None,
"metadata": {},
"outputs": [],
"source": [line + "\n" for line in lines]
}
cells.append(cell)
elif item_type == 'code':
try:
content = read_src_file(item_content, schema_dir)
content = strip_all_imports(content)
lines = content.split('\n')
while lines and lines[-1].strip() == '':
lines.pop()
if lines:
lines.append('')
cell = {
"cell_type": "code",
"execution_count": None,
"metadata": {},
"outputs": [],
"source": [line + "\n" for line in lines]
}
cells.append(cell)
except FileNotFoundError:
print(f"Warning: Could not find: {item_content}")
notebook = {
"cells": cells,
"metadata": {
"kernelspec": {
"display_name": ".venv",
"language": "python",
"name": "python3"
},
"language_info": {
"name": "python",
"version": "3.11.0"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
with open(output_path, 'w', encoding='utf-8') as f:
json.dump(notebook, f, indent=1, ensure_ascii=False)
print(f"Notebook generated: {output_path}")
if __name__ == "__main__":
script_dir = os.path.dirname(os.path.abspath(__file__))
schema_path = os.path.join(script_dir, '_schema.py')
if os.path.exists(schema_path):
build_notebook(schema_path)
else:
print(f"Error: _schema.py not found at {schema_path}")

View File

@@ -1,30 +0,0 @@
"""Main entry point for GAN training."""
from config import create_config
from dataloader import create_data_loaders
from model import create_gan
from trainer import create_trainer
def main():
"""Run training pipeline."""
config = create_config()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Create components
model = create_gan(use_cuda=False) # Set to True to use GPU
train_loader, val_loader = create_data_loaders(
root_dir=config["data_dir"],
batch_size=config["batch_size"],
image_size=tuple(config["image_size"]),
num_workers=config["num_workers"],
)
trainer = create_trainer(model, train_loader, val_loader, config)
# Train
trainer.train(config["epochs"])
if __name__ == "__main__":
import torch
main()

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1 @@
"""GAN package for Google-to-Yandex map image translation."""

117
models/GAN/src/analyze.py Normal file
View File

@@ -0,0 +1,117 @@
from pathlib import Path
import matplotlib.pyplot as plt
import torch
def denormalize_image(tensor):
return (tensor.detach().cpu() * 0.5 + 0.5).clamp(0, 1)
@torch.no_grad()
def visualize_generation(model, data_loader, output_dir, device=None, num_samples=4, show=True):
device = device or model.device
model.eval()
batch = next(iter(data_loader))
google_img = batch["google_img"][:num_samples].to(device)
yandex_img = batch["yandex_img"][:num_samples].to(device)
fake_yandex = model.generator(google_img)
output_dir = Path(output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
output_path = output_dir / "generation_samples.png"
fig, axes = plt.subplots(num_samples, 3, figsize=(9, 3 * num_samples))
if num_samples == 1:
axes = axes.reshape(1, 3)
titles = ["Google input", "Generated Yandex", "Yandex target"]
for row in range(num_samples):
images = [google_img[row], fake_yandex[row], yandex_img[row]]
for col, image in enumerate(images):
axes[row, col].imshow(denormalize_image(image).permute(1, 2, 0))
axes[row, col].set_title(titles[col])
axes[row, col].axis("off")
fig.tight_layout()
fig.savefig(output_path, dpi=150)
if show:
plt.show()
plt.close(fig)
return output_path
def plot_training_history(trainer, output_dir, show=True):
output_dir = Path(output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
output_path = output_dir / "training_history.png"
epochs = range(1, len(trainer.g_losses) + 1)
fig, axes = plt.subplots(2, 3, figsize=(15, 8))
axes = axes.ravel()
axes[0].plot(epochs, trainer.g_losses, label="train G")
axes[0].plot(epochs, trainer.val_g_losses, label="val G")
axes[0].set_title("Generator loss")
axes[0].set_xlabel("Epoch")
axes[0].legend()
axes[0].grid(True, alpha=0.3)
axes[1].plot(epochs, trainer.d_losses, label="train D")
axes[1].plot(epochs, trainer.val_d_losses, label="val D")
axes[1].set_title("Discriminator loss")
axes[1].set_xlabel("Epoch")
axes[1].legend()
axes[1].grid(True, alpha=0.3)
axes[2].plot(epochs, trainer.l1_losses, label="train L1")
axes[2].plot(epochs, trainer.val_l1_losses, label="val L1")
axes[2].set_title("Paired Yandex L1 loss")
axes[2].set_xlabel("Epoch")
axes[2].legend()
axes[2].grid(True, alpha=0.3)
axes[3].plot(epochs, trainer.ssim_losses, label="train SSIM")
axes[3].plot(epochs, trainer.val_ssim_losses, label="val SSIM")
axes[3].set_title("SSIM structure loss")
axes[3].set_xlabel("Epoch")
axes[3].legend()
axes[3].grid(True, alpha=0.3)
axes[4].plot(epochs, trainer.edge_losses, label="train edge")
axes[4].plot(epochs, trainer.val_edge_losses, label="val edge")
axes[4].set_title("Sobel edge loss")
axes[4].set_xlabel("Epoch")
axes[4].legend()
axes[4].grid(True, alpha=0.3)
axes[5].plot(epochs, trainer.val_reconstruction_losses, label="val reconstruction")
axes[5].set_title("Best-checkpoint score")
axes[5].set_xlabel("Epoch")
axes[5].legend()
axes[5].grid(True, alpha=0.3)
fig.tight_layout()
fig.savefig(output_path, dpi=150)
if show:
plt.show()
plt.close(fig)
return output_path
def analyze_training(trainer):
return {
"best_val_loss": trainer.best_val_loss,
"g_losses": trainer.g_losses,
"d_losses": trainer.d_losses,
"l1_losses": trainer.l1_losses,
"ssim_losses": trainer.ssim_losses,
"edge_losses": trainer.edge_losses,
"val_g_losses": trainer.val_g_losses,
"val_d_losses": trainer.val_d_losses,
"val_l1_losses": trainer.val_l1_losses,
"val_ssim_losses": trainer.val_ssim_losses,
"val_edge_losses": trainer.val_edge_losses,
"val_reconstruction_losses": trainer.val_reconstruction_losses,
}

View File

@@ -6,23 +6,30 @@ def create_config():
return {
# Optimizer params
"learning_rate": 2e-4,
"discriminator_lr_factor": 0.5,
"beta1": 0.5,
"beta2": 0.999,
# Training params
"batch_size": 32,
"epochs": 100,
"prefer_cuda": True,
# GAN params
"gan_mode": "vanilla",
"lambda_L1": 100.0,
"gan_mode": "lsgan",
"lambda_GAN": 0.5,
"lambda_L1": 150.0,
"lambda_SSIM": 25.0,
"lambda_edge": 20.0,
"discriminator_update_interval": 1,
# Regularization
"grad_clip": 1.0,
# Early stopping
"early_stopping_patience": 20,
"early_stopping_patience": 25,
# Output
"output_dir": "runs/gan_training",
"output_dir": r"C:\Users\admin\Projects\autopilot\models\GAN\runs",
# Logging
"log_interval": 10,
"save_interval": 5,
"num_visual_samples": 4,
# Data
"data_dir": r"C:\Users\admin\Projects\autopilot\datasets\ya_go_maps\images",
"image_size": [256, 256],
@@ -33,4 +40,4 @@ def create_config():
if __name__ == "__main__":
config = create_config()
print("Default config:", config)
print("Default config:", config)

View File

@@ -1,4 +1,4 @@
"""Data loader for Yandex-to-Google image translation."""
"""Data loader for Google-to-Yandex image translation."""
import os
from typing import Dict, List, Tuple
@@ -10,7 +10,7 @@ from torchvision import transforms
class YaGoDataset(Dataset):
"""Dataset loading pairs of Yandex and Google map images."""
"""Dataset loading paired Google and Yandex map images."""
def __init__(
self,
@@ -159,4 +159,4 @@ if __name__ == "__main__":
batch = next(iter(train_loader))
print(f"Batch shapes: google={batch['google_img'].shape}, yandex={batch['yandex_img'].shape}")
print(f"Train batches: {len(train_loader)}, Val batches: {len(val_loader)}")
print(f"Train batches: {len(train_loader)}, Val batches: {len(val_loader)}")

60
models/GAN/src/main.py Normal file
View File

@@ -0,0 +1,60 @@
"""Executable GAN training pipeline.
The code is intentionally top-level, mirroring the SiaN notebook style:
when this file is included in the generated notebook, variables remain
available for debugging and interactive experiments.
"""
from pathlib import Path
import torch
from analyze import analyze_training, plot_training_history, visualize_generation
from config import create_config
from dataloader import create_data_loaders
from model import create_gan, get_compatible_device
from trainer import create_trainer
config = create_config()
device = get_compatible_device(prefer_cuda=config["prefer_cuda"])
print(f"Using device: {device}")
train_loader, val_loader = create_data_loaders(
root_dir=config["data_dir"],
batch_size=config["batch_size"],
train_split=config["train_split"],
image_size=tuple(config["image_size"]),
num_workers=config["num_workers"],
)
print(f"Data loaders created: train={len(train_loader.dataset)}, val={len(val_loader.dataset)}")
model = create_gan(
gan_mode=config["gan_mode"],
lambda_GAN=config["lambda_GAN"],
lambda_L1=config["lambda_L1"],
lambda_SSIM=config["lambda_SSIM"],
lambda_edge=config["lambda_edge"],
use_cuda=(device.type == "cuda"),
)
generator_params = sum(p.numel() for p in model.generator.parameters())
discriminator_params = sum(p.numel() for p in model.discriminator.parameters())
print(f"Model created: generator={generator_params:,}, discriminator={discriminator_params:,}")
trainer = create_trainer(model, train_loader, val_loader, config)
trainer.train(config["epochs"])
training_analysis = analyze_training(trainer)
images_dir = Path(config["output_dir"]) / "images"
history_plot_path = plot_training_history(trainer, images_dir)
generation_samples_path = visualize_generation(
model=model,
data_loader=val_loader,
output_dir=images_dir,
device=device,
num_samples=config["num_visual_samples"],
)
print(f"Training history plot: {history_plot_path}")
print(f"Generation samples: {generation_samples_path}")

View File

@@ -1,10 +1,37 @@
"""GAN model for image translation Yandex -> Google."""
"""GAN model for image translation Google -> Yandex."""
import torch
import torch.nn as nn
import torch.nn.functional as F
def get_compatible_device(prefer_cuda: bool = True, verbose: bool = True) -> torch.device:
"""Return CUDA only when the current PyTorch build supports the GPU arch."""
if not prefer_cuda or not torch.cuda.is_available():
return torch.device("cpu")
try:
major, minor = torch.cuda.get_device_capability()
arch = f"sm_{major}{minor}"
supported_arches = set(torch.cuda.get_arch_list())
gpu_name = torch.cuda.get_device_name()
except Exception as exc:
if verbose:
print(f"CUDA is visible but cannot be inspected ({exc}); using CPU.")
return torch.device("cpu")
if supported_arches and arch not in supported_arches:
if verbose:
supported = ", ".join(sorted(supported_arches))
print(
f"CUDA GPU '{gpu_name}' has capability {arch}, but this PyTorch build "
f"supports only: {supported}. Using CPU."
)
return torch.device("cpu")
return torch.device("cuda")
class UNetDownBlock(nn.Module):
"""Downsampling block for U-Net."""
@@ -29,7 +56,10 @@ class UNetUpBlock(nn.Module):
def __init__(self, in_channels: int, out_channels: int, dropout: float = 0.0):
super().__init__()
self.upconv = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=4, stride=2, padding=1, bias=False)
self.upconv = nn.Sequential(
nn.Upsample(scale_factor=2, mode="bilinear", align_corners=False),
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, bias=False),
)
self.norm = nn.BatchNorm2d(out_channels)
self.relu = nn.ReLU(inplace=True)
if dropout > 0:
@@ -38,7 +68,6 @@ class UNetUpBlock(nn.Module):
self.dropout = None
def forward(self, x, skip_input):
x = self.upconv(x)
# Pad if needed to match skip connection size
if x.shape != skip_input.shape:
@@ -54,7 +83,7 @@ class UNetUpBlock(nn.Module):
class GeneratorUNet(nn.Module):
"""U-Net generator for Yandex -> Google translation."""
"""U-Net generator for Google -> Yandex translation."""
def __init__(self, in_channels: int = 3, out_channels: int = 3):
super().__init__()
@@ -85,7 +114,8 @@ class GeneratorUNet(nn.Module):
# Final
self.final = nn.Sequential(
nn.ConvTranspose2d(128, out_channels, kernel_size=4, stride=2, padding=1),
nn.Upsample(scale_factor=2, mode="bilinear", align_corners=False),
nn.Conv2d(128, out_channels, kernel_size=3, padding=1),
nn.Tanh(),
)
@@ -115,7 +145,7 @@ class GeneratorUNet(nn.Module):
class DiscriminatorPatchGAN(nn.Module):
"""PatchGAN discriminator."""
"""PatchGAN discriminator for paired source/target images."""
def __init__(self, in_channels: int = 6):
super().__init__()
@@ -132,7 +162,6 @@ class DiscriminatorPatchGAN(nn.Module):
nn.BatchNorm2d(512),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(512, 1, kernel_size=4, stride=1, padding=1),
nn.Sigmoid(),
)
def forward(self, img_A, img_B):
@@ -167,15 +196,82 @@ class GANLoss(nn.Module):
return -prediction.mean() if target_is_real else prediction.mean()
class SSIMLoss(nn.Module):
"""Local SSIM loss for normalized image tensors in [-1, 1]."""
def __init__(self, window_size: int = 11, c1: float = 0.01 ** 2, c2: float = 0.03 ** 2):
super().__init__()
self.window_size = window_size
self.c1 = c1
self.c2 = c2
def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
pred = (pred + 1.0) * 0.5
target = (target + 1.0) * 0.5
padding = self.window_size // 2
mu_pred = F.avg_pool2d(pred, self.window_size, stride=1, padding=padding)
mu_target = F.avg_pool2d(target, self.window_size, stride=1, padding=padding)
mu_pred_sq = mu_pred.pow(2)
mu_target_sq = mu_target.pow(2)
mu_pred_target = mu_pred * mu_target
sigma_pred = F.avg_pool2d(pred * pred, self.window_size, stride=1, padding=padding) - mu_pred_sq
sigma_target = F.avg_pool2d(target * target, self.window_size, stride=1, padding=padding) - mu_target_sq
sigma_pred_target = F.avg_pool2d(pred * target, self.window_size, stride=1, padding=padding) - mu_pred_target
ssim_map = (
(2 * mu_pred_target + self.c1) * (2 * sigma_pred_target + self.c2)
) / (
(mu_pred_sq + mu_target_sq + self.c1) * (sigma_pred + sigma_target + self.c2)
)
return (1.0 - ssim_map.clamp(0, 1)).mean()
class SobelEdgeLoss(nn.Module):
"""L1 loss between Sobel edge maps, useful for stable keypoint structure."""
def __init__(self):
super().__init__()
kernel_x = torch.tensor(
[[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]],
dtype=torch.float32,
).view(1, 1, 3, 3)
kernel_y = torch.tensor(
[[-1, -2, -1], [0, 0, 0], [1, 2, 1]],
dtype=torch.float32,
).view(1, 1, 3, 3)
self.register_buffer("kernel_x", kernel_x)
self.register_buffer("kernel_y", kernel_y)
@staticmethod
def _to_gray(x: torch.Tensor) -> torch.Tensor:
x = (x + 1.0) * 0.5
weights = x.new_tensor([0.299, 0.587, 0.114]).view(1, 3, 1, 1)
return (x * weights).sum(dim=1, keepdim=True)
def _edges(self, x: torch.Tensor) -> torch.Tensor:
gray = self._to_gray(x)
grad_x = F.conv2d(gray, self.kernel_x, padding=1)
grad_y = F.conv2d(gray, self.kernel_y, padding=1)
return torch.sqrt(grad_x.pow(2) + grad_y.pow(2) + 1e-6)
def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
return F.l1_loss(self._edges(pred), self._edges(target))
class ImageGAN(nn.Module):
"""Complete GAN model for image translation."""
"""Complete pix2pix-style GAN for Google -> Yandex image translation."""
def __init__(
self,
input_channels: int = 3,
output_channels: int = 3,
gan_mode: str = "vanilla",
lambda_L1: float = 100.0,
gan_mode: str = "lsgan",
lambda_L1: float = 150.0,
lambda_GAN: float = 0.5,
lambda_SSIM: float = 25.0,
lambda_edge: float = 20.0,
use_cuda: bool = True,
):
super().__init__()
@@ -183,29 +279,36 @@ class ImageGAN(nn.Module):
self.discriminator = DiscriminatorPatchGAN(input_channels + output_channels)
self.gan_loss = GANLoss(gan_mode)
self.l1_loss = nn.L1Loss()
self.ssim_loss = SSIMLoss()
self.edge_loss = SobelEdgeLoss()
self.lambda_L1 = lambda_L1
self.lambda_GAN = lambda_GAN
self.lambda_SSIM = lambda_SSIM
self.lambda_edge = lambda_edge
self.device = torch.device("cuda" if use_cuda and torch.cuda.is_available() else "cpu")
self.device = get_compatible_device(prefer_cuda=use_cuda)
self.to(self.device)
def forward(self, yandex_image):
"""Generate Google image from Yandex."""
return self.generator(yandex_image)
def forward(self, google_image):
"""Generate a Yandex-style image from a Google image."""
return self.generator(google_image)
def generator_step(self, yandex_img, real_google_img):
"""Compute generator losses."""
fake_google = self.generator(yandex_img)
fake_pred = self.discriminator(yandex_img, fake_google)
gan_loss = self.gan_loss(fake_pred, True)
l1_loss = self.l1_loss(fake_google, real_google_img) * self.lambda_L1
total_loss = gan_loss + l1_loss
return total_loss, gan_loss, l1_loss
def generator_step(self, google_img, real_yandex_img):
"""Compute generator losses against the paired original Yandex image."""
fake_yandex = self.generator(google_img)
fake_pred = self.discriminator(google_img, fake_yandex)
gan_loss = self.gan_loss(fake_pred, True) * self.lambda_GAN
l1_loss = self.l1_loss(fake_yandex, real_yandex_img) * self.lambda_L1
ssim_loss = self.ssim_loss(fake_yandex, real_yandex_img) * self.lambda_SSIM
edge_loss = self.edge_loss(fake_yandex, real_yandex_img) * self.lambda_edge
total_loss = gan_loss + l1_loss + ssim_loss + edge_loss
return total_loss, gan_loss, l1_loss, ssim_loss, edge_loss
def discriminator_step(self, yandex_img, real_google_img, fake_google_img):
"""Compute discriminator losses."""
real_pred = self.discriminator(yandex_img, real_google_img)
def discriminator_step(self, google_img, real_yandex_img, fake_yandex_img):
"""Compute discriminator losses for real and generated Yandex targets."""
real_pred = self.discriminator(google_img, real_yandex_img)
real_loss = self.gan_loss(real_pred, True)
fake_pred = self.discriminator(yandex_img, fake_google_img.detach())
fake_pred = self.discriminator(google_img, fake_yandex_img.detach())
fake_loss = self.gan_loss(fake_pred, False)
total_loss = (real_loss + fake_loss) * 0.5
return total_loss, real_loss, fake_loss
@@ -214,8 +317,11 @@ class ImageGAN(nn.Module):
def create_gan(
input_channels: int = 3,
output_channels: int = 3,
gan_mode: str = "vanilla",
lambda_L1: float = 100.0,
gan_mode: str = "lsgan",
lambda_L1: float = 150.0,
lambda_GAN: float = 0.5,
lambda_SSIM: float = 25.0,
lambda_edge: float = 20.0,
use_cuda: bool = True,
) -> ImageGAN:
"""Create a GAN model."""
@@ -224,6 +330,9 @@ def create_gan(
output_channels=output_channels,
gan_mode=gan_mode,
lambda_L1=lambda_L1,
lambda_GAN=lambda_GAN,
lambda_SSIM=lambda_SSIM,
lambda_edge=lambda_edge,
use_cuda=use_cuda,
)
@@ -252,4 +361,4 @@ if __name__ == "__main__":
# Count parameters
gen_params = sum(p.numel() for p in model.generator.parameters())
disc_params = sum(p.numel() for p in model.discriminator.parameters())
print(f"Generator: {gen_params:,} params, Discriminator: {disc_params:,} params")
print(f"Generator: {gen_params:,} params, Discriminator: {disc_params:,} params")

View File

@@ -0,0 +1,40 @@
from config import create_config
from dataloader import YaGoDataset, create_data_loaders
def get_dataset_info():
config = create_config()
dataset = YaGoDataset(
root_dir=config["data_dir"],
image_size=tuple(config["image_size"]),
)
sample = dataset[0] if len(dataset) else {}
return {
"size": len(dataset),
"sample_keys": list(sample.keys()),
"google_shape": tuple(sample["google_img"].shape) if sample else None,
"yandex_shape": tuple(sample["yandex_img"].shape) if sample else None,
}
def smoke_test_dataloader(batch_size=4):
config = create_config()
train_loader, val_loader = create_data_loaders(
root_dir=config["data_dir"],
batch_size=batch_size,
train_split=config["train_split"],
num_workers=config["num_workers"],
image_size=tuple(config["image_size"]),
)
batch = next(iter(train_loader))
return {
"train_size": len(train_loader.dataset),
"val_size": len(val_loader.dataset),
"google_batch_shape": tuple(batch["google_img"].shape),
"yandex_batch_shape": tuple(batch["yandex_img"].shape),
}
if __name__ == "__main__":
print(get_dataset_info())
print(smoke_test_dataloader())

View File

@@ -28,18 +28,26 @@ class GANTrainer:
# Optimizers
lr = config.get("learning_rate", 2e-4)
lr_d = config.get("discriminator_learning_rate", lr * config.get("discriminator_lr_factor", 0.5))
beta1 = config.get("beta1", 0.5)
beta2 = config.get("beta2", 0.999)
self.opt_G = torch.optim.Adam(model.generator.parameters(), lr=lr, betas=(beta1, beta2))
self.opt_D = torch.optim.Adam(model.discriminator.parameters(), lr=lr, betas=(beta1, beta2))
self.opt_D = torch.optim.Adam(model.discriminator.parameters(), lr=lr_d, betas=(beta1, beta2))
# Training state
self.current_epoch = 0
self.best_val_loss = float("inf")
self.g_losses = []
self.d_losses = []
self.l1_losses = []
self.ssim_losses = []
self.edge_losses = []
self.val_g_losses = []
self.val_d_losses = []
self.val_l1_losses = []
self.val_ssim_losses = []
self.val_edge_losses = []
self.val_reconstruction_losses = []
# Output dir
self.output_dir = Path(config.get("output_dir", "runs/gan"))
@@ -54,35 +62,55 @@ class GANTrainer:
"""Train for one epoch."""
self.model.train()
total_g = total_d = 0.0
total_l1 = total_ssim = total_edge = 0.0
num_batches = len(self.train_loader)
d_update_interval = max(1, self.config.get("discriminator_update_interval", 1))
pbar = tqdm(self.train_loader, desc=f"Epoch {self.current_epoch + 1}")
for batch in pbar:
yandex_img = batch["yandex_img"].to(self.device)
pbar = tqdm(enumerate(self.train_loader), total=num_batches, desc=f"Epoch {self.current_epoch + 1}")
for batch_idx, batch in pbar:
google_img = batch["google_img"].to(self.device)
yandex_img = batch["yandex_img"].to(self.device)
# Train D
self.opt_D.zero_grad()
with torch.no_grad():
fake_img = self.model.generator(yandex_img)
d_loss = self.model.discriminator_step(yandex_img, google_img, fake_img)[0]
d_loss.backward()
self.opt_D.step()
if batch_idx % d_update_interval == 0:
self.opt_D.zero_grad()
with torch.no_grad():
fake_img = self.model.generator(google_img)
d_loss = self.model.discriminator_step(google_img, yandex_img, fake_img)[0]
d_loss.backward()
self.opt_D.step()
else:
d_loss = google_img.new_tensor(0.0)
# Train G
self.opt_G.zero_grad()
g_loss = self.model.generator_step(yandex_img, google_img)[0]
g_loss, gan_loss, l1_loss, ssim_loss, edge_loss = self.model.generator_step(google_img, yandex_img)
g_loss.backward()
self.opt_G.step()
total_g += g_loss.item()
total_d += d_loss.item()
pbar.set_postfix({"g_loss": g_loss.item(), "d_loss": d_loss.item()})
total_l1 += l1_loss.item()
total_ssim += ssim_loss.item()
total_edge += edge_loss.item()
pbar.set_postfix({
"g_loss": g_loss.item(),
"d_loss": d_loss.item(),
"l1": l1_loss.item(),
"ssim": ssim_loss.item(),
"edge": edge_loss.item(),
})
avg_g = total_g / num_batches
avg_d = total_d / num_batches
avg_l1 = total_l1 / num_batches
avg_ssim = total_ssim / num_batches
avg_edge = total_edge / num_batches
self.g_losses.append(avg_g)
self.d_losses.append(avg_d)
self.l1_losses.append(avg_l1)
self.ssim_losses.append(avg_ssim)
self.edge_losses.append(avg_edge)
return avg_g, avg_d
@torch.no_grad()
@@ -90,20 +118,32 @@ class GANTrainer:
"""Validate the model."""
self.model.eval()
total_g = total_d = 0.0
total_l1 = total_ssim = total_edge = 0.0
for batch in tqdm(self.val_loader, desc="Val"):
yandex_img = batch["yandex_img"].to(self.device)
google_img = batch["google_img"].to(self.device)
fake_img = self.model.generator(yandex_img)
g_loss = self.model.generator_step(yandex_img, google_img)[0]
d_loss = self.model.discriminator_step(yandex_img, google_img, fake_img)[0]
yandex_img = batch["yandex_img"].to(self.device)
fake_img = self.model.generator(google_img)
g_loss, _, l1_loss, ssim_loss, edge_loss = self.model.generator_step(google_img, yandex_img)
d_loss = self.model.discriminator_step(google_img, yandex_img, fake_img)[0]
total_g += g_loss.item()
total_d += d_loss.item()
total_l1 += l1_loss.item()
total_ssim += ssim_loss.item()
total_edge += edge_loss.item()
avg_g = total_g / len(self.val_loader)
avg_d = total_d / len(self.val_loader)
avg_l1 = total_l1 / len(self.val_loader)
avg_ssim = total_ssim / len(self.val_loader)
avg_edge = total_edge / len(self.val_loader)
avg_reconstruction = avg_l1 + avg_ssim + avg_edge
self.val_g_losses.append(avg_g)
self.val_d_losses.append(avg_d)
self.val_l1_losses.append(avg_l1)
self.val_ssim_losses.append(avg_ssim)
self.val_edge_losses.append(avg_edge)
self.val_reconstruction_losses.append(avg_reconstruction)
return avg_g, avg_d
def train(self, num_epochs: int):
@@ -117,23 +157,30 @@ class GANTrainer:
train_g, train_d = self.train_epoch()
val_g, val_d = self.validate()
# Save best checkpoint
val_total = val_g + val_d
if val_total < self.best_val_loss:
self.best_val_loss = val_total
val_reconstruction = self.val_reconstruction_losses[-1]
if val_reconstruction < self.best_val_loss:
self.best_val_loss = val_reconstruction
self.save_checkpoint("best")
# Periodic checkpoint
if (epoch + 1) % self.config.get("save_interval", 5) == 0:
self.save_checkpoint(f"epoch_{epoch + 1}")
print(f"Epoch {epoch + 1}: train_g={train_g:.4f}, train_d={train_d:.4f}, val_g={val_g:.4f}, val_d={val_d:.4f}")
print(
f"Epoch {epoch + 1}: "
f"train_g={train_g:.4f}, train_d={train_d:.4f}, "
f"train_l1={self.l1_losses[-1]:.4f}, train_ssim={self.ssim_losses[-1]:.4f}, "
f"train_edge={self.edge_losses[-1]:.4f}, val_g={val_g:.4f}, val_d={val_d:.4f}, "
f"val_l1={self.val_l1_losses[-1]:.4f}, val_ssim={self.val_ssim_losses[-1]:.4f}, "
f"val_edge={self.val_edge_losses[-1]:.4f}, val_rec={val_reconstruction:.4f}"
)
# Early stopping
patience = self.config.get("early_stopping_patience", 0)
if patience > 0 and len(self.val_g_losses) > patience:
recent = self.val_g_losses[-patience:]
if all(l >= min(self.val_g_losses[:-patience]) for l in recent):
if patience > 0 and len(self.val_reconstruction_losses) > patience:
recent = self.val_reconstruction_losses[-patience:]
previous_best = min(self.val_reconstruction_losses[:-patience])
if all(loss >= previous_best for loss in recent):
print(f"Early stopping at epoch {epoch + 1}")
break
@@ -188,4 +235,4 @@ if __name__ == "__main__":
g_loss, d_loss = trainer.train_epoch()
print(f"Training step succeeded: G={g_loss:.4f}, D={d_loss:.4f}")
except Exception as e:
print(f"Training step failed: {e}")
print(f"Training step failed: {e}")

Binary file not shown.

Binary file not shown.

Binary file not shown.

145
sian_similarity.py Normal file
View File

@@ -0,0 +1,145 @@
from __future__ import annotations
import importlib.util
import os
from pathlib import Path
from typing import Optional, Sequence
import numpy as np
import torch
from PIL import Image
from vision_chunk import VisionChunk
ROOT_DIR = Path(__file__).resolve().parent
MODEL_FILE = ROOT_DIR / "models" / "SiaN-similarity" / "model.py"
DEFAULT_CHECKPOINT_PATH = (
ROOT_DIR
/ "models"
/ "SiaN-similarity"
/ "runs"
/ "gan_training"
/ "checkpoints"
/ "best_model.pt"
)
IMAGE_SIZE = (256, 256)
DEFAULT_THRESHOLD = 0.5
CHECKPOINT_ENV = "SIAN_SIMILARITY_CHECKPOINT"
THRESHOLD_ENV = "SIAN_SIMILARITY_THRESHOLD"
_model: Optional[torch.nn.Module] = None
_device: Optional[torch.device] = None
def _load_similarity_class():
spec = importlib.util.spec_from_file_location("sian_similarity_model", MODEL_FILE)
if spec is None or spec.loader is None:
raise ImportError(f"Cannot load similarity model from {MODEL_FILE}")
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
return module.SimilarityCNN
def _get_checkpoint_path() -> Path:
checkpoint_path = os.getenv(CHECKPOINT_ENV)
if checkpoint_path:
return Path(checkpoint_path).expanduser().resolve()
return DEFAULT_CHECKPOINT_PATH
def get_threshold() -> float:
threshold = os.getenv(THRESHOLD_ENV)
if threshold is None:
return DEFAULT_THRESHOLD
return float(threshold)
def _get_device() -> torch.device:
global _device
if _device is None:
_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
return _device
def _get_model() -> torch.nn.Module:
global _model
if _model is not None:
return _model
checkpoint_path = _get_checkpoint_path()
if not checkpoint_path.exists():
raise FileNotFoundError(
f"SiaN similarity checkpoint not found: {checkpoint_path}. "
f"Set {CHECKPOINT_ENV} to another .pt file if needed."
)
SimilarityCNN = _load_similarity_class()
device = _get_device()
model = SimilarityCNN(
input_channels=3,
backbone_name="resnet18",
pretrained=False,
dropout_rate=0.3,
use_batch_norm=True,
).to(device)
checkpoint = torch.load(checkpoint_path, map_location=device)
state_dict = checkpoint.get("model_state_dict", checkpoint)
model.load_state_dict(state_dict)
model.eval()
_model = model
return _model
def _chunk_to_tensor(chunk: VisionChunk) -> torch.Tensor:
image = chunk.image.convert("RGB").resize(IMAGE_SIZE, Image.BILINEAR)
array = np.asarray(image, dtype=np.float32) / 255.0
tensor = torch.from_numpy(array).permute(2, 0, 1).unsqueeze(0)
return tensor.to(_get_device())
def get_similarity_score(chunk1: VisionChunk, chunk2: VisionChunk) -> float:
if chunk1 is None or chunk2 is None:
return 0.0
model = _get_model()
img1 = _chunk_to_tensor(chunk1)
img2 = _chunk_to_tensor(chunk2)
with torch.inference_mode():
similarity = model(img1, img2)
return float(similarity.squeeze().item())
def get_similarity_scores(chunk: VisionChunk, candidates: Sequence[VisionChunk]) -> list[float]:
if chunk is None or not candidates:
return []
model = _get_model()
img = _chunk_to_tensor(chunk)
candidate_images = torch.cat([_chunk_to_tensor(candidate) for candidate in candidates], dim=0)
repeated_img = img.expand(candidate_images.shape[0], -1, -1, -1)
with torch.inference_mode():
similarities = model(repeated_img, candidate_images)
return [float(score) for score in similarities.squeeze(1).detach().cpu().tolist()]
def is_similar(
chunk1: VisionChunk,
chunk2: VisionChunk,
threshold: Optional[float] = None,
) -> bool:
if threshold is None:
threshold = get_threshold()
return get_similarity_score(chunk1, chunk2) >= threshold