feat: add GAN
This commit is contained in:
21
.vscode/c_cpp_properties.json
vendored
Normal file
21
.vscode/c_cpp_properties.json
vendored
Normal file
@@ -0,0 +1,21 @@
|
||||
{
|
||||
"configurations": [
|
||||
{
|
||||
"name": "Win32",
|
||||
"includePath": [
|
||||
"${workspaceFolder}/**"
|
||||
],
|
||||
"defines": [
|
||||
"_DEBUG",
|
||||
"UNICODE",
|
||||
"_UNICODE"
|
||||
],
|
||||
"windowsSdkVersion": "10.0.26100.0",
|
||||
"compilerPath": "cl.exe",
|
||||
"cStandard": "c17",
|
||||
"cppStandard": "c++17",
|
||||
"intelliSenseMode": "windows-msvc-x64"
|
||||
}
|
||||
],
|
||||
"version": 4
|
||||
}
|
||||
48
autopilot.py
48
autopilot.py
@@ -8,6 +8,7 @@ import cv2
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
import sian_similarity
|
||||
from timer import Timer
|
||||
|
||||
from vision_chunk import VisionChunk
|
||||
@@ -58,8 +59,16 @@ class AutoPilot(Pilot):
|
||||
# Положение на основе ориентира
|
||||
reserved_pos: Position | None
|
||||
proccessing_time: float
|
||||
use_sian_similarity: bool
|
||||
|
||||
def __init__(self, points = [], chunks = [], viz_manager=None, pixel_ratio: float = 1.):
|
||||
def __init__(
|
||||
self,
|
||||
points = [],
|
||||
chunks = [],
|
||||
viz_manager=None,
|
||||
pixel_ratio: float = 1.,
|
||||
use_sian_similarity: bool = False,
|
||||
):
|
||||
self.prev_chunk = None
|
||||
self.pos = Position(0, 0, 1, 0, 0, 0)
|
||||
self.chunks = chunks
|
||||
@@ -67,6 +76,7 @@ class AutoPilot(Pilot):
|
||||
self.vis_manager = viz_manager # Менеджер визуализации
|
||||
self.reserved_pos = None
|
||||
self.pixel_ratio = pixel_ratio
|
||||
self.use_sian_similarity = use_sian_similarity
|
||||
|
||||
# Пороговые значения качества сопоставления/гомографии
|
||||
self.min_inliers: int = 12
|
||||
@@ -153,17 +163,33 @@ class AutoPilot(Pilot):
|
||||
landmark_timer = Timer()
|
||||
landmark_timer.start()
|
||||
|
||||
cur_pos = np.array([self.pos.x, self.pos.y])
|
||||
closest_chunk_idx = ((self.chunk_points - cur_pos) ** 2).sum(1).argmin()
|
||||
|
||||
current_chunk = self.prev_chunk
|
||||
landmark_chunk = self.chunks[closest_chunk_idx]
|
||||
if current_chunk is None or not self.chunks:
|
||||
return None
|
||||
|
||||
if self.use_sian_similarity:
|
||||
similarity_scores = sian_similarity.get_similarity_scores(current_chunk, self.chunks)
|
||||
best_chunk_idx = int(np.argmax(similarity_scores))
|
||||
best_similarity_score = similarity_scores[best_chunk_idx]
|
||||
|
||||
print(f"[LANDMARK]: best similarity={best_similarity_score:.4f}, idx={best_chunk_idx}")
|
||||
|
||||
if best_similarity_score < sian_similarity.get_threshold():
|
||||
print("[LANDMARK]: not similar")
|
||||
return None
|
||||
|
||||
print("[LANDMARK]: similar")
|
||||
landmark_chunk = self.chunks[best_chunk_idx]
|
||||
else:
|
||||
cur_pos = np.array([self.pos.x, self.pos.y])
|
||||
closest_chunk_idx = ((self.chunk_points - cur_pos) ** 2).sum(1).argmin()
|
||||
landmark_chunk = self.chunks[closest_chunk_idx]
|
||||
|
||||
if constants.DEBUG_FPS:
|
||||
print(f"[LANDMARK]: Closest chunk finding: {landmark_timer.loop() * 1000:.2f} ms")
|
||||
print(f"[LANDMARK]: Landmark chunk finding: {landmark_timer.loop() * 1000:.2f} ms")
|
||||
|
||||
# Краевой случай: отсутствие чанков
|
||||
if current_chunk is None or landmark_chunk is None:
|
||||
if landmark_chunk is None:
|
||||
return None
|
||||
|
||||
landmark_timer.start()
|
||||
@@ -312,10 +338,10 @@ class AutoPilot(Pilot):
|
||||
# Пытаемся найти ориентир на картинке:
|
||||
self.prev_chunk = current_chunk
|
||||
# Для улучшения среднего FPS
|
||||
# if self.frame_count % 5 == 0:
|
||||
# pos_by_chunk = self.get_position_by_chunk()
|
||||
# if pos_by_chunk is not None:
|
||||
# self.pos = pos_by_chunk
|
||||
if self.frame_count % 5 == 0:
|
||||
pos_by_chunk = self.get_position_by_chunk()
|
||||
if pos_by_chunk is not None:
|
||||
self.pos = pos_by_chunk
|
||||
|
||||
command = self.make_command()
|
||||
self.timer.reset()
|
||||
|
||||
118
dissertation/_media/gan_patchgan_discriminator.svg
Normal file
118
dissertation/_media/gan_patchgan_discriminator.svg
Normal file
@@ -0,0 +1,118 @@
|
||||
<svg xmlns="http://www.w3.org/2000/svg" width="1800" height="1000" viewBox="0 0 1800 1000">
|
||||
<defs>
|
||||
<marker id="arrow" markerWidth="16" markerHeight="16" refX="14" refY="8" orient="auto">
|
||||
<path d="M0,0 L16,8 L0,16 Z" fill="#1f2937"/>
|
||||
</marker>
|
||||
<style>
|
||||
.bg { fill: #fafafa; }
|
||||
.title { font: 700 40px Arial, sans-serif; fill: #111827; }
|
||||
.subtitle { font: 400 21px Arial, sans-serif; fill: #4b5563; }
|
||||
.box { stroke-width: 3; rx: 18; }
|
||||
.blue { fill: #eaf3ff; stroke: #3b82f6; }
|
||||
.orange { fill: #fff1df; stroke: #f59e0b; }
|
||||
.violet { fill: #f0ebff; stroke: #7c3aed; }
|
||||
.green { fill: #eaf8ef; stroke: #22a06b; }
|
||||
.red { fill: #fff1f2; stroke: #ef4444; }
|
||||
.label { font: 700 27px Arial, sans-serif; fill: #111827; text-anchor: middle; }
|
||||
.desc { font: 400 18px Arial, sans-serif; fill: #4b5563; text-anchor: middle; }
|
||||
.small { font: 400 16px Arial, sans-serif; fill: #6b7280; text-anchor: middle; }
|
||||
.arrow { stroke: #1f2937; stroke-width: 4; fill: none; marker-end: url(#arrow); }
|
||||
.branch { stroke: #6b7280; stroke-width: 3.5; fill: none; marker-end: url(#arrow); }
|
||||
.icon-line { stroke: #111827; stroke-width: 5; fill: none; stroke-linecap: round; stroke-linejoin: round; }
|
||||
.patch-grid { stroke: #7c3aed; stroke-width: 3; fill: none; }
|
||||
.p1 { fill: #bbf7d0; }
|
||||
.p2 { fill: #fde68a; }
|
||||
.p3 { fill: #fecaca; }
|
||||
.stage { font: 700 20px Arial, sans-serif; fill: #374151; text-anchor: middle; }
|
||||
</style>
|
||||
</defs>
|
||||
|
||||
<rect class="bg" width="1800" height="1000"/>
|
||||
<text class="title" x="90" y="80">Дискриминатор PatchGAN</text>
|
||||
<text class="subtitle" x="90" y="118">Проверяет пару изображений локальными патчами: настоящая ли это пара Google + Яндекс или результат генератора</text>
|
||||
|
||||
<text class="stage" x="245" y="195">Входные пары</text>
|
||||
<text class="stage" x="690" y="195">Объединение</text>
|
||||
<text class="stage" x="1125" y="195">PatchGAN</text>
|
||||
<text class="stage" x="1530" y="195">Оценка</text>
|
||||
|
||||
<g transform="translate(90 240)">
|
||||
<rect class="box green" width="310" height="205"/>
|
||||
<text class="label" x="155" y="54">Real pair</text>
|
||||
<text class="desc" x="155" y="85">Google + настоящий Яндекс</text>
|
||||
<text class="small" x="155" y="116">целевая метка: 1</text>
|
||||
<path class="icon-line" d="M78 158 C112 130, 150 176, 190 145 C211 128, 230 124, 250 122"/>
|
||||
</g>
|
||||
|
||||
<g transform="translate(90 555)">
|
||||
<rect class="box red" width="310" height="205"/>
|
||||
<text class="label" x="155" y="54">Fake pair</text>
|
||||
<text class="desc" x="155" y="85">Google + Generated Яндекс</text>
|
||||
<text class="small" x="155" y="116">целевая метка: 0</text>
|
||||
<path class="icon-line" d="M78 158 C112 130, 150 176, 190 145 C211 128, 230 124, 250 122"/>
|
||||
</g>
|
||||
|
||||
<path class="branch" d="M420 342 H485 V470 H545"/>
|
||||
<path class="branch" d="M420 657 H485 V520 H545"/>
|
||||
|
||||
<g transform="translate(565 390)">
|
||||
<rect class="box blue" width="255" height="210"/>
|
||||
<text class="label" x="128" y="58">Concat</text>
|
||||
<text class="desc" x="128" y="92">6 каналов</text>
|
||||
<text class="small" x="128" y="124">RGB Google + RGB Yandex</text>
|
||||
<text class="small" x="128" y="154">B x 6 x 256 x 256</text>
|
||||
</g>
|
||||
|
||||
<path class="arrow" d="M840 495 H930"/>
|
||||
|
||||
<g transform="translate(950 295)">
|
||||
<rect class="box violet" width="350" height="400"/>
|
||||
<text class="label" x="175" y="58">Сверточные блоки</text>
|
||||
<text class="desc" x="175" y="92">Conv + BatchNorm + LeakyReLU</text>
|
||||
<path class="icon-line" d="M78 170 H272"/>
|
||||
<path class="icon-line" d="M78 230 H272"/>
|
||||
<path class="icon-line" d="M78 290 H272"/>
|
||||
<circle cx="95" cy="170" r="12" fill="#bfdbfe" stroke="#111827" stroke-width="4"/>
|
||||
<circle cx="155" cy="170" r="12" fill="#bfdbfe" stroke="#111827" stroke-width="4"/>
|
||||
<circle cx="215" cy="170" r="12" fill="#bfdbfe" stroke="#111827" stroke-width="4"/>
|
||||
<circle cx="120" cy="230" r="12" fill="#c4b5fd" stroke="#111827" stroke-width="4"/>
|
||||
<circle cx="180" cy="230" r="12" fill="#c4b5fd" stroke="#111827" stroke-width="4"/>
|
||||
<circle cx="240" cy="230" r="12" fill="#c4b5fd" stroke="#111827" stroke-width="4"/>
|
||||
<circle cx="145" cy="290" r="12" fill="#fed7aa" stroke="#111827" stroke-width="4"/>
|
||||
<circle cx="205" cy="290" r="12" fill="#fed7aa" stroke="#111827" stroke-width="4"/>
|
||||
<text class="small" x="175" y="352">64 -> 128 -> 256 -> 512</text>
|
||||
</g>
|
||||
|
||||
<path class="arrow" d="M1320 495 H1410"/>
|
||||
|
||||
<g transform="translate(1430 295)">
|
||||
<rect class="box orange" width="280" height="400"/>
|
||||
<text class="label" x="140" y="58">Patch map</text>
|
||||
<text class="desc" x="140" y="92">оценка real/fake</text>
|
||||
<g transform="translate(78 135)">
|
||||
<rect class="patch-grid" width="124" height="124" rx="8"/>
|
||||
<rect class="p1" x="10" y="10" width="28" height="28"/>
|
||||
<rect class="p2" x="48" y="10" width="28" height="28"/>
|
||||
<rect class="p1" x="86" y="10" width="28" height="28"/>
|
||||
<rect class="p2" x="10" y="48" width="28" height="28"/>
|
||||
<rect class="p1" x="48" y="48" width="28" height="28"/>
|
||||
<rect class="p3" x="86" y="48" width="28" height="28"/>
|
||||
<rect class="p1" x="10" y="86" width="28" height="28"/>
|
||||
<rect class="p2" x="48" y="86" width="28" height="28"/>
|
||||
<rect class="p1" x="86" y="86" width="28" height="28"/>
|
||||
</g>
|
||||
<text class="small" x="140" y="310">каждая ячейка</text>
|
||||
<text class="small" x="140" y="336">соответствует локальной</text>
|
||||
<text class="small" x="140" y="362">области изображения</text>
|
||||
</g>
|
||||
|
||||
<path class="arrow" d="M1570 720 V755 H1095 V785"/>
|
||||
|
||||
<g transform="translate(855 790)">
|
||||
<rect class="box green" width="480" height="130"/>
|
||||
<text class="label" x="240" y="50">Функция потерь</text>
|
||||
<text class="desc" x="240" y="82">реальная пара -> 1, сгенерированная пара -> 0</text>
|
||||
</g>
|
||||
|
||||
<text class="small" x="900" y="960">Главная идея PatchGAN: проверять не всю карту одним числом, а локальные признаки стиля и структуры.</text>
|
||||
</svg>
|
||||
|
After Width: | Height: | Size: 6.3 KiB |
119
dissertation/_media/gan_pipeline.svg
Normal file
119
dissertation/_media/gan_pipeline.svg
Normal file
@@ -0,0 +1,119 @@
|
||||
<svg xmlns="http://www.w3.org/2000/svg" width="1800" height="1000" viewBox="0 0 1800 1000">
|
||||
<defs>
|
||||
<marker id="arrow" markerWidth="16" markerHeight="16" refX="14" refY="8" orient="auto">
|
||||
<path d="M0,0 L16,8 L0,16 Z" fill="#1f2937"/>
|
||||
</marker>
|
||||
<style>
|
||||
.bg { fill: #fafafa; }
|
||||
.title { font: 700 40px Arial, sans-serif; fill: #111827; }
|
||||
.subtitle { font: 400 21px Arial, sans-serif; fill: #4b5563; }
|
||||
.stage { font: 700 20px Arial, sans-serif; fill: #374151; text-anchor: middle; }
|
||||
.box { stroke-width: 3; rx: 18; }
|
||||
.blue { fill: #eaf3ff; stroke: #3b82f6; }
|
||||
.orange { fill: #fff1df; stroke: #f59e0b; }
|
||||
.green { fill: #eaf8ef; stroke: #22a06b; }
|
||||
.violet { fill: #f0ebff; stroke: #7c3aed; }
|
||||
.main { font: 700 28px Arial, sans-serif; fill: #111827; text-anchor: middle; }
|
||||
.desc { font: 400 18px Arial, sans-serif; fill: #4b5563; text-anchor: middle; }
|
||||
.small { font: 400 16px Arial, sans-serif; fill: #6b7280; text-anchor: middle; }
|
||||
.arrow { stroke: #1f2937; stroke-width: 4; fill: none; marker-end: url(#arrow); }
|
||||
.thin { stroke: #6b7280; stroke-width: 3; fill: none; marker-end: url(#arrow); }
|
||||
.dash { stroke: #6b7280; stroke-width: 3; fill: none; stroke-dasharray: 10 9; marker-end: url(#arrow); }
|
||||
.icon-line { stroke: #111827; stroke-width: 5; fill: none; stroke-linecap: round; stroke-linejoin: round; }
|
||||
.icon-fill { fill: #ffffff; stroke: #111827; stroke-width: 4; }
|
||||
.kp { fill: #16a34a; stroke: #ffffff; stroke-width: 3; }
|
||||
.bad { fill: #ef4444; stroke: #ffffff; stroke-width: 3; }
|
||||
</style>
|
||||
</defs>
|
||||
|
||||
<rect class="bg" width="1800" height="1000"/>
|
||||
<text class="title" x="90" y="80">Применение GAN для приведения карт к единому домену</text>
|
||||
<text class="subtitle" x="90" y="118">Нейросеть меняет стиль карты, а дальнейшее сопоставление выполняется классическими геометрическими методами</text>
|
||||
|
||||
<text class="stage" x="255" y="190">Исходные данные</text>
|
||||
<text class="stage" x="720" y="190">Доменная адаптация</text>
|
||||
<text class="stage" x="1185" y="190">Сопоставление</text>
|
||||
<text class="stage" x="1545" y="190">Результат</text>
|
||||
|
||||
<g transform="translate(90 235)">
|
||||
<rect class="box blue" width="330" height="210"/>
|
||||
<text class="main" x="165" y="54">Google Maps</text>
|
||||
<text class="desc" x="165" y="86">фрагмент области полета</text>
|
||||
<path class="icon-line" d="M85 145 C122 118, 164 169, 210 132 C230 116, 248 111, 272 110"/>
|
||||
<path class="icon-line" d="M82 170 L280 118"/>
|
||||
<circle class="kp" cx="116" cy="142" r="9"/>
|
||||
<circle class="kp" cx="208" cy="132" r="9"/>
|
||||
</g>
|
||||
|
||||
<g transform="translate(90 555)">
|
||||
<rect class="box orange" width="330" height="210"/>
|
||||
<text class="main" x="165" y="54">Яндекс.Карты</text>
|
||||
<text class="desc" x="165" y="86">эталон с ориентирами</text>
|
||||
<path class="icon-line" d="M85 145 C122 118, 164 169, 210 132 C230 116, 248 111, 272 110"/>
|
||||
<path class="icon-line" d="M82 170 L280 118"/>
|
||||
<circle class="kp" cx="116" cy="142" r="9"/>
|
||||
<circle class="kp" cx="208" cy="132" r="9"/>
|
||||
</g>
|
||||
|
||||
<path class="arrow" d="M440 340 H555"/>
|
||||
|
||||
<g transform="translate(575 300)">
|
||||
<rect class="box orange" width="290" height="260"/>
|
||||
<text class="main" x="145" y="58">GAN</text>
|
||||
<text class="desc" x="145" y="91">Google -> Yandex</text>
|
||||
<path class="icon-line" d="M80 170 L115 128 L150 170 L185 128 L220 170"/>
|
||||
<text class="small" x="145" y="218">сохраняет геометрию</text>
|
||||
<text class="small" x="145" y="242">меняет визуальный стиль</text>
|
||||
</g>
|
||||
|
||||
<path class="arrow" d="M885 430 H1000"/>
|
||||
|
||||
<g transform="translate(1020 300)">
|
||||
<rect class="box orange" width="330" height="260"/>
|
||||
<text class="main" x="165" y="58">Сгенерированный кадр</text>
|
||||
<text class="desc" x="165" y="91">карта в целевом домене</text>
|
||||
<path class="icon-line" d="M85 166 C122 139, 164 190, 210 153 C230 137, 248 132, 272 131"/>
|
||||
<path class="icon-line" d="M82 191 L280 139"/>
|
||||
<circle class="kp" cx="116" cy="163" r="9"/>
|
||||
<circle class="kp" cx="208" cy="153" r="9"/>
|
||||
<circle class="kp" cx="250" cy="139" r="9"/>
|
||||
</g>
|
||||
|
||||
<path class="dash" d="M420 660 C635 775, 940 775, 1040 580"/>
|
||||
<path class="thin" d="M1185 580 V650"/>
|
||||
|
||||
<g transform="translate(1020 650)">
|
||||
<rect class="box green" width="330" height="190"/>
|
||||
<text class="main" x="165" y="52">Ключевые точки</text>
|
||||
<text class="desc" x="165" y="84">ORB / SIFT / AKAZE</text>
|
||||
<circle class="kp" cx="92" cy="132" r="10"/>
|
||||
<circle class="kp" cx="142" cy="116" r="10"/>
|
||||
<circle class="kp" cx="193" cy="139" r="10"/>
|
||||
<circle class="kp" cx="244" cy="111" r="10"/>
|
||||
<path class="icon-line" d="M90 132 L142 116 L193 139 L244 111"/>
|
||||
</g>
|
||||
|
||||
<path class="arrow" d="M1370 745 H1445"/>
|
||||
|
||||
<g transform="translate(1465 650)">
|
||||
<rect class="box green" width="245" height="190"/>
|
||||
<text class="main" x="122" y="52">RANSAC</text>
|
||||
<text class="desc" x="122" y="84">отбор совпадений</text>
|
||||
<circle class="kp" cx="83" cy="130" r="10"/>
|
||||
<circle class="kp" cx="124" cy="114" r="10"/>
|
||||
<circle class="bad" cx="164" cy="137" r="10"/>
|
||||
</g>
|
||||
|
||||
<path class="arrow" d="M1590 630 V530"/>
|
||||
|
||||
<g transform="translate(1425 300)">
|
||||
<rect class="box violet" width="285" height="260"/>
|
||||
<text class="main" x="142" y="58">Гомография H</text>
|
||||
<text class="desc" x="142" y="91">связь координат</text>
|
||||
<path class="icon-fill" d="M95 135 L172 118 L196 184 L112 201 Z"/>
|
||||
<path class="icon-line" d="M103 212 L197 110"/>
|
||||
<text class="small" x="142" y="238">коррекция позиции БПЛА</text>
|
||||
</g>
|
||||
|
||||
<text class="small" x="900" y="930">Итог: изображения становятся визуально ближе, поэтому классические методы получают больше устойчивых совпадений.</text>
|
||||
</svg>
|
||||
|
After Width: | Height: | Size: 6.1 KiB |
104
dissertation/_media/gan_unet_generator.svg
Normal file
104
dissertation/_media/gan_unet_generator.svg
Normal file
@@ -0,0 +1,104 @@
|
||||
<svg xmlns="http://www.w3.org/2000/svg" width="1700" height="920" viewBox="0 0 1700 920">
|
||||
<defs>
|
||||
<marker id="arrow" markerWidth="12" markerHeight="12" refX="10" refY="6" orient="auto">
|
||||
<path d="M0,0 L12,6 L0,12 Z" fill="#344054"/>
|
||||
</marker>
|
||||
<style>
|
||||
.bg { fill: #fbfcfe; }
|
||||
.title { font: 700 34px Arial, sans-serif; fill: #172033; }
|
||||
.subtitle { font: 400 18px Arial, sans-serif; fill: #526070; }
|
||||
.block { stroke: #344054; stroke-width: 2; rx: 10; }
|
||||
.down { fill: #d9ecff; }
|
||||
.bottle { fill: #fff0cf; }
|
||||
.up { fill: #e6f5dc; }
|
||||
.out { fill: #eee7ff; }
|
||||
.label { font: 700 17px Arial, sans-serif; fill: #172033; text-anchor: middle; }
|
||||
.small { font: 400 14px Arial, sans-serif; fill: #4b5563; text-anchor: middle; }
|
||||
.axis { stroke: #344054; stroke-width: 3; fill: none; marker-end: url(#arrow); }
|
||||
.skip { stroke: #d47f20; stroke-width: 3; fill: none; stroke-dasharray: 9 7; marker-end: url(#arrow); }
|
||||
.arrow { stroke: #344054; stroke-width: 3; fill: none; marker-end: url(#arrow); }
|
||||
.note { font: 400 16px Arial, sans-serif; fill: #667085; }
|
||||
</style>
|
||||
</defs>
|
||||
|
||||
<rect class="bg" width="1700" height="920"/>
|
||||
<text class="title" x="80" y="70">Архитектура генератора U-Net: Google Maps -> Яндекс.Карты</text>
|
||||
<text class="subtitle" x="80" y="104">Сжатие изображения до признакового представления и восстановление в целевом стиле с сохранением геометрии через skip-соединения</text>
|
||||
|
||||
<g id="encoder">
|
||||
<rect class="block down" x="80" y="230" width="120" height="360"/>
|
||||
<text class="label" x="140" y="255">Input</text>
|
||||
<text class="small" x="140" y="280">3 x 256 x 256</text>
|
||||
<text class="small" x="140" y="550">Google RGB</text>
|
||||
|
||||
<rect class="block down" x="240" y="265" width="120" height="290"/>
|
||||
<text class="label" x="300" y="292">Down 1</text>
|
||||
<text class="small" x="300" y="318">64</text>
|
||||
<text class="small" x="300" y="532">128 x 128</text>
|
||||
|
||||
<rect class="block down" x="400" y="295" width="120" height="230"/>
|
||||
<text class="label" x="460" y="322">Down 2</text>
|
||||
<text class="small" x="460" y="348">128</text>
|
||||
<text class="small" x="460" y="502">64 x 64</text>
|
||||
|
||||
<rect class="block down" x="560" y="325" width="120" height="170"/>
|
||||
<text class="label" x="620" y="352">Down 3</text>
|
||||
<text class="small" x="620" y="378">256</text>
|
||||
<text class="small" x="620" y="472">32 x 32</text>
|
||||
|
||||
<rect class="block down" x="720" y="350" width="120" height="120"/>
|
||||
<text class="label" x="780" y="377">Down 4</text>
|
||||
<text class="small" x="780" y="403">512</text>
|
||||
<text class="small" x="780" y="448">16 x 16</text>
|
||||
</g>
|
||||
|
||||
<rect class="block bottle" x="880" y="372" width="120" height="76"/>
|
||||
<text class="label" x="940" y="399">Bottleneck</text>
|
||||
<text class="small" x="940" y="425">512</text>
|
||||
|
||||
<g id="decoder">
|
||||
<rect class="block up" x="1040" y="350" width="120" height="120"/>
|
||||
<text class="label" x="1100" y="377">Up 4</text>
|
||||
<text class="small" x="1100" y="403">512</text>
|
||||
<text class="small" x="1100" y="448">16 x 16</text>
|
||||
|
||||
<rect class="block up" x="1200" y="325" width="120" height="170"/>
|
||||
<text class="label" x="1260" y="352">Up 5</text>
|
||||
<text class="small" x="1260" y="378">256</text>
|
||||
<text class="small" x="1260" y="472">32 x 32</text>
|
||||
|
||||
<rect class="block up" x="1360" y="295" width="120" height="230"/>
|
||||
<text class="label" x="1420" y="322">Up 6</text>
|
||||
<text class="small" x="1420" y="348">128</text>
|
||||
<text class="small" x="1420" y="502">64 x 64</text>
|
||||
|
||||
<rect class="block up" x="1520" y="265" width="120" height="290"/>
|
||||
<text class="label" x="1580" y="292">Up 7</text>
|
||||
<text class="small" x="1580" y="318">64</text>
|
||||
<text class="small" x="1580" y="532">128 x 128</text>
|
||||
</g>
|
||||
|
||||
<rect class="block out" x="1370" y="650" width="250" height="95"/>
|
||||
<text class="label" x="1495" y="684">Final Conv + Tanh</text>
|
||||
<text class="small" x="1495" y="711">3 x 256 x 256</text>
|
||||
<text class="small" x="1495" y="734">Generated Yandex RGB</text>
|
||||
|
||||
<path class="arrow" d="M205 410 L235 410"/>
|
||||
<path class="arrow" d="M365 410 L395 410"/>
|
||||
<path class="arrow" d="M525 410 L555 410"/>
|
||||
<path class="arrow" d="M685 410 L715 410"/>
|
||||
<path class="arrow" d="M845 410 L875 410"/>
|
||||
<path class="arrow" d="M1005 410 L1035 410"/>
|
||||
<path class="arrow" d="M1165 410 L1195 410"/>
|
||||
<path class="arrow" d="M1325 410 L1355 410"/>
|
||||
<path class="arrow" d="M1485 410 L1515 410"/>
|
||||
<path class="arrow" d="M1580 560 C1580 610, 1530 620, 1495 645"/>
|
||||
|
||||
<path class="skip" d="M300 255 C300 145, 1580 145, 1580 260"/>
|
||||
<path class="skip" d="M460 290 C460 185, 1420 185, 1420 290"/>
|
||||
<path class="skip" d="M620 320 C620 225, 1260 225, 1260 320"/>
|
||||
<path class="skip" d="M780 345 C780 265, 1100 265, 1100 345"/>
|
||||
|
||||
<text class="note" x="90" y="805">Down-блок: Conv2d, BatchNorm, LeakyReLU. Up-блок: Upsample, Conv2d, BatchNorm, ReLU, затем конкатенация со skip-признаками.</text>
|
||||
<text class="note" x="90" y="835">Назначение skip-соединений: сохранить дороги, перекрестки и контуры объектов при изменении визуального стиля карты.</text>
|
||||
</svg>
|
||||
|
After Width: | Height: | Size: 5.3 KiB |
@@ -1,3 +1,39 @@
|
||||
2.3.3 Применение архитектуры сиамских близнецов для вычисления матрицы гомографии между двумя кадрами
|
||||
2.3.3 Генеративно-состязательная сеть для приведения карт к единому домену
|
||||
|
||||
При сопоставлении кадров БПЛА с эталонной картой возникает не только геометрическое, но и доменное различие изображений. Даже если два фрагмента описывают один и тот же участок местности, снимки из разных источников могут иметь разные цвета, толщину дорог, подписи, условные обозначения, контраст зданий и набор отображаемых ориентиров. Например, один и тот же район на карте Google и на карте Яндекс визуально отличается настолько, что классические алгоритмы поиска ключевых точек могут находить мало устойчивых совпадений или формировать большое число ложных соответствий.
|
||||
|
||||
Одним из способов уменьшить этот доменный разрыв является предварительное преобразование изображения из одного картографического домена в другой. В данной работе рассматривается генеративно-состязательная сеть (Generative Adversarial Network, GAN), которая переводит фрагмент карты Google в визуальный стиль карты Яндекс. После такого преобразования исходный фрагмент Google становится ближе к эталонному домену Яндекс, а значит для дальнейшей локализации можно применять классические методы выделения и сопоставления ключевых точек: ORB, SIFT, AKAZE, BRISK и последующую оценку матрицы гомографии при помощи RANSAC.
|
||||
|
||||
В отличие от модели сиамских близнецов, которая напрямую оценивает схожесть пары изображений, GAN решает вспомогательную задачу нормализации домена. То есть нейросеть не заменяет классический алгоритм сопоставления, а подготавливает данные так, чтобы классический алгоритм работал в более благоприятных условиях.
|
||||
|
||||

|
||||
|
||||
Рисунок 8 – Применение GAN для приведения картографических изображений к единому домену
|
||||
|
||||
Архитектура модели построена по принципу pix2pix, так как для обучения доступны парные изображения одного и того же участка местности в двух доменах: Google Maps и Яндекс.Карты. На вход генератора подается изображение Google размером \left(B,3,256,256\right), где B – размер пакета данных. Генератор формирует изображение \hat{Y}, визуально соответствующее стилю Яндекс.Карт. Дискриминатор получает на вход пару изображений и должен определить, является ли пара реальной \left(G,Y\right) или сгенерированной \left(G,\hat{Y}\right), где G – исходный фрагмент Google, Y – настоящий фрагмент Яндекс, \hat{Y} – результат работы генератора.
|
||||
|
||||
Генератор реализован в виде U-Net. Энкодер последовательно уменьшает пространственное разрешение изображения и извлекает признаки высокого уровня: структуру дорог, контуры кварталов, границы зданий, водные объекты и другие устойчивые элементы карты. Декодер восстанавливает изображение в целевом стиле. Между симметричными уровнями энкодера и декодера используются skip-соединения, которые передают локальную геометрию напрямую из ранних слоев в поздние. Это важно для навигационной задачи: модель должна изменить стиль карты, но не должна смещать дороги, перекрестки и контуры объектов, так как именно они затем используются как ориентиры.
|
||||
|
||||

|
||||
|
||||
Рисунок 9 – Архитектура генератора U-Net
|
||||
|
||||
Дискриминатор реализован как PatchGAN. В отличие от обычного дискриминатора, который выдает одно значение для всего изображения, PatchGAN оценивает реалистичность локальных областей. На вход дискриминатора подается конкатенация исходного изображения Google и изображения Яндекс по каналам, поэтому входной тензор имеет размер \left(B,6,256,256\right). Далее изображение проходит через сверточные блоки с постепенным уменьшением разрешения, а выходом является карта оценок для отдельных фрагментов. Такой подход подходит для картографических изображений, потому что локальные признаки – ширина дорог, стиль подписей, границы объектов, цветовые переходы – важнее глобальной художественной реалистичности.
|
||||
|
||||

|
||||
|
||||
Рисунок 10 – Архитектура дискриминатора PatchGAN
|
||||
|
||||
Обучение модели является состязательным. Генератор стремится сформировать такое изображение \hat{Y}, чтобы дискриминатор считал пару \left(G,\hat{Y}\right) реальной. Дискриминатор, наоборот, учится отличать настоящие пары \left(G,Y\right) от сгенерированных. Для сохранения геометрии карты одной только состязательной функции потерь недостаточно, поэтому итоговая функция потерь генератора включает несколько компонентов:
|
||||
|
||||
L_G = \lambda_{GAN}L_{GAN}\left(D\left(G,\hat{Y}\right),1\right)+\lambda_{L1}\left\|\hat{Y}-Y\right\|_1+\lambda_{SSIM}L_{SSIM}\left(\hat{Y},Y\right)+\lambda_{edge}L_{edge}\left(\hat{Y},Y\right) (1)
|
||||
|
||||
где L_{GAN} – состязательная функция потерь, L1 – попиксельная ошибка между сгенерированным и настоящим изображением Яндекс, L_{SSIM} – структурная ошибка, сохраняющая сходство локальной структуры, L_{edge} – ошибка по картам границ, вычисленным оператором Собеля. Коэффициенты \lambda_{GAN}, \lambda_{L1}, \lambda_{SSIM} и \lambda_{edge} задают вклад каждого компонента. В реализованной модели используются значения \lambda_{GAN}=0.5, \lambda_{L1}=150, \lambda_{SSIM}=25, \lambda_{edge}=20. Усиленные L1, SSIM и edge-компоненты делают модель менее «творческой», но лучше сохраняют контуры дорог и объектов, что важнее для последующего поиска ключевых точек.
|
||||
|
||||
Функция потерь дискриминатора имеет вид:
|
||||
|
||||
L_D = \frac{1}{2}\left(L_{GAN}\left(D\left(G,Y\right),1\right)+L_{GAN}\left(D\left(G,\hat{Y}\right),0\right)\right) (2)
|
||||
|
||||
После обучения GAN может использоваться в навигационном пайплайне следующим образом. Сначала для области предполагаемого положения БПЛА загружается или выбирается фрагмент Google Maps. Затем генератор переводит этот фрагмент в стиль Яндекс.Карт. На полученном изображении и на эталонном фрагменте Яндекс.Карт выделяются ключевые точки и дескрипторы. Далее дескрипторы сопоставляются, ложные соответствия отбрасываются при помощи RANSAC, а по оставшимся точкам оценивается матрица гомографии. Полученная матрица позволяет связать координаты текущего изображения с координатами эталонной карты и уточнить положение БПЛА.
|
||||
|
||||
Таким образом, GAN выступает как промежуточный модуль доменной адаптации. Его применение особенно полезно в ситуации, когда источник доступной карты и источник эталонных ориентиров различаются. В рассматриваемой задаче это позволяет перевести изображения Google в представление, близкое к Яндекс.Картам, где ориентиры визуально согласованы между собой. Благодаря этому классические методы компьютерного зрения получают более похожие изображения и могут устойчивее находить ключевые точки, не требуя полного отказа от интерпретируемого геометрического пайплайна.
|
||||
|
||||
@@ -6,3 +6,4 @@
|
||||
|-----------|----------|------|
|
||||
| 2.3.1 | Применение архитектуры сиамских близнецов для сопоставления кадров из различных доменов | 2.3.1_siamese_match.md |
|
||||
| 2.3.2 | Применение архитектуры сиамских близнецов для вычисления матрицы гомографии | 2.3.2_siamese_homography.md |
|
||||
| 2.3.3 | Генеративно-состязательная сеть для приведения карт к единому домену | 2.3.3_additional.md |
|
||||
|
||||
@@ -9,5 +9,6 @@
|
||||
| 2.3 | Методы глубокого обучения | 2.3_deep_learning/ |
|
||||
| 2.3.1 | Сиамские близнецы для сопоставления кадров из различных доменов | 2.3_deep_learning/2.3.1_siamese_match.md |
|
||||
| 2.3.2 | Сиамские близнецы для вычисления матрицы гомографии | 2.3_deep_learning/2.3.2_siamese_homography.md |
|
||||
| 2.3.3 | Генеративно-состязательная сеть для приведения карт к единому домену | 2.3_deep_learning/2.3.3_additional.md |
|
||||
| 2.4 | Датасет | 2.4_dataset/ |
|
||||
| 2.5 | Обучение моделей глубокого обучения | 2.5_training/ |
|
||||
|
||||
41
docs/google-city-1 --use-sian-similarity.txt
Normal file
41
docs/google-city-1 --use-sian-similarity.txt
Normal file
@@ -0,0 +1,41 @@
|
||||
google-city-1 --use-sian-similarity
|
||||
MSE: 1006.2278793057307
|
||||
RMSE: 31.721095178220608
|
||||
Average FPS: 2.1113510203551025
|
||||
|
||||
google-city-2 --use-sian-similarity
|
||||
MSE: 0.018811735940268307
|
||||
RMSE: 0.13715588190182842
|
||||
Average FPS: 4.383734633135686
|
||||
|
||||
google-city-3 --use-sian-similarity
|
||||
MSE: 0.06464743825147634
|
||||
RMSE: 0.2542586050686905
|
||||
Average FPS: 4.047345287474747
|
||||
|
||||
google-city-4 --use-sian-similarity
|
||||
MSE: 0.3089718382086463
|
||||
RMSE: 0.5558523528857697
|
||||
Average FPS: 2.9425785696773636
|
||||
|
||||
|
||||
|
||||
google-city-1
|
||||
MSE: 716.9078171684257
|
||||
RMSE: 26.775134307196776
|
||||
Average FPS: 38.766897877580504
|
||||
|
||||
google-city-2
|
||||
MSE: 0.06237516673056363
|
||||
RMSE: 0.2497502086697099
|
||||
Average FPS: 25.821371544698586
|
||||
|
||||
google-city-3
|
||||
MSE: 0.07598762118336233
|
||||
RMSE: 0.2756585227838282
|
||||
Average FPS: 26.37680633915316
|
||||
|
||||
google-city-3
|
||||
MSE: 0.7199724273833397
|
||||
RMSE: 0.8485118899481254
|
||||
Average FPS: 29.534802136097262
|
||||
18
main.py
18
main.py
@@ -120,7 +120,7 @@ def build(name: str, map_name: str, lat: float, lon: float):
|
||||
sleep(15)
|
||||
online_map.destroy()
|
||||
|
||||
def run(name: str, map_name: str, ref_min_distance: float):
|
||||
def run(name: str, map_name: str, ref_min_distance: float, use_sian_similarity: bool = False):
|
||||
dir = Path('trajectories')
|
||||
assert dir.exists()
|
||||
dir /= name
|
||||
@@ -158,7 +158,13 @@ def run(name: str, map_name: str, ref_min_distance: float):
|
||||
print("READ POINTS:", points)
|
||||
|
||||
vis_manager = VisualizationManager()
|
||||
pilot = autopilot.AutoPilot(points, chunks, vis_manager, online_map.pixel_ratio)
|
||||
pilot = autopilot.AutoPilot(
|
||||
points,
|
||||
chunks,
|
||||
vis_manager,
|
||||
online_map.pixel_ratio,
|
||||
use_sian_similarity=use_sian_similarity,
|
||||
)
|
||||
simulator = Simulator(online_map)
|
||||
pilot.target_idx = 0
|
||||
|
||||
@@ -300,6 +306,12 @@ def parse_args():
|
||||
help='Включить отладку эталонов'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--use-sian-similarity',
|
||||
action='store_true',
|
||||
help='Выбирать ориентир через SiaN similarity вместо ближайшего по текущей позиции'
|
||||
)
|
||||
|
||||
# Парсим аргументы
|
||||
args = parser.parse_args()
|
||||
|
||||
@@ -327,4 +339,4 @@ if __name__ == "__main__":
|
||||
build(name, ref, lat, lon)
|
||||
|
||||
if mode == 'run' or mode == 'standalone':
|
||||
run(name, sim, rmd)
|
||||
run(name, sim, rmd, args.use_sian_similarity)
|
||||
|
||||
BIN
map.jpg
BIN
map.jpg
Binary file not shown.
|
Before Width: | Height: | Size: 3.9 MiB After Width: | Height: | Size: 424 KiB |
@@ -1,254 +1,111 @@
|
||||
# GAN Trainer для преобразования изображений Yandex → Google
|
||||
|
||||
Этот модуль содержит реализацию тренера для GAN (Generative Adversarial Network) модели, предназначенной для преобразования изображений карт Yandex в стиль Google Maps.
|
||||
|
||||
## Структура проекта
|
||||
|
||||
```
|
||||
autopilot/models/GAN/
|
||||
├── gan.py # Основная реализация GAN модели
|
||||
├── trainer.py # Тренер для обучения GAN
|
||||
├── test_trainer.py # Тесты для тренера
|
||||
├── train_example.py # Пример использования тренера
|
||||
└── README.md # Этот файл
|
||||
```
|
||||
|
||||
## Модель GAN
|
||||
|
||||
Модель состоит из двух основных компонентов:
|
||||
|
||||
### 1. Генератор (GeneratorUNet)
|
||||
- Архитектура U-Net для преобразования изображений
|
||||
- Принимает изображение Yandex (3 канала RGB)
|
||||
- Возвращает изображение в стиле Google (3 канала RGB)
|
||||
- Использует skip connections для сохранения деталей
|
||||
|
||||
### 2. Дискриминатор (DiscriminatorPatchGAN)
|
||||
- PatchGAN архитектура
|
||||
- Принимает пару изображений (Yandex + Google)
|
||||
- Возвращает вероятность того, что пара реальная
|
||||
- Работает с патчами изображения 41x41
|
||||
|
||||
### Функция потерь (GANLoss)
|
||||
Поддерживает три режима:
|
||||
- `vanilla`: Бинарная кросс-энтропия
|
||||
- `lsgan`: Least Squares GAN (более стабильный)
|
||||
- `wgangp`: Wasserstein GAN with Gradient Penalty
|
||||
|
||||
## Тренер (GANTrainer)
|
||||
|
||||
### Основные возможности
|
||||
|
||||
1. **Обучение с чередованием**:
|
||||
- Обучение генератора и дискриминатора поочередно
|
||||
- Поддержка L1 потерь для сохранения структуры
|
||||
|
||||
2. **Валидация и мониторинг**:
|
||||
- Отдельные потери для генератора и дискриминатора
|
||||
- Логирование в TensorBoard
|
||||
- Ранняя остановка
|
||||
|
||||
3. **Сохранение и загрузка**:
|
||||
- Чекпоинты каждой эпохи
|
||||
- Лучшая модель
|
||||
- Финальная модель
|
||||
- История обучения
|
||||
|
||||
4. **Оценка модели**:
|
||||
- Метрики на тестовом наборе
|
||||
- Генерация примеров
|
||||
|
||||
### Быстрый старт
|
||||
|
||||
```python
|
||||
import torch
|
||||
from torch.utils.data import DataLoader
|
||||
from models.GAN.gan import create_image_gan
|
||||
from models.GAN.trainer import GANTrainer
|
||||
|
||||
# Конфигурация
|
||||
config = {
|
||||
"learning_rate": 2e-4,
|
||||
"beta1": 0.5,
|
||||
"beta2": 0.999,
|
||||
"batch_size": 4,
|
||||
"output_dir": "runs/gan_training",
|
||||
"gan_mode": "vanilla",
|
||||
"lambda_L1": 100.0,
|
||||
"early_stopping_patience": 20,
|
||||
}
|
||||
|
||||
# Устройство
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
# Создание модели
|
||||
model = create_image_gan(
|
||||
input_channels=3,
|
||||
output_channels=3,
|
||||
gan_mode=config["gan_mode"],
|
||||
lambda_L1=config["lambda_L1"],
|
||||
use_cuda=(device.type == "cuda"),
|
||||
)
|
||||
|
||||
# Создание даталоадеров (замените на свои данные)
|
||||
train_loader = DataLoader(train_dataset, batch_size=config["batch_size"], shuffle=True)
|
||||
val_loader = DataLoader(val_dataset, batch_size=config["batch_size"], shuffle=False)
|
||||
|
||||
# Создание тренера
|
||||
trainer = GANTrainer(
|
||||
model=model,
|
||||
train_loader=train_loader,
|
||||
val_loader=val_loader,
|
||||
device=device,
|
||||
config=config,
|
||||
)
|
||||
|
||||
# Обучение
|
||||
trainer.train(num_epochs=100)
|
||||
|
||||
# Оценка
|
||||
metrics = trainer.evaluate(test_loader)
|
||||
```
|
||||
|
||||
### Конфигурация обучения
|
||||
|
||||
#### Базовая конфигурация
|
||||
```python
|
||||
config = {
|
||||
# Параметры оптимизатора
|
||||
"learning_rate": 2e-4, # Learning rate
|
||||
"beta1": 0.5, # Adam beta1
|
||||
"beta2": 0.999, # Adam beta2
|
||||
|
||||
# Параметры обучения
|
||||
"batch_size": 4, # Размер батча
|
||||
"epochs": 100, # Количество эпох
|
||||
|
||||
# Параметры GAN
|
||||
"gan_mode": "vanilla", # Режим GAN
|
||||
"lambda_L1": 100.0, # Вес L1 потерь
|
||||
|
||||
# Регуляризация
|
||||
"grad_clip": 1.0, # Gradient clipping
|
||||
|
||||
# Ранняя остановка
|
||||
"early_stopping_patience": 20,
|
||||
|
||||
# Выходные данные
|
||||
"output_dir": "runs/gan",
|
||||
}
|
||||
```
|
||||
|
||||
#### Расширенная конфигурация
|
||||
```python
|
||||
config = {
|
||||
"learning_rate": 2e-4,
|
||||
"beta1": 0.5,
|
||||
"beta2": 0.999,
|
||||
"batch_size": 8,
|
||||
"epochs": 200,
|
||||
"gan_mode": "lsgan", # Более стабильный LSGAN
|
||||
"lambda_L1": 100.0,
|
||||
"grad_clip": 1.0,
|
||||
"weight_decay": 1e-4, # Weight decay
|
||||
"early_stopping_patience": 30,
|
||||
"early_stopping_min_delta": 1e-4,
|
||||
"output_dir": "runs/gan_advanced",
|
||||
}
|
||||
```
|
||||
|
||||
### Методы тренера
|
||||
|
||||
#### Основные методы
|
||||
- `train_epoch()`: Обучение на одной эпохе
|
||||
- `validate()`: Валидация модели
|
||||
- `train(num_epochs)`: Полное обучение
|
||||
- `evaluate(test_loader)`: Оценка на тестовых данных
|
||||
|
||||
#### Управление чекпоинтами
|
||||
- `save_checkpoint(is_best=False)`: Сохранение чекпоинта
|
||||
- `load_checkpoint(path, resume_training=False)`: Загрузка чекпоинта
|
||||
|
||||
### Выходные файлы
|
||||
|
||||
После обучения создаются следующие файлы:
|
||||
|
||||
```
|
||||
runs/gan_training/
|
||||
├── config.json # Конфигурация обучения
|
||||
├── training_history.json # История потерь
|
||||
├── model_best.pth # Лучшая модель
|
||||
├── model_final.pth # Финальная модель
|
||||
├── checkpoint_epoch_1.pth # Чекпоинты каждой эпохи
|
||||
├── checkpoint_epoch_2.pth
|
||||
├── ...
|
||||
└── tensorboard/ # Логи TensorBoard
|
||||
├── events.out.tfevents...
|
||||
└── ...
|
||||
```
|
||||
|
||||
### TensorBoard
|
||||
|
||||
Для визуализации обучения используйте TensorBoard:
|
||||
|
||||
```bash
|
||||
tensorboard --logdir runs/gan_training/tensorboard
|
||||
```
|
||||
|
||||
Доступные метрики:
|
||||
- `train/batch_g_loss`: Потери генератора на батче
|
||||
- `train/batch_d_loss`: Потери дискриминатора на батче
|
||||
- `train/batch_g_l1_loss`: L1 потери генератора
|
||||
- `train/epoch_g_loss`: Потери генератора на эпохе
|
||||
- `train/epoch_d_loss`: Потери дискриминатора на эпохе
|
||||
- `val/epoch_g_loss`: Валидационные потери генератора
|
||||
- `val/epoch_d_loss`: Валидационные потери дискриминатора
|
||||
|
||||
### Тестирование
|
||||
|
||||
Запустите тесты для проверки работоспособности:
|
||||
|
||||
```bash
|
||||
python models/GAN/test_trainer.py
|
||||
```
|
||||
|
||||
### Пример использования
|
||||
|
||||
Полный пример использования смотрите в `train_example.py`.
|
||||
|
||||
### Советы по обучению
|
||||
|
||||
1. **Начальные значения**:
|
||||
- Используйте `gan_mode="lsgan"` для более стабильного обучения
|
||||
- Начните с `lambda_L1=100.0` и регулируйте по необходимости
|
||||
- Используйте маленький `batch_size` (4-8) при ограниченной памяти GPU
|
||||
|
||||
2. **Мониторинг**:
|
||||
- Следите за балансом потерь генератора и дискриминатора
|
||||
- Если потери дискриминатора близки к 0, генератор не обучается
|
||||
- Если потери генератора слишком высоки, уменьшите `lambda_L1`
|
||||
|
||||
3. **Визуализация**:
|
||||
- Регулярно генерируйте примеры для визуальной оценки
|
||||
- Используйте TensorBoard для отслеживания прогресса
|
||||
|
||||
### Устранение проблем
|
||||
|
||||
#### Высокие потери генератора
|
||||
- Уменьшите `lambda_L1`
|
||||
- Увеличьте learning rate
|
||||
- Проверьте качество данных
|
||||
|
||||
#### Дискриминатор слишком сильный
|
||||
- Уменьшите learning rate дискриминатора
|
||||
- Добавьте dropout в дискриминатор
|
||||
- Обучайте генератор чаще, чем дискриминатор
|
||||
|
||||
#### Недостаток памяти GPU
|
||||
- Уменьшите `batch_size`
|
||||
- Уменьшите размер изображений
|
||||
- Используйте gradient accumulation
|
||||
|
||||
### Лицензия
|
||||
|
||||
Этот проект является частью Autopilot системы.
|
||||
# GAN для преобразования Google -> Yandex
|
||||
|
||||
Модуль реализует pix2pix-подобный GAN для парных изображений карт. Генератор получает изображение Google из пары и пытается сгенерировать изображение в стиле Yandex. Сгенерированная картинка сравнивается со вторым изображением пары, то есть с оригинальным `*_yandex.png`.
|
||||
|
||||
## Структура
|
||||
|
||||
```text
|
||||
models/GAN/
|
||||
├── build.py # генератор notebook.gen.ipynb по схеме
|
||||
├── _schema.py # структура генерируемого ноутбука
|
||||
├── _schema.md # описание формата схемы
|
||||
├── notebook.gen.ipynb
|
||||
├── src/
|
||||
│ ├── config.py
|
||||
│ ├── dataloader.py
|
||||
│ ├── model.py
|
||||
│ ├── trainer.py
|
||||
│ ├── analyze.py
|
||||
│ └── main.py
|
||||
└── README.md
|
||||
```
|
||||
|
||||
## Архитектура
|
||||
|
||||
`GeneratorUNet` принимает `google_img` с 3 RGB-каналами и возвращает `fake_yandex`.
|
||||
|
||||
`DiscriminatorPatchGAN` получает пару `(google_img, yandex_img)` и отличает настоящую пару от `(google_img, fake_yandex)`.
|
||||
|
||||
`ImageGAN.generator_step()` считает:
|
||||
|
||||
- adversarial loss: дискриминатор должен принять `(google_img, fake_yandex)` за реальную пару;
|
||||
- L1 loss: `fake_yandex` сравнивается с оригинальным `yandex_img` из той же пары.
|
||||
- SSIM loss: штрафует структурные отличия, чтобы карта не расплывалась;
|
||||
- Sobel edge loss: сохраняет контуры дорог и объектов для последующего поиска ключевых точек.
|
||||
|
||||
Итоговая функция потерь генератора:
|
||||
|
||||
```text
|
||||
G_loss =
|
||||
lambda_GAN * GAN_loss(D(google_img, fake_yandex), real)
|
||||
+ lambda_L1 * L1(fake_yandex, yandex_img)
|
||||
+ lambda_SSIM * SSIMLoss(fake_yandex, yandex_img)
|
||||
+ lambda_edge * SobelEdgeLoss(fake_yandex, yandex_img)
|
||||
```
|
||||
|
||||
По умолчанию используется `lsgan`, усиленная реконструкция и более медленный
|
||||
дискриминатор. Это менее “креативно”, зато обычно даёт более чистые контуры.
|
||||
|
||||
## Запуск
|
||||
|
||||
Из папки `models/GAN`:
|
||||
|
||||
```bash
|
||||
python src/main.py
|
||||
```
|
||||
|
||||
`src/main.py` специально написан без функции `main()`: при генерации ноутбука
|
||||
все переменные остаются доступны после запуска ячейки, как в SiaN.
|
||||
|
||||
Чтобы пересобрать ноутбук из схемы:
|
||||
|
||||
```bash
|
||||
python build.py
|
||||
```
|
||||
|
||||
По умолчанию датасет берется из:
|
||||
|
||||
```text
|
||||
C:\Users\admin\Projects\autopilot\datasets\ya_go_maps\images
|
||||
```
|
||||
|
||||
Путь, размер изображений, batch size и число эпох меняются в `src/config.py`.
|
||||
|
||||
Если CUDA видна, но установленный PyTorch не поддерживает архитектуру GPU
|
||||
например Tesla P100 `sm_60`, код автоматически переключится на CPU. Это
|
||||
управляется параметром `prefer_cuda` в `src/config.py`.
|
||||
|
||||
## Ожидаемый формат данных
|
||||
|
||||
В директории датасета должны лежать пары:
|
||||
|
||||
```text
|
||||
0000_google.png
|
||||
0000_yandex.png
|
||||
0001_google.png
|
||||
0001_yandex.png
|
||||
...
|
||||
```
|
||||
|
||||
Даталоадер возвращает:
|
||||
|
||||
- `google_img`: вход генератора;
|
||||
- `yandex_img`: целевое изображение для сравнения;
|
||||
- `idx`: номер пары.
|
||||
|
||||
## Чекпоинты
|
||||
|
||||
Тренер сохраняет чекпоинты в:
|
||||
|
||||
```text
|
||||
models/GAN/runs/checkpoints
|
||||
```
|
||||
|
||||
Сохраняются `best.pth`, периодические `epoch_N.pth` и `final.pth`.
|
||||
|
||||
После обучения `src/main.py` также строит:
|
||||
|
||||
- графики `G/D/L1/SSIM/edge loss`;
|
||||
- сетку `Google input -> Generated Yandex -> Yandex target`.
|
||||
|
||||
Файлы сохраняются в `models/GAN/runs/images`.
|
||||
|
||||
32
models/GAN/_schema.md
Normal file
32
models/GAN/_schema.md
Normal file
@@ -0,0 +1,32 @@
|
||||
# GAN Schema
|
||||
|
||||
Notebook structure definition for the Google -> Yandex GAN model.
|
||||
|
||||
---
|
||||
|
||||
## Format
|
||||
|
||||
```text
|
||||
# === IMPORTS ===
|
||||
<all imports>
|
||||
|
||||
# markdown
|
||||
"""Description"""
|
||||
|
||||
# code: ./src/file.py
|
||||
|
||||
# # shell:
|
||||
<shell commands>
|
||||
```
|
||||
|
||||
## Directives
|
||||
|
||||
| Directive | Description |
|
||||
|----------------|------------------------------------|
|
||||
| `# code:` | Include file from `src/` |
|
||||
| `# markdown` | Add markdown cell |
|
||||
| `# # shell:` | Add notebook shell command cell |
|
||||
|
||||
---
|
||||
|
||||
`build.py` generates `notebook.gen.ipynb` from `_schema.py`.
|
||||
121
models/GAN/_schema.py
Normal file
121
models/GAN/_schema.py
Normal file
@@ -0,0 +1,121 @@
|
||||
# _schema.py
|
||||
|
||||
# === IMPORTS ===
|
||||
import json
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Tuple
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from PIL import Image
|
||||
from torch.utils.data import DataLoader, Dataset, Subset
|
||||
from torchvision import transforms
|
||||
from tqdm import tqdm
|
||||
|
||||
# markdown
|
||||
"""# Configuration
|
||||
|
||||
Global settings for the Google -> Yandex GAN:
|
||||
- Dataset path and image size
|
||||
- Optimizer and training hyperparameters
|
||||
- Device preference with safe CUDA compatibility checks
|
||||
- GAN, L1, SSIM and edge reconstruction weights
|
||||
- Output directories for checkpoints and generated samples"""
|
||||
# code: ./src/config.py
|
||||
|
||||
# markdown
|
||||
"""## Dataset
|
||||
|
||||
Google/Yandex paired image loader.
|
||||
|
||||
**Direction:**
|
||||
- `google_img` is the generator input
|
||||
- `yandex_img` is the target image from the same pair
|
||||
|
||||
**Returns:**
|
||||
- Batch dict with `google_img`, `yandex_img`, `idx`"""
|
||||
# code: ./src/dataloader.py
|
||||
|
||||
# code: ./src/test_dataloader.py
|
||||
|
||||
# markdown
|
||||
"""## Model
|
||||
|
||||
Pix2pix-style GAN for Google -> Yandex map translation.
|
||||
|
||||
**Generator:**
|
||||
- `GeneratorUNet`
|
||||
- Input: Google image `(B, 3, H, W)`
|
||||
- Output: generated Yandex image `(B, 3, H, W)`
|
||||
|
||||
**Discriminator:**
|
||||
- `DiscriminatorPatchGAN`
|
||||
- Input pair: `(google_img, yandex_img)`
|
||||
- Learns to distinguish real pairs from `(google_img, fake_yandex)`
|
||||
|
||||
**Generator loss:**
|
||||
- adversarial loss
|
||||
- `lambda_L1 * L1(fake_yandex, yandex_img)`
|
||||
- `lambda_SSIM * SSIMLoss(fake_yandex, yandex_img)`
|
||||
- `lambda_edge * SobelEdgeLoss(fake_yandex, yandex_img)`
|
||||
|
||||
The generator uses bilinear upsampling followed by convolution to avoid
|
||||
checkerboard artifacts from transposed convolutions."""
|
||||
# code: ./src/model.py
|
||||
|
||||
# markdown
|
||||
"""## Training
|
||||
|
||||
`GANTrainer` trains discriminator and generator alternately.
|
||||
|
||||
**Training step:**
|
||||
1. Generate `fake_yandex = G(google_img)`
|
||||
2. Train discriminator on real pair `(google_img, yandex_img)` and fake pair `(google_img, fake_yandex)`
|
||||
3. Train generator against discriminator and paired Yandex target
|
||||
|
||||
**Checkpoint saving:**
|
||||
- `best.pth`
|
||||
- `epoch_N.pth`
|
||||
- `final.pth`"""
|
||||
# code: ./src/trainer.py
|
||||
|
||||
# markdown
|
||||
"""## Analysis
|
||||
|
||||
Visualization helpers for generated samples and collected training metrics.
|
||||
|
||||
Training history plot contains:
|
||||
1. Generator loss
|
||||
2. Discriminator loss
|
||||
3. L1 loss against the paired Yandex target
|
||||
4. SSIM structure loss
|
||||
5. Sobel edge loss
|
||||
6. Best-checkpoint reconstruction score
|
||||
|
||||
The sample grid contains:
|
||||
1. Google input
|
||||
2. Generated Yandex
|
||||
3. Original Yandex target"""
|
||||
# code: ./src/analyze.py
|
||||
|
||||
# markdown
|
||||
"""## Main Pipeline
|
||||
|
||||
Executes the full GAN workflow:
|
||||
1. Create config
|
||||
2. Build paired data loaders
|
||||
3. Initialize Google -> Yandex GAN
|
||||
4. Train with validation
|
||||
5. Save checkpoints in `runs/checkpoints/`
|
||||
6. Show loss plots and generated sample grid
|
||||
|
||||
This block is intentionally top-level, not wrapped in `main()`, so notebook
|
||||
variables such as `model`, `trainer`, `train_loader`, `val_loader`, and
|
||||
`training_analysis` remain available for debugging."""
|
||||
# code: ./src/main.py
|
||||
|
||||
# # shell:
|
||||
# !zip artefacts.zip runs/checkpoints/best.pth runs/images/training_history.png runs/images/generation_samples.png
|
||||
260
models/GAN/build.py
Normal file
260
models/GAN/build.py
Normal file
@@ -0,0 +1,260 @@
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
|
||||
|
||||
def parse_schema(schema_path):
|
||||
with open(schema_path, 'r', encoding='utf-8') as f:
|
||||
content = f.read()
|
||||
|
||||
imports = []
|
||||
items = []
|
||||
|
||||
in_imports_section = False
|
||||
lines = content.split('\n')
|
||||
|
||||
i = 0
|
||||
while i < len(lines):
|
||||
line = lines[i]
|
||||
stripped = line.strip()
|
||||
|
||||
if stripped == '# === IMPORTS ===':
|
||||
in_imports_section = True
|
||||
i += 1
|
||||
continue
|
||||
elif stripped.startswith('# ==='):
|
||||
in_imports_section = False
|
||||
i += 1
|
||||
continue
|
||||
|
||||
if in_imports_section and stripped and not stripped.startswith('#'):
|
||||
imports.append(stripped)
|
||||
i += 1
|
||||
continue
|
||||
|
||||
if in_imports_section and stripped.startswith('#'):
|
||||
in_imports_section = False
|
||||
|
||||
if stripped.startswith('# code:'):
|
||||
match = re.search(r'# code:\s*(.+)', stripped)
|
||||
if match:
|
||||
items.append(('code', match.group(1).strip()))
|
||||
i += 1
|
||||
continue
|
||||
|
||||
if stripped.startswith('# inline:') or stripped.startswith('# # shell:'):
|
||||
directive = 'inline' if stripped.startswith('# inline:') else 'shell'
|
||||
i += 1
|
||||
code_lines = []
|
||||
while i < len(lines):
|
||||
stripped = lines[i].strip()
|
||||
if directive == 'shell':
|
||||
if stripped.startswith('# ') or stripped.startswith('# # '):
|
||||
line_content = lines[i].strip()
|
||||
if line_content.startswith('# # '):
|
||||
line_content = line_content[3:].lstrip()
|
||||
else:
|
||||
line_content = line_content[2:]
|
||||
code_lines.append(line_content)
|
||||
i += 1
|
||||
continue
|
||||
if stripped.startswith('#'):
|
||||
if stripped.startswith('# code:') or stripped.startswith('# inline:') or stripped.startswith('# # shell:') or stripped == '# markdown' or stripped.startswith('# ==='):
|
||||
break
|
||||
i += 1
|
||||
continue
|
||||
if stripped == '':
|
||||
i += 1
|
||||
continue
|
||||
code_lines.append(lines[i])
|
||||
i += 1
|
||||
if code_lines:
|
||||
items.append((directive, '\n'.join(code_lines).rstrip()))
|
||||
continue
|
||||
|
||||
if stripped == '# markdown':
|
||||
i += 1
|
||||
if i < len(lines):
|
||||
next_line = lines[i].strip()
|
||||
if next_line.startswith('"""'):
|
||||
if next_line.endswith('"""') and len(next_line) > 3:
|
||||
md_content = next_line[3:-3].strip()
|
||||
i += 1
|
||||
else:
|
||||
end_idx = None
|
||||
for j in range(i + 1, len(lines)):
|
||||
if '"""' in lines[j]:
|
||||
end_idx = j
|
||||
break
|
||||
if end_idx:
|
||||
md_content = '\n'.join(lines[i:end_idx])
|
||||
md_content = md_content.strip('"""').strip()
|
||||
i = end_idx + 1
|
||||
else:
|
||||
md_content = ""
|
||||
i += 1
|
||||
items.append(('markdown', md_content))
|
||||
continue
|
||||
|
||||
i += 1
|
||||
|
||||
return imports, items
|
||||
|
||||
|
||||
def strip_all_imports(content):
|
||||
lines = content.split('\n')
|
||||
result_lines = []
|
||||
skip_block = False
|
||||
if_block_indent = 0
|
||||
|
||||
for line in lines:
|
||||
stripped = line.strip()
|
||||
|
||||
if stripped.startswith('if __name__') or stripped.startswith('if __name__ =='):
|
||||
skip_block = True
|
||||
if_block_indent = len(line) - len(line.lstrip())
|
||||
continue
|
||||
|
||||
if skip_block:
|
||||
current_indent = len(line) - len(line.lstrip())
|
||||
if line.strip() == '':
|
||||
continue
|
||||
if current_indent < if_block_indent:
|
||||
skip_block = False
|
||||
elif current_indent == if_block_indent and stripped.startswith('if '):
|
||||
skip_block = True
|
||||
if_block_indent = current_indent
|
||||
continue
|
||||
else:
|
||||
continue
|
||||
|
||||
if stripped.startswith('import ') or stripped.startswith('from '):
|
||||
continue
|
||||
|
||||
result_lines.append(line)
|
||||
|
||||
while result_lines and result_lines[-1].strip() == '':
|
||||
result_lines.pop()
|
||||
|
||||
return '\n'.join(result_lines)
|
||||
|
||||
|
||||
def read_src_file(ref, base_dir):
|
||||
ref_path = ref.replace('./src/', '').replace('src/', '')
|
||||
full_path = os.path.join(base_dir, 'src', ref_path.replace('./src/', '').lstrip('/'))
|
||||
|
||||
if not full_path.endswith('.py'):
|
||||
full_path += '.py'
|
||||
|
||||
with open(full_path, 'r', encoding='utf-8') as f:
|
||||
return f.read()
|
||||
|
||||
|
||||
def build_notebook(schema_path, output_path=None):
|
||||
schema_dir = os.path.dirname(os.path.abspath(schema_path))
|
||||
if output_path is None:
|
||||
output_path = os.path.join(schema_dir, 'notebook.gen.ipynb')
|
||||
|
||||
imports, items = parse_schema(schema_path)
|
||||
|
||||
cells = []
|
||||
|
||||
if imports:
|
||||
imports_cell = {
|
||||
"cell_type": "code",
|
||||
"execution_count": None,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [imp + "\n" for imp in imports]
|
||||
}
|
||||
cells.append(imports_cell)
|
||||
|
||||
for item_type, item_content in items:
|
||||
if item_type == 'markdown':
|
||||
md_cell = {
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": item_content + "\n"
|
||||
}
|
||||
cells.append(md_cell)
|
||||
elif item_type == 'inline':
|
||||
lines = item_content.split('\n')
|
||||
while lines and lines[-1].strip() == '':
|
||||
lines.pop()
|
||||
if lines:
|
||||
lines.append('')
|
||||
cell = {
|
||||
"cell_type": "code",
|
||||
"execution_count": None,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [line + "\n" for line in lines]
|
||||
}
|
||||
cells.append(cell)
|
||||
elif item_type == 'shell':
|
||||
lines = item_content.split('\n')
|
||||
while lines and lines[-1].strip() == '':
|
||||
lines.pop()
|
||||
if lines:
|
||||
lines.append('')
|
||||
cell = {
|
||||
"cell_type": "code",
|
||||
"execution_count": None,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [line + "\n" for line in lines]
|
||||
}
|
||||
cells.append(cell)
|
||||
elif item_type == 'code':
|
||||
try:
|
||||
content = read_src_file(item_content, schema_dir)
|
||||
content = strip_all_imports(content)
|
||||
|
||||
lines = content.split('\n')
|
||||
while lines and lines[-1].strip() == '':
|
||||
lines.pop()
|
||||
if lines:
|
||||
lines.append('')
|
||||
|
||||
cell = {
|
||||
"cell_type": "code",
|
||||
"execution_count": None,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [line + "\n" for line in lines]
|
||||
}
|
||||
cells.append(cell)
|
||||
except FileNotFoundError:
|
||||
print(f"Warning: Could not find: {item_content}")
|
||||
|
||||
notebook = {
|
||||
"cells": cells,
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": ".venv",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"name": "python",
|
||||
"version": "3.11.0"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
||||
|
||||
with open(output_path, 'w', encoding='utf-8') as f:
|
||||
json.dump(notebook, f, indent=1, ensure_ascii=False)
|
||||
|
||||
print(f"Notebook generated: {output_path}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
script_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
schema_path = os.path.join(script_dir, '_schema.py')
|
||||
|
||||
if os.path.exists(schema_path):
|
||||
build_notebook(schema_path)
|
||||
else:
|
||||
print(f"Error: _schema.py not found at {schema_path}")
|
||||
@@ -1,30 +0,0 @@
|
||||
"""Main entry point for GAN training."""
|
||||
|
||||
from config import create_config
|
||||
from dataloader import create_data_loaders
|
||||
from model import create_gan
|
||||
from trainer import create_trainer
|
||||
|
||||
|
||||
def main():
|
||||
"""Run training pipeline."""
|
||||
config = create_config()
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
# Create components
|
||||
model = create_gan(use_cuda=False) # Set to True to use GPU
|
||||
train_loader, val_loader = create_data_loaders(
|
||||
root_dir=config["data_dir"],
|
||||
batch_size=config["batch_size"],
|
||||
image_size=tuple(config["image_size"]),
|
||||
num_workers=config["num_workers"],
|
||||
)
|
||||
trainer = create_trainer(model, train_loader, val_loader, config)
|
||||
|
||||
# Train
|
||||
trainer.train(config["epochs"])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import torch
|
||||
main()
|
||||
1064
models/GAN/notebook.gen.ipynb
Normal file
1064
models/GAN/notebook.gen.ipynb
Normal file
File diff suppressed because it is too large
Load Diff
1
models/GAN/src/__init__.py
Normal file
1
models/GAN/src/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""GAN package for Google-to-Yandex map image translation."""
|
||||
117
models/GAN/src/analyze.py
Normal file
117
models/GAN/src/analyze.py
Normal file
@@ -0,0 +1,117 @@
|
||||
from pathlib import Path
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import torch
|
||||
|
||||
|
||||
def denormalize_image(tensor):
|
||||
return (tensor.detach().cpu() * 0.5 + 0.5).clamp(0, 1)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def visualize_generation(model, data_loader, output_dir, device=None, num_samples=4, show=True):
|
||||
device = device or model.device
|
||||
model.eval()
|
||||
|
||||
batch = next(iter(data_loader))
|
||||
google_img = batch["google_img"][:num_samples].to(device)
|
||||
yandex_img = batch["yandex_img"][:num_samples].to(device)
|
||||
fake_yandex = model.generator(google_img)
|
||||
|
||||
output_dir = Path(output_dir)
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
output_path = output_dir / "generation_samples.png"
|
||||
|
||||
fig, axes = plt.subplots(num_samples, 3, figsize=(9, 3 * num_samples))
|
||||
if num_samples == 1:
|
||||
axes = axes.reshape(1, 3)
|
||||
|
||||
titles = ["Google input", "Generated Yandex", "Yandex target"]
|
||||
for row in range(num_samples):
|
||||
images = [google_img[row], fake_yandex[row], yandex_img[row]]
|
||||
for col, image in enumerate(images):
|
||||
axes[row, col].imshow(denormalize_image(image).permute(1, 2, 0))
|
||||
axes[row, col].set_title(titles[col])
|
||||
axes[row, col].axis("off")
|
||||
|
||||
fig.tight_layout()
|
||||
fig.savefig(output_path, dpi=150)
|
||||
if show:
|
||||
plt.show()
|
||||
plt.close(fig)
|
||||
return output_path
|
||||
|
||||
|
||||
def plot_training_history(trainer, output_dir, show=True):
|
||||
output_dir = Path(output_dir)
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
output_path = output_dir / "training_history.png"
|
||||
|
||||
epochs = range(1, len(trainer.g_losses) + 1)
|
||||
fig, axes = plt.subplots(2, 3, figsize=(15, 8))
|
||||
axes = axes.ravel()
|
||||
|
||||
axes[0].plot(epochs, trainer.g_losses, label="train G")
|
||||
axes[0].plot(epochs, trainer.val_g_losses, label="val G")
|
||||
axes[0].set_title("Generator loss")
|
||||
axes[0].set_xlabel("Epoch")
|
||||
axes[0].legend()
|
||||
axes[0].grid(True, alpha=0.3)
|
||||
|
||||
axes[1].plot(epochs, trainer.d_losses, label="train D")
|
||||
axes[1].plot(epochs, trainer.val_d_losses, label="val D")
|
||||
axes[1].set_title("Discriminator loss")
|
||||
axes[1].set_xlabel("Epoch")
|
||||
axes[1].legend()
|
||||
axes[1].grid(True, alpha=0.3)
|
||||
|
||||
axes[2].plot(epochs, trainer.l1_losses, label="train L1")
|
||||
axes[2].plot(epochs, trainer.val_l1_losses, label="val L1")
|
||||
axes[2].set_title("Paired Yandex L1 loss")
|
||||
axes[2].set_xlabel("Epoch")
|
||||
axes[2].legend()
|
||||
axes[2].grid(True, alpha=0.3)
|
||||
|
||||
axes[3].plot(epochs, trainer.ssim_losses, label="train SSIM")
|
||||
axes[3].plot(epochs, trainer.val_ssim_losses, label="val SSIM")
|
||||
axes[3].set_title("SSIM structure loss")
|
||||
axes[3].set_xlabel("Epoch")
|
||||
axes[3].legend()
|
||||
axes[3].grid(True, alpha=0.3)
|
||||
|
||||
axes[4].plot(epochs, trainer.edge_losses, label="train edge")
|
||||
axes[4].plot(epochs, trainer.val_edge_losses, label="val edge")
|
||||
axes[4].set_title("Sobel edge loss")
|
||||
axes[4].set_xlabel("Epoch")
|
||||
axes[4].legend()
|
||||
axes[4].grid(True, alpha=0.3)
|
||||
|
||||
axes[5].plot(epochs, trainer.val_reconstruction_losses, label="val reconstruction")
|
||||
axes[5].set_title("Best-checkpoint score")
|
||||
axes[5].set_xlabel("Epoch")
|
||||
axes[5].legend()
|
||||
axes[5].grid(True, alpha=0.3)
|
||||
|
||||
fig.tight_layout()
|
||||
fig.savefig(output_path, dpi=150)
|
||||
if show:
|
||||
plt.show()
|
||||
plt.close(fig)
|
||||
return output_path
|
||||
|
||||
|
||||
def analyze_training(trainer):
|
||||
return {
|
||||
"best_val_loss": trainer.best_val_loss,
|
||||
"g_losses": trainer.g_losses,
|
||||
"d_losses": trainer.d_losses,
|
||||
"l1_losses": trainer.l1_losses,
|
||||
"ssim_losses": trainer.ssim_losses,
|
||||
"edge_losses": trainer.edge_losses,
|
||||
"val_g_losses": trainer.val_g_losses,
|
||||
"val_d_losses": trainer.val_d_losses,
|
||||
"val_l1_losses": trainer.val_l1_losses,
|
||||
"val_ssim_losses": trainer.val_ssim_losses,
|
||||
"val_edge_losses": trainer.val_edge_losses,
|
||||
"val_reconstruction_losses": trainer.val_reconstruction_losses,
|
||||
}
|
||||
@@ -6,23 +6,30 @@ def create_config():
|
||||
return {
|
||||
# Optimizer params
|
||||
"learning_rate": 2e-4,
|
||||
"discriminator_lr_factor": 0.5,
|
||||
"beta1": 0.5,
|
||||
"beta2": 0.999,
|
||||
# Training params
|
||||
"batch_size": 32,
|
||||
"epochs": 100,
|
||||
"prefer_cuda": True,
|
||||
# GAN params
|
||||
"gan_mode": "vanilla",
|
||||
"lambda_L1": 100.0,
|
||||
"gan_mode": "lsgan",
|
||||
"lambda_GAN": 0.5,
|
||||
"lambda_L1": 150.0,
|
||||
"lambda_SSIM": 25.0,
|
||||
"lambda_edge": 20.0,
|
||||
"discriminator_update_interval": 1,
|
||||
# Regularization
|
||||
"grad_clip": 1.0,
|
||||
# Early stopping
|
||||
"early_stopping_patience": 20,
|
||||
"early_stopping_patience": 25,
|
||||
# Output
|
||||
"output_dir": "runs/gan_training",
|
||||
"output_dir": r"C:\Users\admin\Projects\autopilot\models\GAN\runs",
|
||||
# Logging
|
||||
"log_interval": 10,
|
||||
"save_interval": 5,
|
||||
"num_visual_samples": 4,
|
||||
# Data
|
||||
"data_dir": r"C:\Users\admin\Projects\autopilot\datasets\ya_go_maps\images",
|
||||
"image_size": [256, 256],
|
||||
@@ -33,4 +40,4 @@ def create_config():
|
||||
|
||||
if __name__ == "__main__":
|
||||
config = create_config()
|
||||
print("Default config:", config)
|
||||
print("Default config:", config)
|
||||
@@ -1,4 +1,4 @@
|
||||
"""Data loader for Yandex-to-Google image translation."""
|
||||
"""Data loader for Google-to-Yandex image translation."""
|
||||
|
||||
import os
|
||||
from typing import Dict, List, Tuple
|
||||
@@ -10,7 +10,7 @@ from torchvision import transforms
|
||||
|
||||
|
||||
class YaGoDataset(Dataset):
|
||||
"""Dataset loading pairs of Yandex and Google map images."""
|
||||
"""Dataset loading paired Google and Yandex map images."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -159,4 +159,4 @@ if __name__ == "__main__":
|
||||
|
||||
batch = next(iter(train_loader))
|
||||
print(f"Batch shapes: google={batch['google_img'].shape}, yandex={batch['yandex_img'].shape}")
|
||||
print(f"Train batches: {len(train_loader)}, Val batches: {len(val_loader)}")
|
||||
print(f"Train batches: {len(train_loader)}, Val batches: {len(val_loader)}")
|
||||
60
models/GAN/src/main.py
Normal file
60
models/GAN/src/main.py
Normal file
@@ -0,0 +1,60 @@
|
||||
"""Executable GAN training pipeline.
|
||||
|
||||
The code is intentionally top-level, mirroring the SiaN notebook style:
|
||||
when this file is included in the generated notebook, variables remain
|
||||
available for debugging and interactive experiments.
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
|
||||
from analyze import analyze_training, plot_training_history, visualize_generation
|
||||
from config import create_config
|
||||
from dataloader import create_data_loaders
|
||||
from model import create_gan, get_compatible_device
|
||||
from trainer import create_trainer
|
||||
|
||||
|
||||
config = create_config()
|
||||
device = get_compatible_device(prefer_cuda=config["prefer_cuda"])
|
||||
print(f"Using device: {device}")
|
||||
|
||||
train_loader, val_loader = create_data_loaders(
|
||||
root_dir=config["data_dir"],
|
||||
batch_size=config["batch_size"],
|
||||
train_split=config["train_split"],
|
||||
image_size=tuple(config["image_size"]),
|
||||
num_workers=config["num_workers"],
|
||||
)
|
||||
print(f"Data loaders created: train={len(train_loader.dataset)}, val={len(val_loader.dataset)}")
|
||||
|
||||
model = create_gan(
|
||||
gan_mode=config["gan_mode"],
|
||||
lambda_GAN=config["lambda_GAN"],
|
||||
lambda_L1=config["lambda_L1"],
|
||||
lambda_SSIM=config["lambda_SSIM"],
|
||||
lambda_edge=config["lambda_edge"],
|
||||
use_cuda=(device.type == "cuda"),
|
||||
)
|
||||
|
||||
generator_params = sum(p.numel() for p in model.generator.parameters())
|
||||
discriminator_params = sum(p.numel() for p in model.discriminator.parameters())
|
||||
print(f"Model created: generator={generator_params:,}, discriminator={discriminator_params:,}")
|
||||
|
||||
trainer = create_trainer(model, train_loader, val_loader, config)
|
||||
trainer.train(config["epochs"])
|
||||
|
||||
training_analysis = analyze_training(trainer)
|
||||
images_dir = Path(config["output_dir"]) / "images"
|
||||
history_plot_path = plot_training_history(trainer, images_dir)
|
||||
generation_samples_path = visualize_generation(
|
||||
model=model,
|
||||
data_loader=val_loader,
|
||||
output_dir=images_dir,
|
||||
device=device,
|
||||
num_samples=config["num_visual_samples"],
|
||||
)
|
||||
|
||||
print(f"Training history plot: {history_plot_path}")
|
||||
print(f"Generation samples: {generation_samples_path}")
|
||||
@@ -1,10 +1,37 @@
|
||||
"""GAN model for image translation Yandex -> Google."""
|
||||
"""GAN model for image translation Google -> Yandex."""
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
def get_compatible_device(prefer_cuda: bool = True, verbose: bool = True) -> torch.device:
|
||||
"""Return CUDA only when the current PyTorch build supports the GPU arch."""
|
||||
if not prefer_cuda or not torch.cuda.is_available():
|
||||
return torch.device("cpu")
|
||||
|
||||
try:
|
||||
major, minor = torch.cuda.get_device_capability()
|
||||
arch = f"sm_{major}{minor}"
|
||||
supported_arches = set(torch.cuda.get_arch_list())
|
||||
gpu_name = torch.cuda.get_device_name()
|
||||
except Exception as exc:
|
||||
if verbose:
|
||||
print(f"CUDA is visible but cannot be inspected ({exc}); using CPU.")
|
||||
return torch.device("cpu")
|
||||
|
||||
if supported_arches and arch not in supported_arches:
|
||||
if verbose:
|
||||
supported = ", ".join(sorted(supported_arches))
|
||||
print(
|
||||
f"CUDA GPU '{gpu_name}' has capability {arch}, but this PyTorch build "
|
||||
f"supports only: {supported}. Using CPU."
|
||||
)
|
||||
return torch.device("cpu")
|
||||
|
||||
return torch.device("cuda")
|
||||
|
||||
|
||||
class UNetDownBlock(nn.Module):
|
||||
"""Downsampling block for U-Net."""
|
||||
|
||||
@@ -29,7 +56,10 @@ class UNetUpBlock(nn.Module):
|
||||
|
||||
def __init__(self, in_channels: int, out_channels: int, dropout: float = 0.0):
|
||||
super().__init__()
|
||||
self.upconv = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=4, stride=2, padding=1, bias=False)
|
||||
self.upconv = nn.Sequential(
|
||||
nn.Upsample(scale_factor=2, mode="bilinear", align_corners=False),
|
||||
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, bias=False),
|
||||
)
|
||||
self.norm = nn.BatchNorm2d(out_channels)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
if dropout > 0:
|
||||
@@ -38,7 +68,6 @@ class UNetUpBlock(nn.Module):
|
||||
self.dropout = None
|
||||
|
||||
def forward(self, x, skip_input):
|
||||
|
||||
x = self.upconv(x)
|
||||
# Pad if needed to match skip connection size
|
||||
if x.shape != skip_input.shape:
|
||||
@@ -54,7 +83,7 @@ class UNetUpBlock(nn.Module):
|
||||
|
||||
|
||||
class GeneratorUNet(nn.Module):
|
||||
"""U-Net generator for Yandex -> Google translation."""
|
||||
"""U-Net generator for Google -> Yandex translation."""
|
||||
|
||||
def __init__(self, in_channels: int = 3, out_channels: int = 3):
|
||||
super().__init__()
|
||||
@@ -85,7 +114,8 @@ class GeneratorUNet(nn.Module):
|
||||
|
||||
# Final
|
||||
self.final = nn.Sequential(
|
||||
nn.ConvTranspose2d(128, out_channels, kernel_size=4, stride=2, padding=1),
|
||||
nn.Upsample(scale_factor=2, mode="bilinear", align_corners=False),
|
||||
nn.Conv2d(128, out_channels, kernel_size=3, padding=1),
|
||||
nn.Tanh(),
|
||||
)
|
||||
|
||||
@@ -115,7 +145,7 @@ class GeneratorUNet(nn.Module):
|
||||
|
||||
|
||||
class DiscriminatorPatchGAN(nn.Module):
|
||||
"""PatchGAN discriminator."""
|
||||
"""PatchGAN discriminator for paired source/target images."""
|
||||
|
||||
def __init__(self, in_channels: int = 6):
|
||||
super().__init__()
|
||||
@@ -132,7 +162,6 @@ class DiscriminatorPatchGAN(nn.Module):
|
||||
nn.BatchNorm2d(512),
|
||||
nn.LeakyReLU(0.2, inplace=True),
|
||||
nn.Conv2d(512, 1, kernel_size=4, stride=1, padding=1),
|
||||
nn.Sigmoid(),
|
||||
)
|
||||
|
||||
def forward(self, img_A, img_B):
|
||||
@@ -167,15 +196,82 @@ class GANLoss(nn.Module):
|
||||
return -prediction.mean() if target_is_real else prediction.mean()
|
||||
|
||||
|
||||
class SSIMLoss(nn.Module):
|
||||
"""Local SSIM loss for normalized image tensors in [-1, 1]."""
|
||||
|
||||
def __init__(self, window_size: int = 11, c1: float = 0.01 ** 2, c2: float = 0.03 ** 2):
|
||||
super().__init__()
|
||||
self.window_size = window_size
|
||||
self.c1 = c1
|
||||
self.c2 = c2
|
||||
|
||||
def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
|
||||
pred = (pred + 1.0) * 0.5
|
||||
target = (target + 1.0) * 0.5
|
||||
padding = self.window_size // 2
|
||||
|
||||
mu_pred = F.avg_pool2d(pred, self.window_size, stride=1, padding=padding)
|
||||
mu_target = F.avg_pool2d(target, self.window_size, stride=1, padding=padding)
|
||||
mu_pred_sq = mu_pred.pow(2)
|
||||
mu_target_sq = mu_target.pow(2)
|
||||
mu_pred_target = mu_pred * mu_target
|
||||
|
||||
sigma_pred = F.avg_pool2d(pred * pred, self.window_size, stride=1, padding=padding) - mu_pred_sq
|
||||
sigma_target = F.avg_pool2d(target * target, self.window_size, stride=1, padding=padding) - mu_target_sq
|
||||
sigma_pred_target = F.avg_pool2d(pred * target, self.window_size, stride=1, padding=padding) - mu_pred_target
|
||||
|
||||
ssim_map = (
|
||||
(2 * mu_pred_target + self.c1) * (2 * sigma_pred_target + self.c2)
|
||||
) / (
|
||||
(mu_pred_sq + mu_target_sq + self.c1) * (sigma_pred + sigma_target + self.c2)
|
||||
)
|
||||
return (1.0 - ssim_map.clamp(0, 1)).mean()
|
||||
|
||||
|
||||
class SobelEdgeLoss(nn.Module):
|
||||
"""L1 loss between Sobel edge maps, useful for stable keypoint structure."""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
kernel_x = torch.tensor(
|
||||
[[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]],
|
||||
dtype=torch.float32,
|
||||
).view(1, 1, 3, 3)
|
||||
kernel_y = torch.tensor(
|
||||
[[-1, -2, -1], [0, 0, 0], [1, 2, 1]],
|
||||
dtype=torch.float32,
|
||||
).view(1, 1, 3, 3)
|
||||
self.register_buffer("kernel_x", kernel_x)
|
||||
self.register_buffer("kernel_y", kernel_y)
|
||||
|
||||
@staticmethod
|
||||
def _to_gray(x: torch.Tensor) -> torch.Tensor:
|
||||
x = (x + 1.0) * 0.5
|
||||
weights = x.new_tensor([0.299, 0.587, 0.114]).view(1, 3, 1, 1)
|
||||
return (x * weights).sum(dim=1, keepdim=True)
|
||||
|
||||
def _edges(self, x: torch.Tensor) -> torch.Tensor:
|
||||
gray = self._to_gray(x)
|
||||
grad_x = F.conv2d(gray, self.kernel_x, padding=1)
|
||||
grad_y = F.conv2d(gray, self.kernel_y, padding=1)
|
||||
return torch.sqrt(grad_x.pow(2) + grad_y.pow(2) + 1e-6)
|
||||
|
||||
def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
|
||||
return F.l1_loss(self._edges(pred), self._edges(target))
|
||||
|
||||
|
||||
class ImageGAN(nn.Module):
|
||||
"""Complete GAN model for image translation."""
|
||||
"""Complete pix2pix-style GAN for Google -> Yandex image translation."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_channels: int = 3,
|
||||
output_channels: int = 3,
|
||||
gan_mode: str = "vanilla",
|
||||
lambda_L1: float = 100.0,
|
||||
gan_mode: str = "lsgan",
|
||||
lambda_L1: float = 150.0,
|
||||
lambda_GAN: float = 0.5,
|
||||
lambda_SSIM: float = 25.0,
|
||||
lambda_edge: float = 20.0,
|
||||
use_cuda: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
@@ -183,29 +279,36 @@ class ImageGAN(nn.Module):
|
||||
self.discriminator = DiscriminatorPatchGAN(input_channels + output_channels)
|
||||
self.gan_loss = GANLoss(gan_mode)
|
||||
self.l1_loss = nn.L1Loss()
|
||||
self.ssim_loss = SSIMLoss()
|
||||
self.edge_loss = SobelEdgeLoss()
|
||||
self.lambda_L1 = lambda_L1
|
||||
self.lambda_GAN = lambda_GAN
|
||||
self.lambda_SSIM = lambda_SSIM
|
||||
self.lambda_edge = lambda_edge
|
||||
|
||||
self.device = torch.device("cuda" if use_cuda and torch.cuda.is_available() else "cpu")
|
||||
self.device = get_compatible_device(prefer_cuda=use_cuda)
|
||||
self.to(self.device)
|
||||
|
||||
def forward(self, yandex_image):
|
||||
"""Generate Google image from Yandex."""
|
||||
return self.generator(yandex_image)
|
||||
def forward(self, google_image):
|
||||
"""Generate a Yandex-style image from a Google image."""
|
||||
return self.generator(google_image)
|
||||
|
||||
def generator_step(self, yandex_img, real_google_img):
|
||||
"""Compute generator losses."""
|
||||
fake_google = self.generator(yandex_img)
|
||||
fake_pred = self.discriminator(yandex_img, fake_google)
|
||||
gan_loss = self.gan_loss(fake_pred, True)
|
||||
l1_loss = self.l1_loss(fake_google, real_google_img) * self.lambda_L1
|
||||
total_loss = gan_loss + l1_loss
|
||||
return total_loss, gan_loss, l1_loss
|
||||
def generator_step(self, google_img, real_yandex_img):
|
||||
"""Compute generator losses against the paired original Yandex image."""
|
||||
fake_yandex = self.generator(google_img)
|
||||
fake_pred = self.discriminator(google_img, fake_yandex)
|
||||
gan_loss = self.gan_loss(fake_pred, True) * self.lambda_GAN
|
||||
l1_loss = self.l1_loss(fake_yandex, real_yandex_img) * self.lambda_L1
|
||||
ssim_loss = self.ssim_loss(fake_yandex, real_yandex_img) * self.lambda_SSIM
|
||||
edge_loss = self.edge_loss(fake_yandex, real_yandex_img) * self.lambda_edge
|
||||
total_loss = gan_loss + l1_loss + ssim_loss + edge_loss
|
||||
return total_loss, gan_loss, l1_loss, ssim_loss, edge_loss
|
||||
|
||||
def discriminator_step(self, yandex_img, real_google_img, fake_google_img):
|
||||
"""Compute discriminator losses."""
|
||||
real_pred = self.discriminator(yandex_img, real_google_img)
|
||||
def discriminator_step(self, google_img, real_yandex_img, fake_yandex_img):
|
||||
"""Compute discriminator losses for real and generated Yandex targets."""
|
||||
real_pred = self.discriminator(google_img, real_yandex_img)
|
||||
real_loss = self.gan_loss(real_pred, True)
|
||||
fake_pred = self.discriminator(yandex_img, fake_google_img.detach())
|
||||
fake_pred = self.discriminator(google_img, fake_yandex_img.detach())
|
||||
fake_loss = self.gan_loss(fake_pred, False)
|
||||
total_loss = (real_loss + fake_loss) * 0.5
|
||||
return total_loss, real_loss, fake_loss
|
||||
@@ -214,8 +317,11 @@ class ImageGAN(nn.Module):
|
||||
def create_gan(
|
||||
input_channels: int = 3,
|
||||
output_channels: int = 3,
|
||||
gan_mode: str = "vanilla",
|
||||
lambda_L1: float = 100.0,
|
||||
gan_mode: str = "lsgan",
|
||||
lambda_L1: float = 150.0,
|
||||
lambda_GAN: float = 0.5,
|
||||
lambda_SSIM: float = 25.0,
|
||||
lambda_edge: float = 20.0,
|
||||
use_cuda: bool = True,
|
||||
) -> ImageGAN:
|
||||
"""Create a GAN model."""
|
||||
@@ -224,6 +330,9 @@ def create_gan(
|
||||
output_channels=output_channels,
|
||||
gan_mode=gan_mode,
|
||||
lambda_L1=lambda_L1,
|
||||
lambda_GAN=lambda_GAN,
|
||||
lambda_SSIM=lambda_SSIM,
|
||||
lambda_edge=lambda_edge,
|
||||
use_cuda=use_cuda,
|
||||
)
|
||||
|
||||
@@ -252,4 +361,4 @@ if __name__ == "__main__":
|
||||
# Count parameters
|
||||
gen_params = sum(p.numel() for p in model.generator.parameters())
|
||||
disc_params = sum(p.numel() for p in model.discriminator.parameters())
|
||||
print(f"Generator: {gen_params:,} params, Discriminator: {disc_params:,} params")
|
||||
print(f"Generator: {gen_params:,} params, Discriminator: {disc_params:,} params")
|
||||
40
models/GAN/src/test_dataloader.py
Normal file
40
models/GAN/src/test_dataloader.py
Normal file
@@ -0,0 +1,40 @@
|
||||
from config import create_config
|
||||
from dataloader import YaGoDataset, create_data_loaders
|
||||
|
||||
|
||||
def get_dataset_info():
|
||||
config = create_config()
|
||||
dataset = YaGoDataset(
|
||||
root_dir=config["data_dir"],
|
||||
image_size=tuple(config["image_size"]),
|
||||
)
|
||||
sample = dataset[0] if len(dataset) else {}
|
||||
return {
|
||||
"size": len(dataset),
|
||||
"sample_keys": list(sample.keys()),
|
||||
"google_shape": tuple(sample["google_img"].shape) if sample else None,
|
||||
"yandex_shape": tuple(sample["yandex_img"].shape) if sample else None,
|
||||
}
|
||||
|
||||
|
||||
def smoke_test_dataloader(batch_size=4):
|
||||
config = create_config()
|
||||
train_loader, val_loader = create_data_loaders(
|
||||
root_dir=config["data_dir"],
|
||||
batch_size=batch_size,
|
||||
train_split=config["train_split"],
|
||||
num_workers=config["num_workers"],
|
||||
image_size=tuple(config["image_size"]),
|
||||
)
|
||||
batch = next(iter(train_loader))
|
||||
return {
|
||||
"train_size": len(train_loader.dataset),
|
||||
"val_size": len(val_loader.dataset),
|
||||
"google_batch_shape": tuple(batch["google_img"].shape),
|
||||
"yandex_batch_shape": tuple(batch["yandex_img"].shape),
|
||||
}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
print(get_dataset_info())
|
||||
print(smoke_test_dataloader())
|
||||
@@ -28,18 +28,26 @@ class GANTrainer:
|
||||
|
||||
# Optimizers
|
||||
lr = config.get("learning_rate", 2e-4)
|
||||
lr_d = config.get("discriminator_learning_rate", lr * config.get("discriminator_lr_factor", 0.5))
|
||||
beta1 = config.get("beta1", 0.5)
|
||||
beta2 = config.get("beta2", 0.999)
|
||||
self.opt_G = torch.optim.Adam(model.generator.parameters(), lr=lr, betas=(beta1, beta2))
|
||||
self.opt_D = torch.optim.Adam(model.discriminator.parameters(), lr=lr, betas=(beta1, beta2))
|
||||
self.opt_D = torch.optim.Adam(model.discriminator.parameters(), lr=lr_d, betas=(beta1, beta2))
|
||||
|
||||
# Training state
|
||||
self.current_epoch = 0
|
||||
self.best_val_loss = float("inf")
|
||||
self.g_losses = []
|
||||
self.d_losses = []
|
||||
self.l1_losses = []
|
||||
self.ssim_losses = []
|
||||
self.edge_losses = []
|
||||
self.val_g_losses = []
|
||||
self.val_d_losses = []
|
||||
self.val_l1_losses = []
|
||||
self.val_ssim_losses = []
|
||||
self.val_edge_losses = []
|
||||
self.val_reconstruction_losses = []
|
||||
|
||||
# Output dir
|
||||
self.output_dir = Path(config.get("output_dir", "runs/gan"))
|
||||
@@ -54,35 +62,55 @@ class GANTrainer:
|
||||
"""Train for one epoch."""
|
||||
self.model.train()
|
||||
total_g = total_d = 0.0
|
||||
total_l1 = total_ssim = total_edge = 0.0
|
||||
num_batches = len(self.train_loader)
|
||||
d_update_interval = max(1, self.config.get("discriminator_update_interval", 1))
|
||||
|
||||
pbar = tqdm(self.train_loader, desc=f"Epoch {self.current_epoch + 1}")
|
||||
for batch in pbar:
|
||||
yandex_img = batch["yandex_img"].to(self.device)
|
||||
pbar = tqdm(enumerate(self.train_loader), total=num_batches, desc=f"Epoch {self.current_epoch + 1}")
|
||||
for batch_idx, batch in pbar:
|
||||
google_img = batch["google_img"].to(self.device)
|
||||
yandex_img = batch["yandex_img"].to(self.device)
|
||||
|
||||
# Train D
|
||||
self.opt_D.zero_grad()
|
||||
with torch.no_grad():
|
||||
fake_img = self.model.generator(yandex_img)
|
||||
d_loss = self.model.discriminator_step(yandex_img, google_img, fake_img)[0]
|
||||
d_loss.backward()
|
||||
self.opt_D.step()
|
||||
if batch_idx % d_update_interval == 0:
|
||||
self.opt_D.zero_grad()
|
||||
with torch.no_grad():
|
||||
fake_img = self.model.generator(google_img)
|
||||
d_loss = self.model.discriminator_step(google_img, yandex_img, fake_img)[0]
|
||||
d_loss.backward()
|
||||
self.opt_D.step()
|
||||
else:
|
||||
d_loss = google_img.new_tensor(0.0)
|
||||
|
||||
# Train G
|
||||
self.opt_G.zero_grad()
|
||||
g_loss = self.model.generator_step(yandex_img, google_img)[0]
|
||||
g_loss, gan_loss, l1_loss, ssim_loss, edge_loss = self.model.generator_step(google_img, yandex_img)
|
||||
g_loss.backward()
|
||||
self.opt_G.step()
|
||||
|
||||
total_g += g_loss.item()
|
||||
total_d += d_loss.item()
|
||||
pbar.set_postfix({"g_loss": g_loss.item(), "d_loss": d_loss.item()})
|
||||
total_l1 += l1_loss.item()
|
||||
total_ssim += ssim_loss.item()
|
||||
total_edge += edge_loss.item()
|
||||
pbar.set_postfix({
|
||||
"g_loss": g_loss.item(),
|
||||
"d_loss": d_loss.item(),
|
||||
"l1": l1_loss.item(),
|
||||
"ssim": ssim_loss.item(),
|
||||
"edge": edge_loss.item(),
|
||||
})
|
||||
|
||||
avg_g = total_g / num_batches
|
||||
avg_d = total_d / num_batches
|
||||
avg_l1 = total_l1 / num_batches
|
||||
avg_ssim = total_ssim / num_batches
|
||||
avg_edge = total_edge / num_batches
|
||||
self.g_losses.append(avg_g)
|
||||
self.d_losses.append(avg_d)
|
||||
self.l1_losses.append(avg_l1)
|
||||
self.ssim_losses.append(avg_ssim)
|
||||
self.edge_losses.append(avg_edge)
|
||||
return avg_g, avg_d
|
||||
|
||||
@torch.no_grad()
|
||||
@@ -90,20 +118,32 @@ class GANTrainer:
|
||||
"""Validate the model."""
|
||||
self.model.eval()
|
||||
total_g = total_d = 0.0
|
||||
total_l1 = total_ssim = total_edge = 0.0
|
||||
|
||||
for batch in tqdm(self.val_loader, desc="Val"):
|
||||
yandex_img = batch["yandex_img"].to(self.device)
|
||||
google_img = batch["google_img"].to(self.device)
|
||||
fake_img = self.model.generator(yandex_img)
|
||||
g_loss = self.model.generator_step(yandex_img, google_img)[0]
|
||||
d_loss = self.model.discriminator_step(yandex_img, google_img, fake_img)[0]
|
||||
yandex_img = batch["yandex_img"].to(self.device)
|
||||
fake_img = self.model.generator(google_img)
|
||||
g_loss, _, l1_loss, ssim_loss, edge_loss = self.model.generator_step(google_img, yandex_img)
|
||||
d_loss = self.model.discriminator_step(google_img, yandex_img, fake_img)[0]
|
||||
total_g += g_loss.item()
|
||||
total_d += d_loss.item()
|
||||
total_l1 += l1_loss.item()
|
||||
total_ssim += ssim_loss.item()
|
||||
total_edge += edge_loss.item()
|
||||
|
||||
avg_g = total_g / len(self.val_loader)
|
||||
avg_d = total_d / len(self.val_loader)
|
||||
avg_l1 = total_l1 / len(self.val_loader)
|
||||
avg_ssim = total_ssim / len(self.val_loader)
|
||||
avg_edge = total_edge / len(self.val_loader)
|
||||
avg_reconstruction = avg_l1 + avg_ssim + avg_edge
|
||||
self.val_g_losses.append(avg_g)
|
||||
self.val_d_losses.append(avg_d)
|
||||
self.val_l1_losses.append(avg_l1)
|
||||
self.val_ssim_losses.append(avg_ssim)
|
||||
self.val_edge_losses.append(avg_edge)
|
||||
self.val_reconstruction_losses.append(avg_reconstruction)
|
||||
return avg_g, avg_d
|
||||
|
||||
def train(self, num_epochs: int):
|
||||
@@ -117,23 +157,30 @@ class GANTrainer:
|
||||
train_g, train_d = self.train_epoch()
|
||||
val_g, val_d = self.validate()
|
||||
|
||||
# Save best checkpoint
|
||||
val_total = val_g + val_d
|
||||
if val_total < self.best_val_loss:
|
||||
self.best_val_loss = val_total
|
||||
val_reconstruction = self.val_reconstruction_losses[-1]
|
||||
if val_reconstruction < self.best_val_loss:
|
||||
self.best_val_loss = val_reconstruction
|
||||
self.save_checkpoint("best")
|
||||
|
||||
# Periodic checkpoint
|
||||
if (epoch + 1) % self.config.get("save_interval", 5) == 0:
|
||||
self.save_checkpoint(f"epoch_{epoch + 1}")
|
||||
|
||||
print(f"Epoch {epoch + 1}: train_g={train_g:.4f}, train_d={train_d:.4f}, val_g={val_g:.4f}, val_d={val_d:.4f}")
|
||||
print(
|
||||
f"Epoch {epoch + 1}: "
|
||||
f"train_g={train_g:.4f}, train_d={train_d:.4f}, "
|
||||
f"train_l1={self.l1_losses[-1]:.4f}, train_ssim={self.ssim_losses[-1]:.4f}, "
|
||||
f"train_edge={self.edge_losses[-1]:.4f}, val_g={val_g:.4f}, val_d={val_d:.4f}, "
|
||||
f"val_l1={self.val_l1_losses[-1]:.4f}, val_ssim={self.val_ssim_losses[-1]:.4f}, "
|
||||
f"val_edge={self.val_edge_losses[-1]:.4f}, val_rec={val_reconstruction:.4f}"
|
||||
)
|
||||
|
||||
# Early stopping
|
||||
patience = self.config.get("early_stopping_patience", 0)
|
||||
if patience > 0 and len(self.val_g_losses) > patience:
|
||||
recent = self.val_g_losses[-patience:]
|
||||
if all(l >= min(self.val_g_losses[:-patience]) for l in recent):
|
||||
if patience > 0 and len(self.val_reconstruction_losses) > patience:
|
||||
recent = self.val_reconstruction_losses[-patience:]
|
||||
previous_best = min(self.val_reconstruction_losses[:-patience])
|
||||
if all(loss >= previous_best for loss in recent):
|
||||
print(f"Early stopping at epoch {epoch + 1}")
|
||||
break
|
||||
|
||||
@@ -188,4 +235,4 @@ if __name__ == "__main__":
|
||||
g_loss, d_loss = trainer.train_epoch()
|
||||
print(f"Training step succeeded: G={g_loss:.4f}, D={d_loss:.4f}")
|
||||
except Exception as e:
|
||||
print(f"Training step failed: {e}")
|
||||
print(f"Training step failed: {e}")
|
||||
BIN
practice/Indiv_zadanie_Сагитов.docx
Normal file
BIN
practice/Indiv_zadanie_Сагитов.docx
Normal file
Binary file not shown.
BIN
practice/otchet_po_praktike.docx
Normal file
BIN
practice/otchet_po_praktike.docx
Normal file
Binary file not shown.
BIN
practice/Дневник практики.docx
Normal file
BIN
practice/Дневник практики.docx
Normal file
Binary file not shown.
145
sian_similarity.py
Normal file
145
sian_similarity.py
Normal file
@@ -0,0 +1,145 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import importlib.util
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Optional, Sequence
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
|
||||
from vision_chunk import VisionChunk
|
||||
|
||||
|
||||
ROOT_DIR = Path(__file__).resolve().parent
|
||||
MODEL_FILE = ROOT_DIR / "models" / "SiaN-similarity" / "model.py"
|
||||
DEFAULT_CHECKPOINT_PATH = (
|
||||
ROOT_DIR
|
||||
/ "models"
|
||||
/ "SiaN-similarity"
|
||||
/ "runs"
|
||||
/ "gan_training"
|
||||
/ "checkpoints"
|
||||
/ "best_model.pt"
|
||||
)
|
||||
|
||||
IMAGE_SIZE = (256, 256)
|
||||
DEFAULT_THRESHOLD = 0.5
|
||||
CHECKPOINT_ENV = "SIAN_SIMILARITY_CHECKPOINT"
|
||||
THRESHOLD_ENV = "SIAN_SIMILARITY_THRESHOLD"
|
||||
|
||||
_model: Optional[torch.nn.Module] = None
|
||||
_device: Optional[torch.device] = None
|
||||
|
||||
|
||||
def _load_similarity_class():
|
||||
spec = importlib.util.spec_from_file_location("sian_similarity_model", MODEL_FILE)
|
||||
if spec is None or spec.loader is None:
|
||||
raise ImportError(f"Cannot load similarity model from {MODEL_FILE}")
|
||||
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
spec.loader.exec_module(module)
|
||||
return module.SimilarityCNN
|
||||
|
||||
|
||||
def _get_checkpoint_path() -> Path:
|
||||
checkpoint_path = os.getenv(CHECKPOINT_ENV)
|
||||
if checkpoint_path:
|
||||
return Path(checkpoint_path).expanduser().resolve()
|
||||
return DEFAULT_CHECKPOINT_PATH
|
||||
|
||||
|
||||
def get_threshold() -> float:
|
||||
threshold = os.getenv(THRESHOLD_ENV)
|
||||
if threshold is None:
|
||||
return DEFAULT_THRESHOLD
|
||||
return float(threshold)
|
||||
|
||||
|
||||
def _get_device() -> torch.device:
|
||||
global _device
|
||||
|
||||
if _device is None:
|
||||
_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
return _device
|
||||
|
||||
|
||||
def _get_model() -> torch.nn.Module:
|
||||
global _model
|
||||
|
||||
if _model is not None:
|
||||
return _model
|
||||
|
||||
checkpoint_path = _get_checkpoint_path()
|
||||
if not checkpoint_path.exists():
|
||||
raise FileNotFoundError(
|
||||
f"SiaN similarity checkpoint not found: {checkpoint_path}. "
|
||||
f"Set {CHECKPOINT_ENV} to another .pt file if needed."
|
||||
)
|
||||
|
||||
SimilarityCNN = _load_similarity_class()
|
||||
device = _get_device()
|
||||
|
||||
model = SimilarityCNN(
|
||||
input_channels=3,
|
||||
backbone_name="resnet18",
|
||||
pretrained=False,
|
||||
dropout_rate=0.3,
|
||||
use_batch_norm=True,
|
||||
).to(device)
|
||||
|
||||
checkpoint = torch.load(checkpoint_path, map_location=device)
|
||||
state_dict = checkpoint.get("model_state_dict", checkpoint)
|
||||
model.load_state_dict(state_dict)
|
||||
model.eval()
|
||||
|
||||
_model = model
|
||||
return _model
|
||||
|
||||
|
||||
def _chunk_to_tensor(chunk: VisionChunk) -> torch.Tensor:
|
||||
image = chunk.image.convert("RGB").resize(IMAGE_SIZE, Image.BILINEAR)
|
||||
array = np.asarray(image, dtype=np.float32) / 255.0
|
||||
tensor = torch.from_numpy(array).permute(2, 0, 1).unsqueeze(0)
|
||||
return tensor.to(_get_device())
|
||||
|
||||
|
||||
def get_similarity_score(chunk1: VisionChunk, chunk2: VisionChunk) -> float:
|
||||
if chunk1 is None or chunk2 is None:
|
||||
return 0.0
|
||||
|
||||
model = _get_model()
|
||||
img1 = _chunk_to_tensor(chunk1)
|
||||
img2 = _chunk_to_tensor(chunk2)
|
||||
|
||||
with torch.inference_mode():
|
||||
similarity = model(img1, img2)
|
||||
|
||||
return float(similarity.squeeze().item())
|
||||
|
||||
|
||||
def get_similarity_scores(chunk: VisionChunk, candidates: Sequence[VisionChunk]) -> list[float]:
|
||||
if chunk is None or not candidates:
|
||||
return []
|
||||
|
||||
model = _get_model()
|
||||
img = _chunk_to_tensor(chunk)
|
||||
candidate_images = torch.cat([_chunk_to_tensor(candidate) for candidate in candidates], dim=0)
|
||||
repeated_img = img.expand(candidate_images.shape[0], -1, -1, -1)
|
||||
|
||||
with torch.inference_mode():
|
||||
similarities = model(repeated_img, candidate_images)
|
||||
|
||||
return [float(score) for score in similarities.squeeze(1).detach().cpu().tolist()]
|
||||
|
||||
|
||||
def is_similar(
|
||||
chunk1: VisionChunk,
|
||||
chunk2: VisionChunk,
|
||||
threshold: Optional[float] = None,
|
||||
) -> bool:
|
||||
if threshold is None:
|
||||
threshold = get_threshold()
|
||||
|
||||
return get_similarity_score(chunk1, chunk2) >= threshold
|
||||
Reference in New Issue
Block a user