Compare commits
13 Commits
31c0f13361
...
feat/pitch
| Author | SHA1 | Date | |
|---|---|---|---|
| 0cc210968f | |||
| 6040f3b253 | |||
| 6d4208d100 | |||
|
|
339e5f210c | ||
|
|
5385641d28 | ||
| 64c9215f5b | |||
| 7cd700c1fa | |||
| 05c249ed78 | |||
| fbd0d01b35 | |||
| e00878daad | |||
| ceca8a6e75 | |||
| 6456d18212 | |||
| 3ee3599b87 |
3
.gitignore
vendored
3
.gitignore
vendored
@@ -1,4 +1,5 @@
|
|||||||
.venv
|
.venv
|
||||||
__pycache__
|
__pycache__
|
||||||
*.png
|
*.png
|
||||||
images
|
trajectories
|
||||||
|
z
|
||||||
|
|||||||
222
autopilot.py
222
autopilot.py
@@ -3,6 +3,7 @@ from pathlib import Path
|
|||||||
import math
|
import math
|
||||||
import random
|
import random
|
||||||
|
|
||||||
|
import constants
|
||||||
import cv2
|
import cv2
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
@@ -56,14 +57,16 @@ class AutoPilot(Pilot):
|
|||||||
|
|
||||||
# Положение на основе ориентира
|
# Положение на основе ориентира
|
||||||
reserved_pos: Position | None
|
reserved_pos: Position | None
|
||||||
|
proccessing_time: float
|
||||||
|
|
||||||
def __init__(self, points = [], chunks = [], viz_manager=None):
|
def __init__(self, points = [], chunks = [], viz_manager=None, pixel_ratio: float = 1.):
|
||||||
self.prev_chunk = None
|
self.prev_chunk = None
|
||||||
self.pos = Position(0, 0, 1, 0, 0, 0)
|
self.pos = Position(0, 0, 1, 0, 0, 0)
|
||||||
self.chunks = chunks
|
self.chunks = chunks
|
||||||
self.frame_count = 0
|
self.frame_count = 0
|
||||||
self.vis_manager = viz_manager # Менеджер визуализации
|
self.vis_manager = viz_manager # Менеджер визуализации
|
||||||
self.reserved_pos = None
|
self.reserved_pos = None
|
||||||
|
self.pixel_ratio = pixel_ratio
|
||||||
|
|
||||||
# Пороговые значения качества сопоставления/гомографии
|
# Пороговые значения качества сопоставления/гомографии
|
||||||
self.min_inliers: int = 12
|
self.min_inliers: int = 12
|
||||||
@@ -76,6 +79,7 @@ class AutoPilot(Pilot):
|
|||||||
self.target_idx = 0
|
self.target_idx = 0
|
||||||
|
|
||||||
self.timer = Timer()
|
self.timer = Timer()
|
||||||
|
self.chunk_points = np.array([[chunk.pos.x, chunk.pos.y] for chunk in self.chunks])
|
||||||
|
|
||||||
def get_position(self) -> tuple[float, float]:
|
def get_position(self) -> tuple[float, float]:
|
||||||
return self.pos.x, self.pos.y
|
return self.pos.x, self.pos.y
|
||||||
@@ -91,7 +95,7 @@ class AutoPilot(Pilot):
|
|||||||
h, w = prev_gray.shape[:2]
|
h, w = prev_gray.shape[:2]
|
||||||
|
|
||||||
# Создаем сетку точек для отслеживания (аналогично вашему step=20)
|
# Создаем сетку точек для отслеживания (аналогично вашему step=20)
|
||||||
step = 35
|
step = 20
|
||||||
grid_points = []
|
grid_points = []
|
||||||
for y in range(step, h - step, step):
|
for y in range(step, h - step, step):
|
||||||
for x in range(step, w - step, step):
|
for x in range(step, w - step, step):
|
||||||
@@ -133,9 +137,6 @@ class AutoPilot(Pilot):
|
|||||||
"""
|
"""
|
||||||
self.pos.iapply(homography_matrix)
|
self.pos.iapply(homography_matrix)
|
||||||
|
|
||||||
if self.reserved_pos is not None:
|
|
||||||
self.reserved_pos.iapply(homography_matrix)
|
|
||||||
|
|
||||||
def get_drone_state(self) -> dict:
|
def get_drone_state(self) -> dict:
|
||||||
"""
|
"""
|
||||||
Возвращает текущее состояние БПЛА
|
Возвращает текущее состояние БПЛА
|
||||||
@@ -149,43 +150,122 @@ class AutoPilot(Pilot):
|
|||||||
}
|
}
|
||||||
|
|
||||||
def get_position_by_chunk(self) -> Position | None:
|
def get_position_by_chunk(self) -> Position | None:
|
||||||
# Пытаемся найти ориентир на картинке:
|
landmark_timer = Timer()
|
||||||
|
landmark_timer.start()
|
||||||
|
|
||||||
|
cur_pos = np.array([self.pos.x, self.pos.y])
|
||||||
|
closest_chunk_idx = ((self.chunk_points - cur_pos) ** 2).sum(1).argmin()
|
||||||
|
|
||||||
current_chunk = self.prev_chunk
|
current_chunk = self.prev_chunk
|
||||||
landmark_chunk = self.chunks[self.target_idx]
|
landmark_chunk = self.chunks[closest_chunk_idx]
|
||||||
|
|
||||||
|
if constants.DEBUG_FPS:
|
||||||
|
print(f"[LANDMARK]: Closest chunk finding: {landmark_timer.loop() * 1000:.2f} ms")
|
||||||
|
|
||||||
|
# Краевой случай: отсутствие чанков
|
||||||
|
if current_chunk is None or landmark_chunk is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
landmark_timer.start()
|
||||||
src_pts, dst_pts, matches, kp1, kp2 = landmark_chunk.detect_and_match_keypoints(current_chunk)
|
src_pts, dst_pts, matches, kp1, kp2 = landmark_chunk.detect_and_match_keypoints(current_chunk)
|
||||||
|
|
||||||
|
if constants.DEBUG_FPS:
|
||||||
|
print(f"[LANDMARK]: detect and match keypoints: {landmark_timer.loop() * 1000:.2f} ms")
|
||||||
|
|
||||||
|
landmark_timer.stop()
|
||||||
|
|
||||||
|
# Визуализация (если нужна)
|
||||||
if src_pts is not None and dst_pts is not None and self.vis_manager:
|
if src_pts is not None and dst_pts is not None and self.vis_manager:
|
||||||
was_enabled = self.timer.enabled
|
was_enabled = self.timer.enabled
|
||||||
if was_enabled: self.timer.stop()
|
if was_enabled:
|
||||||
self.vis_manager.update_chunk_matches(landmark_chunk.to_numpy(), current_chunk.to_numpy(), kp1, kp2, matches)
|
self.timer.stop()
|
||||||
if was_enabled: self.timer.start()
|
self.vis_manager.update_chunk_matches(
|
||||||
|
landmark_chunk.to_numpy(),
|
||||||
|
current_chunk.to_numpy(),
|
||||||
|
kp1, kp2, matches
|
||||||
|
)
|
||||||
|
if was_enabled:
|
||||||
|
self.timer.start()
|
||||||
|
|
||||||
if src_pts is not None and dst_pts is not None:
|
landmark_timer.start()
|
||||||
# Оцениваем матрицу трансформации
|
# Краевой случай: нет точек или недостаточно матчей
|
||||||
landmark_transform = self.estimate_transformation_matrix(src_pts, dst_pts)
|
if src_pts is None or dst_pts is None:
|
||||||
# Если ориентир достоверно найден — скорректируем глобальные координаты и угол
|
return None
|
||||||
if landmark_transform is not None:
|
|
||||||
ok_scale = (self.min_scale <= landmark_transform['scale'] <= self.max_scale)
|
num_matches = len(src_pts)
|
||||||
ok_inliers = (landmark_transform.get('inliers', 0) >= self.min_inliers)
|
if num_matches < 20:
|
||||||
ratio = landmark_transform.get('inliers_ratio', 0.0)
|
return None
|
||||||
ok_ratio = (ratio >= self.min_inlier_ratio)
|
|
||||||
rmse = landmark_transform.get('rmse', None)
|
# Оценка матрицы гомографии
|
||||||
ok_rmse = (rmse is not None and rmse <= self.max_reproj_rmse)
|
landmark_timer.loop()
|
||||||
if ok_scale and ok_inliers and ok_ratio and ok_rmse:
|
landmark_transform, mask = estimate_transformation_matrix(src_pts, dst_pts)
|
||||||
# print("[HELP]")
|
num_inliers = int(np.sum(mask))
|
||||||
# print("Matrix", landmark_transform['homography'])
|
|
||||||
# print("Position", self.x, self.y)
|
if constants.DEBUG_FPS:
|
||||||
# print("Position of point", self.points[self.target_idx])
|
print(f"[LANDMARK]: matrix estimation: {landmark_timer.loop() * 1000:.2f} ms")
|
||||||
# print("[PILOT]", rmse, ratio, ok_rmse)
|
|
||||||
# if False:
|
# Краевой случай: матрица не найдена
|
||||||
# Считаем абсолютную позу относительно координат ориентира
|
if landmark_transform is None or mask is None:
|
||||||
landmark_world_x, landmark_world_y = self.points[self.target_idx]
|
return None
|
||||||
landmark = Position(landmark_world_x, landmark_world_y, 1, 0)
|
|
||||||
homography = landmark_transform['homography']
|
# === КРИТЕРИИ ПРИНЯТИЯ РЕШЕНИЯ ===
|
||||||
# homography = np.linalg.inv(homography)
|
|
||||||
print(f" [Pilot] Landmark correction applied (inliers={landmark_transform['inliers']}, ratio={ratio:.2f}, rmse={rmse:.2f})")
|
# 1. Минимальное количество инлайеров (абсолютное)
|
||||||
return landmark @ homography
|
MIN_INLIERS_ABSOLUTE = 6
|
||||||
return None
|
if num_inliers < MIN_INLIERS_ABSOLUTE:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# 2. Процент инлайеров от общего числа матчей
|
||||||
|
inlier_ratio = num_inliers / num_matches
|
||||||
|
|
||||||
|
if constants.DEBUG_LANDMARK:
|
||||||
|
print("[LANDMARK]: inlier_ratio=", inlier_ratio)
|
||||||
|
|
||||||
|
MIN_INLIER_RATIO = 0.6
|
||||||
|
if inlier_ratio < MIN_INLIER_RATIO:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# 3. Проверка качества гомографии (детерминант для выявления вырожденных случаев)
|
||||||
|
det = np.linalg.det(landmark_transform[:2, :2])
|
||||||
|
# Детерминант должен быть близок к 1 (без сильного масштабирования)
|
||||||
|
if abs(det) < 0.1 or abs(det) > 10.0:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# 4. Проверка на валидность перспективного преобразования
|
||||||
|
# Элементы третьей строки не должны быть слишком большими
|
||||||
|
if abs(landmark_transform[2, 0]) > 0.01 or abs(landmark_transform[2, 1]) > 0.01:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# 5. Дополнительная проверка: средняя ошибка репроекции для инлайеров
|
||||||
|
inlier_src = src_pts[mask.ravel() == 1]
|
||||||
|
inlier_dst = dst_pts[mask.ravel() == 1]
|
||||||
|
|
||||||
|
# Преобразуем точки через найденную гомографию
|
||||||
|
transformed_pts = cv2.perspectiveTransform(inlier_src, landmark_transform)
|
||||||
|
|
||||||
|
# Вычисляем ошибку репроекции
|
||||||
|
reprojection_errors = np.sqrt(np.sum((transformed_pts - inlier_dst) ** 2, axis=2))
|
||||||
|
mean_error = np.mean(reprojection_errors)
|
||||||
|
|
||||||
|
MAX_MEAN_REPROJECTION_ERROR = 1.1 # пиксели
|
||||||
|
|
||||||
|
if constants.DEBUG_LANDMARK:
|
||||||
|
print("[LANDMARK]: Mean_error=", mean_error)
|
||||||
|
|
||||||
|
if mean_error > MAX_MEAN_REPROJECTION_ERROR:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# 6. Проверка стабильности: если слишком много хороших совпадений, но мало инлайеров - подозрительно
|
||||||
|
if num_matches > 50 and inlier_ratio < 0.15:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# === ВСЕ ПРОВЕРКИ ПРОЙДЕНЫ ===
|
||||||
|
print("[LANDMARK]: Correction Applied")
|
||||||
|
|
||||||
|
if constants.DEBUG_FPS:
|
||||||
|
print(f"[LANDMARK]: time: {landmark_timer.get_elapsed() * 1000:.2f} ms")
|
||||||
|
|
||||||
|
return landmark_chunk.pos.apply(landmark_transform)
|
||||||
|
|
||||||
|
|
||||||
def handle(self, current_chunk: VisionChunk) -> PilotCommand:
|
def handle(self, current_chunk: VisionChunk) -> PilotCommand:
|
||||||
@@ -197,17 +277,11 @@ class AutoPilot(Pilot):
|
|||||||
self.timer.stop()
|
self.timer.stop()
|
||||||
return PilotCommand(processing_time=self.timer.get_elapsed())
|
return PilotCommand(processing_time=self.timer.get_elapsed())
|
||||||
|
|
||||||
# Расстояние до цели
|
|
||||||
distance_to_target = math.sqrt(
|
|
||||||
(self.points[self.target_idx][0] - self.pos.x) ** 2 +
|
|
||||||
(self.points[self.target_idx][1] - self.pos.y) ** 2
|
|
||||||
)
|
|
||||||
|
|
||||||
# Вычисляем оптический поток для покадрового сравнения
|
# Вычисляем оптический поток для покадрового сравнения
|
||||||
matching_timer = Timer()
|
matching_timer = Timer()
|
||||||
matching_timer.start()
|
matching_timer.start()
|
||||||
# src_pts, dst_pts = self.calculate_optical_flow(self.prev_chunk, current_chunk)
|
src_pts, dst_pts = self.calculate_optical_flow(self.prev_chunk, current_chunk)
|
||||||
src_pts, dst_pts, _, _, _ = self.prev_chunk.detect_and_match_keypoints(current_chunk)
|
# src_pts, dst_pts, _, _, _ = self.prev_chunk.detect_and_match_keypoints(current_chunk)
|
||||||
matching_timer.stop()
|
matching_timer.stop()
|
||||||
print(f"Matching calculating: {matching_timer.get_elapsed() * 1000:.2f} ms")
|
print(f"Matching calculating: {matching_timer.get_elapsed() * 1000:.2f} ms")
|
||||||
|
|
||||||
@@ -217,7 +291,7 @@ class AutoPilot(Pilot):
|
|||||||
|
|
||||||
matrix_estimation_timer = Timer()
|
matrix_estimation_timer = Timer()
|
||||||
matrix_estimation_timer.start()
|
matrix_estimation_timer.start()
|
||||||
homography_matrix = estimate_transformation_matrix(src_pts, dst_pts)
|
homography_matrix, _ = estimate_transformation_matrix(src_pts, dst_pts)
|
||||||
matrix_estimation_timer.stop()
|
matrix_estimation_timer.stop()
|
||||||
print(f"Transformation matrix updating: {matrix_estimation_timer.get_elapsed() * 1000:.2f} ms")
|
print(f"Transformation matrix updating: {matrix_estimation_timer.get_elapsed() * 1000:.2f} ms")
|
||||||
|
|
||||||
@@ -235,17 +309,24 @@ class AutoPilot(Pilot):
|
|||||||
|
|
||||||
self.timer.start()
|
self.timer.start()
|
||||||
|
|
||||||
chunk_timer = Timer()
|
|
||||||
chunk_timer.start()
|
|
||||||
|
|
||||||
# Пытаемся найти ориентир на картинке:
|
# Пытаемся найти ориентир на картинке:
|
||||||
self.prev_chunk = current_chunk
|
self.prev_chunk = current_chunk
|
||||||
# npos = self.get_position_by_chunk()
|
# Для улучшения среднего FPS
|
||||||
# if npos is not None:
|
if self.frame_count % 5 == 0:
|
||||||
# self.reserved_pos = npos
|
pos_by_chunk = self.get_position_by_chunk()
|
||||||
|
if pos_by_chunk is not None:
|
||||||
|
self.pos = pos_by_chunk
|
||||||
|
|
||||||
chunk_timer.stop()
|
command = self.make_command()
|
||||||
print(f"Chunk timer: {chunk_timer.get_elapsed() * 1000:.2f} ms")
|
self.timer.reset()
|
||||||
|
return command
|
||||||
|
|
||||||
|
def make_command(self) -> PilotCommand:
|
||||||
|
# Расстояние до цели
|
||||||
|
distance_to_target = math.sqrt(
|
||||||
|
(self.points[self.target_idx][0] - self.pos.x) ** 2 +
|
||||||
|
(self.points[self.target_idx][1] - self.pos.y) ** 2
|
||||||
|
) * self.pixel_ratio
|
||||||
|
|
||||||
if distance_to_target < 35:
|
if distance_to_target < 35:
|
||||||
self.target_idx += 1
|
self.target_idx += 1
|
||||||
@@ -258,16 +339,32 @@ class AutoPilot(Pilot):
|
|||||||
(self.points[self.target_idx][1] - self.pos.y) ** 2
|
(self.points[self.target_idx][1] - self.pos.y) ** 2
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.reserved_pos is not None:
|
angle_trajectory = self.pos.yaw + math.pi / 2
|
||||||
self.pos = self.reserved_pos
|
|
||||||
self.reserved_pos = None
|
# Проверка на слепую зону
|
||||||
|
R = 120
|
||||||
|
blind = np.array([
|
||||||
|
[
|
||||||
|
self.pos.x * self.pixel_ratio + R * np.cos(angle_trajectory - np.pi / 2),
|
||||||
|
self.pos.y * self.pixel_ratio + R * np.sin(angle_trajectory - np.pi / 2),
|
||||||
|
],
|
||||||
|
[
|
||||||
|
self.pos.x * self.pixel_ratio + R * np.cos(angle_trajectory + np.pi / 2),
|
||||||
|
self.pos.y * self.pixel_ratio + R * np.sin(angle_trajectory + np.pi / 2),
|
||||||
|
]
|
||||||
|
])
|
||||||
|
|
||||||
|
blind -= self.points[self.target_idx] * self.pixel_ratio
|
||||||
|
blind = np.hypot(blind[:, 0], blind[:, 1])
|
||||||
|
|
||||||
|
print("R: ", blind)
|
||||||
|
if np.min(blind) < R:
|
||||||
|
return PilotCommand(0, 10, False, self.timer.get_elapsed())
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# Вычисляем угол к цели
|
# Вычисляем угол к цели
|
||||||
target_angle = math.atan2(self.points[self.target_idx][1] - self.pos.y, self.points[self.target_idx][0] - self.pos.x)
|
target_angle = math.atan2(self.points[self.target_idx][1] - self.pos.y, self.points[self.target_idx][0] - self.pos.x)
|
||||||
|
|
||||||
angle_trajectory = self.pos.yaw + math.pi / 2
|
|
||||||
|
|
||||||
# print("[ANGLE]", angle_trajectory, "->", target_angle)
|
|
||||||
|
|
||||||
# Вычисляем разность углов (направление поворота)
|
# Вычисляем разность углов (направление поворота)
|
||||||
angle_diff = target_angle - angle_trajectory
|
angle_diff = target_angle - angle_trajectory
|
||||||
@@ -277,14 +374,13 @@ class AutoPilot(Pilot):
|
|||||||
if angle_diff >= math.pi:
|
if angle_diff >= math.pi:
|
||||||
angle_diff -= 2 * math.pi
|
angle_diff -= 2 * math.pi
|
||||||
|
|
||||||
d_r = max(10, min(35., distance_to_target / 2))
|
d_r = max(5, min(10., distance_to_target / 2))
|
||||||
d_a_limit = d_r / 10 * 0.01
|
d_a_limit = np.radians(5)
|
||||||
|
|
||||||
command = PilotCommand(
|
command = PilotCommand(
|
||||||
max(min(d_a_limit, angle_diff), -d_a_limit),
|
max(min(d_a_limit, angle_diff), -d_a_limit),
|
||||||
d_r, False, self.timer.get_elapsed()
|
d_r, False, self.timer.get_elapsed()
|
||||||
)
|
)
|
||||||
self.timer.reset()
|
|
||||||
return command
|
return command
|
||||||
|
|
||||||
def reset_position(self, x: float = 0.0, y: float = 0.0, angle: float = 0.0):
|
def reset_position(self, x: float = 0.0, y: float = 0.0, angle: float = 0.0):
|
||||||
|
|||||||
16
constants.py
16
constants.py
@@ -1,5 +1,18 @@
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
# Ширина 1 пикселя в метрах
|
||||||
|
YANDEX_PIXEL_RATIO = {
|
||||||
|
15: 2830 / 1049,
|
||||||
|
18: 350 / 1049,
|
||||||
|
}
|
||||||
|
|
||||||
|
GOOGLE_PIXEL_RATIO = {
|
||||||
|
15: 2766 / 1031,
|
||||||
|
18: 346 / 1031,
|
||||||
|
}
|
||||||
|
|
||||||
|
WINDOW_HEIGHT = 1031
|
||||||
|
|
||||||
# Ширина и высота снимка в пикселях
|
# Ширина и высота снимка в пикселях
|
||||||
CHUNK_WIDTH = 700
|
CHUNK_WIDTH = 700
|
||||||
|
|
||||||
@@ -17,3 +30,6 @@ K = np.array([
|
|||||||
[0, _K_FOCUS_DISTANCE, _K_CENTER],
|
[0, _K_FOCUS_DISTANCE, _K_CENTER],
|
||||||
[0, 0, 1]
|
[0, 0, 1]
|
||||||
])
|
])
|
||||||
|
|
||||||
|
DEBUG_FPS: bool = False
|
||||||
|
DEBUG_LANDMARK: bool = False
|
||||||
|
|||||||
2
datasets/ya_go_maps/.gitignore
vendored
Normal file
2
datasets/ya_go_maps/.gitignore
vendored
Normal file
@@ -0,0 +1,2 @@
|
|||||||
|
images/homography_cache
|
||||||
|
*.zip
|
||||||
43
datasets/ya_go_maps/generate_dataset.py
Normal file
43
datasets/ya_go_maps/generate_dataset.py
Normal file
@@ -0,0 +1,43 @@
|
|||||||
|
import numpy as np
|
||||||
|
from pathlib import Path
|
||||||
|
from ...google_map import GoogleMap
|
||||||
|
from ...simulator import Simulator
|
||||||
|
from ...yandex_map import YandexMap
|
||||||
|
|
||||||
|
LAT_MIN, LAT_MAX = 44.960236, 54.967830
|
||||||
|
LON_MIN, LON_MAX = 53.084167, 58.677977
|
||||||
|
|
||||||
|
def create_new_asset(yandex_map, google_map):
|
||||||
|
folder = Path('dataset_ya_go_maps')
|
||||||
|
|
||||||
|
id = 0
|
||||||
|
print(id)
|
||||||
|
while (folder / f"{id:0{4}}_google.png").exists():
|
||||||
|
id += 1
|
||||||
|
|
||||||
|
google_file = folder / f"{id:0{4}}_google.png"
|
||||||
|
yandex_file = folder / f"{id:0{4}}_yandex.png"
|
||||||
|
|
||||||
|
lat = np.random.rand() * (LAT_MAX - LAT_MIN) + LAT_MIN
|
||||||
|
lon = np.random.rand() * (LON_MAX - LON_MIN) + LON_MIN
|
||||||
|
|
||||||
|
yandex_map.open(lat, lon, 18)
|
||||||
|
google_map.open(lat, lon, 18)
|
||||||
|
|
||||||
|
simulator = Simulator()
|
||||||
|
simulator._apply_perspective_transform(yandex_map.make_screenshot()).save(yandex_file)
|
||||||
|
simulator._apply_perspective_transform(google_map.make_screenshot()).save(google_file)
|
||||||
|
|
||||||
|
def main():
|
||||||
|
folder = Path('dataset_ya_go_maps')
|
||||||
|
if not folder.exists():
|
||||||
|
folder.mkdir()
|
||||||
|
|
||||||
|
yandex_map = YandexMap(initial_zoom=15)
|
||||||
|
google_map = GoogleMap(initial_zoom=15)
|
||||||
|
|
||||||
|
for i in range(4):
|
||||||
|
create_new_asset(yandex_map, google_map)
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
||||||
89
google_map.py
Normal file
89
google_map.py
Normal file
@@ -0,0 +1,89 @@
|
|||||||
|
from io import BytesIO
|
||||||
|
from PIL import Image
|
||||||
|
from selenium.webdriver.common.actions.wheel_input import ScrollOrigin
|
||||||
|
from selenium import webdriver
|
||||||
|
from selenium.webdriver.common.by import By
|
||||||
|
from selenium.webdriver.common.action_chains import ActionChains
|
||||||
|
from time import sleep
|
||||||
|
|
||||||
|
import constants
|
||||||
|
import cv2
|
||||||
|
import math
|
||||||
|
import numpy as np
|
||||||
|
import os
|
||||||
|
import utility
|
||||||
|
|
||||||
|
def generateURL(lat: float, lon: float, zoom: int):
|
||||||
|
return f"https://www.google.com/maps/@{lon},{lat},{zoom}z"
|
||||||
|
|
||||||
|
class GoogleMap:
|
||||||
|
initial_zoom: int
|
||||||
|
initial_lat: float
|
||||||
|
initial_lon: float
|
||||||
|
pixel_ratio: float
|
||||||
|
|
||||||
|
def __init__(self, initial_lat=49.103814, initial_lon=55.794258, initial_zoom=18):
|
||||||
|
self.initial_lat = initial_lat
|
||||||
|
self.initial_lon = initial_lon
|
||||||
|
self.initial_zoom = initial_zoom
|
||||||
|
self.pixel_ratio = constants.GOOGLE_PIXEL_RATIO[self.initial_zoom]
|
||||||
|
|
||||||
|
options = webdriver.ChromeOptions()
|
||||||
|
# options.add_experimental_option("detach", True)
|
||||||
|
self.driver = webdriver.Chrome(options)
|
||||||
|
self.driver.get(generateURL(initial_lat, initial_lon, initial_zoom))
|
||||||
|
self.driver.maximize_window()
|
||||||
|
|
||||||
|
action = ActionChains(self.driver)
|
||||||
|
sleep(5)
|
||||||
|
self.driver.execute_script('document.querySelector(\'.yHc72.qk5Wte\').click()')
|
||||||
|
sleep(5)
|
||||||
|
|
||||||
|
def open(self, lat, lon, zoom):
|
||||||
|
self.initial_lat = lat
|
||||||
|
self.initial_lon = lon
|
||||||
|
self.initial_zoom = zoom
|
||||||
|
self.pixel_ratio = constants.GOOGLE_PIXEL_RATIO[self.initial_zoom]
|
||||||
|
self.driver.get(generateURL(lat, lon, zoom))
|
||||||
|
|
||||||
|
def save_photo(self, filename: str):
|
||||||
|
im = self.make_screenshot()
|
||||||
|
im.save(filename)
|
||||||
|
|
||||||
|
def destroy(self):
|
||||||
|
self.driver.close()
|
||||||
|
|
||||||
|
def get_size(self) -> tuple[int, int]:
|
||||||
|
html = self.driver.find_element(By.TAG_NAME, 'html')
|
||||||
|
return (html.size['width'], html.size['height'])
|
||||||
|
|
||||||
|
def scroll(self, x: float, y: float, count: int = 1, inner_zoom: bool = True):
|
||||||
|
html = self.driver.find_element(By.TAG_NAME, 'html')
|
||||||
|
|
||||||
|
x_offset = (x - 0.5) * (html.size['width'] - 72) + 72
|
||||||
|
y_offset = (y - 0.5) * html.size['height']
|
||||||
|
action = ActionChains(self.driver)
|
||||||
|
|
||||||
|
for i in range(count-1):
|
||||||
|
action.scroll_from_origin(ScrollOrigin(html, int(x_offset), int(y_offset)), 0, -100 if inner_zoom else 100)
|
||||||
|
action.perform()
|
||||||
|
if i != count - 1:
|
||||||
|
sleep(0.25)
|
||||||
|
|
||||||
|
def move(self, dx: float, dy: float):
|
||||||
|
self.driver.execute_script(utility.google_map_js_move_script(dx, dy))
|
||||||
|
|
||||||
|
def make_as_center(self, x: float, y: float):
|
||||||
|
dx = (x - 0.5) * self.get_size()[0]
|
||||||
|
dy = (0.5 - y) * self.get_size()[1]
|
||||||
|
self.move(dx, dy)
|
||||||
|
sleep(1)
|
||||||
|
|
||||||
|
def make_screenshot(self) -> Image.Image:
|
||||||
|
png = self.driver.get_screenshot_as_png()
|
||||||
|
im = Image.open(BytesIO(png))
|
||||||
|
return utility.cv2_to_pil(np.array(im)[:, 72:])
|
||||||
|
|
||||||
|
def get_geolocation(self):
|
||||||
|
current_url = self.driver.current_url
|
||||||
|
return utility.parse_google_maps_url(current_url)
|
||||||
384
main.py
384
main.py
@@ -1,23 +1,33 @@
|
|||||||
|
from google_map import GoogleMap
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from position import Position
|
||||||
from simulator import Simulator
|
from simulator import Simulator
|
||||||
from time import sleep
|
from time import sleep
|
||||||
from trajectory_drawer import TrajectoryDrawer
|
from trajectory_drawer import TrajectoryDrawer
|
||||||
|
from utility import cv2_to_pil
|
||||||
|
from vision_chunk import VisionChunk
|
||||||
from visualization import VisualizationManager
|
from visualization import VisualizationManager
|
||||||
from yandex_map import YandexMap
|
from yandex_map import YandexMap
|
||||||
from vision_chunk import VisionChunk
|
|
||||||
from utility import cv2_to_pil
|
|
||||||
import random
|
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import autopilot
|
||||||
|
import constants
|
||||||
import cv2
|
import cv2
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import pickle
|
||||||
|
import random
|
||||||
|
import utility
|
||||||
|
|
||||||
import autopilot
|
def get_map(map_name: str = 'google', lat=49.103814, lon=55.794258, zoom=18):
|
||||||
|
if map_name == 'google': return GoogleMap(lat, lon, zoom)
|
||||||
|
if map_name == 'yandex': return YandexMap(lat, lon, zoom)
|
||||||
|
return None
|
||||||
|
|
||||||
def make_global_photo(filename):
|
def make_global_photo(filename, map_name: str = 'google', lat=49.103814, lon=55.794258, zoom=13):
|
||||||
yandexMap = YandexMap()
|
online_map: YandexMap | GoogleMap = get_map(map_name, lat, lon, zoom)
|
||||||
yandexMap.save_photo(filename)
|
online_map.save_photo(filename)
|
||||||
yandexMap.destroy()
|
online_map.destroy()
|
||||||
|
|
||||||
def get_trajectory_points(bg_img: str) -> list[(float, float)]:
|
def get_trajectory_points(bg_img: str) -> list[(float, float)]:
|
||||||
trajectoryDrawer = TrajectoryDrawer(bg_img)
|
trajectoryDrawer = TrajectoryDrawer(bg_img)
|
||||||
@@ -26,97 +36,132 @@ def get_trajectory_points(bg_img: str) -> list[(float, float)]:
|
|||||||
points = list(map(lambda p: [p[0] / trajectoryDrawer.img.shape[1], p[1] / trajectoryDrawer.img.shape[0]], trajectoryDrawer.points))
|
points = list(map(lambda p: [p[0] / trajectoryDrawer.img.shape[1], p[1] / trajectoryDrawer.img.shape[0]], trajectoryDrawer.points))
|
||||||
return points
|
return points
|
||||||
|
|
||||||
def main():
|
def build(name: str, map_name: str, lat: float, lon: float):
|
||||||
# Скриншот местности
|
|
||||||
# make_global_photo('map.jpg')
|
|
||||||
|
|
||||||
# Получаем траекторию от пользователя
|
# Создание папки с информацией о маршруте
|
||||||
# points = get_trajectory_points('map.jpg')
|
dir = Path('trajectories')
|
||||||
|
if not dir.exists(): dir.mkdir()
|
||||||
|
dir /= name
|
||||||
|
assert not dir.exists()
|
||||||
|
dir.mkdir()
|
||||||
|
dir_chunks = dir / 'chunks'
|
||||||
|
dir_chunks.mkdir()
|
||||||
|
|
||||||
# Trajectory #1
|
make_global_photo('map.jpg', map_name, lat, lon, 15)
|
||||||
# points = [[np.float64(0.5384504359393909), np.float64(0.4084520767967683)], [np.float64(0.4451750568707629), np.float64(0.38213330305374654)], [np.float64(0.49266070439660997), np.float64(0.2789637099811013)], [np.float64(0.36377108968359656), np.float64(0.3263375027185404)], [np.float64(0.3535955937852008), np.float64(0.4337180995900692)]]
|
points = get_trajectory_points('map.jpg')
|
||||||
|
online_map: YandexMap | GoogleMap = get_map(map_name, lat, lon, 15)
|
||||||
|
|
||||||
# Trajectory #2
|
|
||||||
# points = [[np.float64(0.29197731306713737), np.float64(0.3452870198135161)], [np.float64(0.33494051797147517), np.float64(0.2010601397017569)], [np.float64(0.39768940934491587), np.float64(0.25369768718780034)], [np.float64(0.4027771572941138), np.float64(0.4158213334448144)], [np.float64(0.2914120077394487), np.float64(0.5547844588079692)]]
|
|
||||||
|
|
||||||
# Trajectory #3
|
width, height = online_map.get_size()
|
||||||
# points = [[np.float64(0.2755834585641664), np.float64(0.45687862048392835)], [np.float64(0.295934450360958), np.float64(0.5021469113219258)], [np.float64(0.32872215936689997), np.float64(0.4810918923275084)], [np.float64(0.3649017003389739), np.float64(0.5295184360146684)], [np.float64(0.3999506306556705), np.float64(0.49477765467387963)]]
|
|
||||||
|
|
||||||
# Trajectory #4
|
|
||||||
# points = [[np.float64(0.42143223310783934), np.float64(0.6663760594783815)], [np.float64(0.4253893704016599), np.float64(0.5537317078582484)], [np.float64(0.5124463908657128), np.float64(0.5621537154560153)], [np.float64(0.5124463908657128), np.float64(0.6684815613778233)], [np.float64(0.42143223310783934), np.float64(0.6663760594783815)]]
|
|
||||||
|
|
||||||
# Trajectory #5
|
|
||||||
# points = [[np.float64(0.5983728006743884), np.float64(0.7348048712102382)], [np.float64(0.5966768846913225), np.float64(0.5453097002604814)], [np.float64(0.6345523416464622), np.float64(0.7190136069644251)], [np.float64(0.6402053949233488), np.float64(0.5495207040593649)], [np.float64(0.5983728006743884), np.float64(0.7348048712102382)]]
|
|
||||||
|
|
||||||
# Trajectory #6
|
|
||||||
# points = [[np.float64(0.4406526142492536), np.float64(0.28106921188054296)], [np.float64(0.38581799746345413), np.float64(0.2968604761263561)], [np.float64(0.3931669667234066), np.float64(0.353709027411283)], [np.float64(0.4248240650739713), np.float64(0.35265627646156217)], [np.float64(0.40616898926024564), np.float64(0.3179154951207735)]]
|
|
||||||
|
|
||||||
# Trajectory #7
|
|
||||||
# points = [[np.float64(0.5491912371654754), np.float64(0.7505961354560512)], [np.float64(0.5537136797869846), np.float64(0.6863783275230781)], [np.float64(0.5017055896396284), np.float64(0.6653233085286606)], [np.float64(0.5520177638039186), np.float64(0.6042637534448503)], [np.float64(0.5593667330638712), np.float64(0.516885424618018)]]
|
|
||||||
|
|
||||||
# Trajectory #8
|
|
||||||
# points =
|
|
||||||
|
|
||||||
# Trajectory #9
|
|
||||||
# points =
|
|
||||||
|
|
||||||
# Trajectory #10
|
|
||||||
# points =
|
|
||||||
|
|
||||||
print(points)
|
|
||||||
|
|
||||||
# Для каждой точки сделаем приближенный снимок
|
|
||||||
yandexMap = YandexMap()
|
|
||||||
chunks: list[VisionChunk] = []
|
|
||||||
|
|
||||||
plt.ion()
|
|
||||||
for i in range(len(points)):
|
|
||||||
point = points[i]
|
|
||||||
yandexMap.scroll(point[0], point[1], 5, True)
|
|
||||||
sleep(1)
|
|
||||||
cv2_img = yandexMap.make_screenshot(point[0], point[1], 0.2, 0.2)
|
|
||||||
img = cv2_to_pil(cv2_img)
|
|
||||||
chunk = VisionChunk(img)
|
|
||||||
Path('chunks').mkdir(exist_ok=True)
|
|
||||||
chunk.save_image(Path('.') / 'chunks' / f'chunk_{i}.png')
|
|
||||||
plt.subplot(1, len(points), i+1)
|
|
||||||
plt.imshow(img)
|
|
||||||
plt.pause(0.25)
|
|
||||||
yandexMap.scroll(point[0], point[1], 5, False)
|
|
||||||
|
|
||||||
plt.tight_layout()
|
|
||||||
|
|
||||||
# Выделим на каждой картинке ключевые точки
|
|
||||||
for i in range(len(points)):
|
|
||||||
chunk = VisionChunk.load_image(Path('chunks') / f'chunk_{i}.png')
|
|
||||||
chunks.append(chunk)
|
|
||||||
kp, des = chunk.compute_keypoints()
|
|
||||||
|
|
||||||
plt.subplot(1, len(points), i+1)
|
|
||||||
plt.imshow(chunk.image)
|
|
||||||
kp_coords = np.array([j.pt for j in kp])
|
|
||||||
if len(kp_coords) > 0:
|
|
||||||
plt.scatter(kp_coords[:, 0], kp_coords[:, 1], c='red', s=20, alpha=0.7, marker='o')
|
|
||||||
plt.pause(0.2)
|
|
||||||
plt.ioff()
|
|
||||||
|
|
||||||
plt.show(block=True)
|
|
||||||
|
|
||||||
# Начнём симуляцию полёта с первой точки
|
|
||||||
yandexMap.scroll(points[0][0], points[0][1], 5, True)
|
|
||||||
sleep(0.2)
|
|
||||||
yandexMap.make_as_center(*points[0])
|
|
||||||
sleep(1)
|
|
||||||
|
|
||||||
vis_manager = VisualizationManager()
|
|
||||||
width, height = yandexMap.get_size()
|
|
||||||
# print(width, height)
|
|
||||||
points_coords = np.array(list(map(lambda p: [
|
points_coords = np.array(list(map(lambda p: [
|
||||||
(p[0] - points[0][0]) * width, (points[0][1] - p[1]) * height
|
(p[0] - points[0][0]) * width, (points[0][1] - p[1]) * height
|
||||||
], points)))
|
], points)))
|
||||||
points_coords *= 2 ** 4
|
|
||||||
pilot = autopilot.AutoPilot(points_coords, chunks, vis_manager)
|
points_coords *= online_map.pixel_ratio
|
||||||
simulator = Simulator(yandexMap)
|
|
||||||
|
# Начнём симуляцию полёта с первой точки
|
||||||
|
online_map.make_as_center(*points[0])
|
||||||
|
sleep(3)
|
||||||
|
online_map.scroll(0.5, 0.5, 10, True)
|
||||||
|
sleep(2)
|
||||||
|
geo = online_map.get_geolocation()
|
||||||
|
online_map.open(geo['lat'], geo['lon'], 18)
|
||||||
|
sleep(5)
|
||||||
|
|
||||||
|
points_coords_pixel = points_coords.copy() / online_map.pixel_ratio
|
||||||
|
|
||||||
|
pilot = autopilot.AutoPilot(points_coords_pixel, pixel_ratio=online_map.pixel_ratio)
|
||||||
|
simulator = Simulator(online_map)
|
||||||
|
pilot.target_idx = 0
|
||||||
|
|
||||||
|
pilot.pos = simulator.pos.copy()
|
||||||
|
command = pilot.make_command()
|
||||||
|
|
||||||
|
positions: list[Position] = []
|
||||||
|
|
||||||
|
print("points_coords_pixel:", points_coords_pixel)
|
||||||
|
for i in range(5000):
|
||||||
|
|
||||||
|
if command.stop:
|
||||||
|
break
|
||||||
|
|
||||||
|
chunk = simulator.get_chunk()
|
||||||
|
pilot.pos = simulator.pos.copy()
|
||||||
|
command = pilot.make_command()
|
||||||
|
command.velocity /= online_map.pixel_ratio
|
||||||
|
|
||||||
|
print("Position:", simulator.pos)
|
||||||
|
|
||||||
|
# Save Image
|
||||||
|
chunk.save_image(dir_chunks / f"chunk_{len(positions)}.png")
|
||||||
|
|
||||||
|
positions.append(simulator.pos.copy() * online_map.pixel_ratio)
|
||||||
|
|
||||||
|
simulator.handle(command.dangle, command.velocity)
|
||||||
|
if i == 0 and map_name == 'google':
|
||||||
|
simulator.pos.x = 0
|
||||||
|
simulator.pos.y = 0
|
||||||
|
sleep(1.5)
|
||||||
|
|
||||||
|
data = {
|
||||||
|
'points': points_coords,
|
||||||
|
'chunk_positions': positions,
|
||||||
|
'initial_geolocation': geo
|
||||||
|
}
|
||||||
|
|
||||||
|
print(points_coords)
|
||||||
|
|
||||||
|
file_positions = dir / 'positions.pkl'
|
||||||
|
with file_positions.open('wb') as file:
|
||||||
|
pickle.dump(data, file)
|
||||||
|
|
||||||
|
print("WRITE POINTS:", points)
|
||||||
|
|
||||||
|
sleep(15)
|
||||||
|
online_map.destroy()
|
||||||
|
|
||||||
|
def run(name: str, map_name: str, ref_min_distance: float):
|
||||||
|
dir = Path('trajectories')
|
||||||
|
assert dir.exists()
|
||||||
|
dir /= name
|
||||||
|
assert dir.exists(), "Укажите корректное название маршрута"
|
||||||
|
dir_chunks = dir / 'chunks'
|
||||||
|
file_positions = dir / 'positions.pkl'
|
||||||
|
|
||||||
|
with file_positions.open('rb') as file:
|
||||||
|
data = pickle.load(file)
|
||||||
|
|
||||||
|
initial_geolocation = data['initial_geolocation']
|
||||||
|
|
||||||
|
lat = initial_geolocation['lat']
|
||||||
|
lon = initial_geolocation['lon']
|
||||||
|
online_map: YandexMap | GoogleMap = get_map(map_name, lat, lon, 18)
|
||||||
|
sleep(2)
|
||||||
|
|
||||||
|
chunks: list[VisionChunk] = []
|
||||||
|
for i in range(len(data['chunk_positions'])):
|
||||||
|
pos = data['chunk_positions'][i]
|
||||||
|
if len(chunks) == 0 or np.hypot(chunks[-1].pos.x - pos.x, chunks[-1].pos.y - pos.y) > ref_min_distance:
|
||||||
|
chunk = VisionChunk.load_image(dir_chunks / f"chunk_{i}.png")
|
||||||
|
chunk.pos = data['chunk_positions'][i] / online_map.pixel_ratio
|
||||||
|
chunks.append(chunk)
|
||||||
|
|
||||||
|
r = 0
|
||||||
|
for i in range(len(data['points']) - 1):
|
||||||
|
r += np.hypot(
|
||||||
|
data['points'][i][0] - data['points'][i+1][0],
|
||||||
|
data['points'][i][1] - data['points'][i+1][1]
|
||||||
|
)
|
||||||
|
print("R: ", r)
|
||||||
|
|
||||||
|
return
|
||||||
|
|
||||||
|
points = data['points'] / online_map.pixel_ratio
|
||||||
|
print("READ POINTS:", points)
|
||||||
|
|
||||||
|
vis_manager = VisualizationManager()
|
||||||
|
pilot = autopilot.AutoPilot(points, chunks, vis_manager, online_map.pixel_ratio)
|
||||||
|
simulator = Simulator(online_map)
|
||||||
pilot.target_idx = 0
|
pilot.target_idx = 0
|
||||||
|
|
||||||
chunk = simulator.get_chunk()
|
chunk = simulator.get_chunk()
|
||||||
@@ -125,38 +170,26 @@ def main():
|
|||||||
vis_manager.update_display()
|
vis_manager.update_display()
|
||||||
vis_manager.pause(1)
|
vis_manager.pause(1)
|
||||||
|
|
||||||
vis_manager.set_target_points(points_coords)
|
vis_manager.set_target_points(data['points'])
|
||||||
|
|
||||||
proc_time = np.array([])
|
proc_time = np.array([])
|
||||||
|
|
||||||
zoom_next_event = random.randint(5, 10)
|
|
||||||
|
|
||||||
errors = []
|
errors = []
|
||||||
chunk_errors = []
|
|
||||||
chunk_improves = []
|
sleep(1)
|
||||||
|
|
||||||
last_chunk_index = 0
|
|
||||||
|
|
||||||
for i in range(10000000000):
|
for i in range(10000000000):
|
||||||
print(f"Image #{i}")
|
if i > 0:
|
||||||
if i == zoom_next_event:
|
simulator.set_pitch(np.sin(i / 10) * 5)
|
||||||
r = random.randint(0, 1)
|
simulator.set_roll(np.sin(i / 15) * 5)
|
||||||
direction = ['up', 'down'][r]
|
simulator.set_zoom(1.0 + np.sin(i / 10) * 0.3)
|
||||||
# simulator.change_zoom(direction)
|
|
||||||
zoom_next_event = i + random.randint(20, 40)
|
|
||||||
|
|
||||||
|
|
||||||
# if i > 0:
|
|
||||||
# simulator.set_pitch(np.sin(i / 10) * 5)
|
|
||||||
# simulator.set_roll(np.sin(i / 15) * 5)
|
|
||||||
# simulator.set_zoom(1.0 + np.sin(i / 10) * 0.3)
|
|
||||||
|
|
||||||
if command.stop:
|
if command.stop:
|
||||||
break
|
break
|
||||||
|
|
||||||
# simulator.handle(command.dangle, command.velocity)
|
|
||||||
chunk = simulator.get_chunk()
|
chunk = simulator.get_chunk()
|
||||||
command = pilot.handle(chunk)
|
command = pilot.handle(chunk)
|
||||||
|
command.velocity /= online_map.pixel_ratio
|
||||||
|
|
||||||
proc_time = np.append(proc_time, command.proccessing_time)
|
proc_time = np.append(proc_time, command.proccessing_time)
|
||||||
|
|
||||||
@@ -167,34 +200,133 @@ def main():
|
|||||||
vis_manager.pause(0.2)
|
vis_manager.pause(0.2)
|
||||||
|
|
||||||
vis_manager.set_target_index(pilot.target_idx)
|
vis_manager.set_target_index(pilot.target_idx)
|
||||||
vis_manager.update_drone_trajectory(pilot.pos.x, pilot.pos.y)
|
vis_manager.update_drone_trajectory(pilot.pos.x * online_map.pixel_ratio, pilot.pos.y * online_map.pixel_ratio)
|
||||||
vis_manager.update_global_map(simulator.pos.x, simulator.pos.y)
|
vis_manager.update_global_map(simulator.pos.x * online_map.pixel_ratio, simulator.pos.y * online_map.pixel_ratio)
|
||||||
vis_manager.update_error_plot(i, pilot.pos.x, pilot.pos.y, simulator.pos.x, simulator.pos.y)
|
vis_manager.update_error_plot(i, pilot.pos.x * online_map.pixel_ratio, pilot.pos.y * online_map.pixel_ratio, simulator.pos.x * online_map.pixel_ratio, simulator.pos.y * online_map.pixel_ratio)
|
||||||
|
|
||||||
errors.append(np.hypot(pilot.pos.x - simulator.pos.x, pilot.pos.y - simulator.pos.y))
|
errors.append(np.hypot((pilot.pos.x - simulator.pos.x) * online_map.pixel_ratio, (pilot.pos.y - simulator.pos.y) * online_map.pixel_ratio))
|
||||||
if last_chunk_index != pilot.target_idx:
|
|
||||||
last_chunk_index = pilot.target_idx
|
|
||||||
chunk_errors.append(errors[-1])
|
|
||||||
chunk_improves.append(errors[-1] - errors[max(len(errors) - 2, 0)])
|
|
||||||
|
|
||||||
vis_manager.update_display()
|
vis_manager.update_display()
|
||||||
vis_manager.pause(0.2)
|
vis_manager.pause(0.2)
|
||||||
|
|
||||||
last_proc_times = proc_time[-10:]
|
last_proc_times = proc_time[-30:]
|
||||||
|
print(F"\nImage #{i}")
|
||||||
print("Average FPS:", 1 / last_proc_times.mean())
|
print("Average FPS:", 1 / last_proc_times.mean())
|
||||||
print("Pilot coords:", pilot.pos)
|
print("Pilot coords:", pilot.pos)
|
||||||
print("Simulator coords:", simulator.pos)
|
print("Simulator coords:", simulator.pos)
|
||||||
|
sleep(0.5)
|
||||||
simulator.handle(command.dangle, command.velocity)
|
simulator.handle(command.dangle, command.velocity)
|
||||||
|
if i == 0 and map_name == 'google':
|
||||||
|
simulator.pos.x = 0
|
||||||
|
simulator.pos.y = 0
|
||||||
|
|
||||||
print("Errors:", errors)
|
print("Errors:", errors)
|
||||||
print("MSE:", (np.array(errors) ** 2).mean())
|
print("MSE:", (np.array(errors) ** 2).mean())
|
||||||
print("RMSE:", (np.array(errors) ** 2).mean() ** 0.5)
|
print("RMSE:", (np.array(errors) ** 2).mean() ** 0.5)
|
||||||
print("Chunk errors:", chunk_errors)
|
|
||||||
print("Chunk error improves:", chunk_improves)
|
|
||||||
print("Average FPS:", 1 / proc_time.mean())
|
print("Average FPS:", 1 / proc_time.mean())
|
||||||
vis_manager.show_final()
|
vis_manager.show_final()
|
||||||
|
|
||||||
if __name__ == "__main__":
|
def parse_args():
|
||||||
main()
|
"""Парсер аргументов командной строки"""
|
||||||
|
parser = argparse.ArgumentParser(description='Обработка траекторий')
|
||||||
|
|
||||||
#TODO
|
# Добавляем обязательный аргумент --mode
|
||||||
|
parser.add_argument(
|
||||||
|
'--mode',
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
choices=['standalone', 'build', 'run'],
|
||||||
|
help='Режим работы: standalone, build или run'
|
||||||
|
)
|
||||||
|
|
||||||
|
# Добавляем опциональный аргумент --name
|
||||||
|
parser.add_argument(
|
||||||
|
'--name',
|
||||||
|
type=str,
|
||||||
|
default=utility.generate_folder_name(),
|
||||||
|
help='Название траектории'
|
||||||
|
)
|
||||||
|
|
||||||
|
# Координаты
|
||||||
|
parser.add_argument(
|
||||||
|
'--lat',
|
||||||
|
type=float,
|
||||||
|
default=49.103814,
|
||||||
|
help='Широта (по умолчанию 49.103814)'
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
'--lon',
|
||||||
|
type=float,
|
||||||
|
default=55.794258,
|
||||||
|
help='Долгота (по умолчанию 55.794258)'
|
||||||
|
)
|
||||||
|
|
||||||
|
# Источник эталонных изображений (ориентиров)
|
||||||
|
parser.add_argument(
|
||||||
|
'--reference',
|
||||||
|
type=str,
|
||||||
|
default='google',
|
||||||
|
choices=['google', 'yandex'],
|
||||||
|
help='Откуда берутся эталонные изображения (ориентиры): google или yandex (по умолчанию google)'
|
||||||
|
)
|
||||||
|
|
||||||
|
# Место проведения симуляции
|
||||||
|
parser.add_argument(
|
||||||
|
'--simulation',
|
||||||
|
type=str,
|
||||||
|
default='yandex',
|
||||||
|
choices=['google', 'yandex'],
|
||||||
|
help='Где проводится симуляция: google или yandex (по умолчанию yandex)'
|
||||||
|
)
|
||||||
|
|
||||||
|
# Место проведения симуляции
|
||||||
|
parser.add_argument(
|
||||||
|
'--ref-min-distance',
|
||||||
|
type=float,
|
||||||
|
default=100,
|
||||||
|
help='Минимальное расстояние между эталонами'
|
||||||
|
)
|
||||||
|
|
||||||
|
# Место проведения симуляции
|
||||||
|
parser.add_argument(
|
||||||
|
'--debug-fps',
|
||||||
|
action='store_true',
|
||||||
|
help='Включить отладку FPS'
|
||||||
|
)
|
||||||
|
|
||||||
|
# Место проведения симуляции
|
||||||
|
parser.add_argument(
|
||||||
|
'--debug-landmark',
|
||||||
|
action='store_true',
|
||||||
|
help='Включить отладку эталонов'
|
||||||
|
)
|
||||||
|
|
||||||
|
# Парсим аргументы
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
# Проверяем, что для build и run указан --name
|
||||||
|
if args.mode in ['build', 'run'] and not args.name:
|
||||||
|
parser.error(f"--name обязателен для режима {args.mode}")
|
||||||
|
|
||||||
|
return args
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
args = parse_args()
|
||||||
|
name = args.name
|
||||||
|
mode = args.mode
|
||||||
|
sim: str = args.simulation
|
||||||
|
ref: str = args.reference
|
||||||
|
lat: float = args.lat
|
||||||
|
lon: float = args.lon
|
||||||
|
rmd: float = args.ref_min_distance
|
||||||
|
|
||||||
|
constants.DEBUG_FPS = args.debug_fps
|
||||||
|
constants.DEBUG_LANDMARK = args.debug_landmark
|
||||||
|
|
||||||
|
if mode == 'build' or mode == 'standalone':
|
||||||
|
build(name, ref, lat, lon)
|
||||||
|
|
||||||
|
if mode == 'run' or mode == 'standalone':
|
||||||
|
run(name, sim, rmd)
|
||||||
|
|||||||
BIN
map.jpg
BIN
map.jpg
Binary file not shown.
|
Before Width: | Height: | Size: 3.4 MiB After Width: | Height: | Size: 3.3 MiB |
1
models/GAN/.gitignore
vendored
Normal file
1
models/GAN/.gitignore
vendored
Normal file
@@ -0,0 +1 @@
|
|||||||
|
runs
|
||||||
254
models/GAN/README.md
Normal file
254
models/GAN/README.md
Normal file
@@ -0,0 +1,254 @@
|
|||||||
|
# GAN Trainer для преобразования изображений Yandex → Google
|
||||||
|
|
||||||
|
Этот модуль содержит реализацию тренера для GAN (Generative Adversarial Network) модели, предназначенной для преобразования изображений карт Yandex в стиль Google Maps.
|
||||||
|
|
||||||
|
## Структура проекта
|
||||||
|
|
||||||
|
```
|
||||||
|
autopilot/models/GAN/
|
||||||
|
├── gan.py # Основная реализация GAN модели
|
||||||
|
├── trainer.py # Тренер для обучения GAN
|
||||||
|
├── test_trainer.py # Тесты для тренера
|
||||||
|
├── train_example.py # Пример использования тренера
|
||||||
|
└── README.md # Этот файл
|
||||||
|
```
|
||||||
|
|
||||||
|
## Модель GAN
|
||||||
|
|
||||||
|
Модель состоит из двух основных компонентов:
|
||||||
|
|
||||||
|
### 1. Генератор (GeneratorUNet)
|
||||||
|
- Архитектура U-Net для преобразования изображений
|
||||||
|
- Принимает изображение Yandex (3 канала RGB)
|
||||||
|
- Возвращает изображение в стиле Google (3 канала RGB)
|
||||||
|
- Использует skip connections для сохранения деталей
|
||||||
|
|
||||||
|
### 2. Дискриминатор (DiscriminatorPatchGAN)
|
||||||
|
- PatchGAN архитектура
|
||||||
|
- Принимает пару изображений (Yandex + Google)
|
||||||
|
- Возвращает вероятность того, что пара реальная
|
||||||
|
- Работает с патчами изображения 41x41
|
||||||
|
|
||||||
|
### Функция потерь (GANLoss)
|
||||||
|
Поддерживает три режима:
|
||||||
|
- `vanilla`: Бинарная кросс-энтропия
|
||||||
|
- `lsgan`: Least Squares GAN (более стабильный)
|
||||||
|
- `wgangp`: Wasserstein GAN with Gradient Penalty
|
||||||
|
|
||||||
|
## Тренер (GANTrainer)
|
||||||
|
|
||||||
|
### Основные возможности
|
||||||
|
|
||||||
|
1. **Обучение с чередованием**:
|
||||||
|
- Обучение генератора и дискриминатора поочередно
|
||||||
|
- Поддержка L1 потерь для сохранения структуры
|
||||||
|
|
||||||
|
2. **Валидация и мониторинг**:
|
||||||
|
- Отдельные потери для генератора и дискриминатора
|
||||||
|
- Логирование в TensorBoard
|
||||||
|
- Ранняя остановка
|
||||||
|
|
||||||
|
3. **Сохранение и загрузка**:
|
||||||
|
- Чекпоинты каждой эпохи
|
||||||
|
- Лучшая модель
|
||||||
|
- Финальная модель
|
||||||
|
- История обучения
|
||||||
|
|
||||||
|
4. **Оценка модели**:
|
||||||
|
- Метрики на тестовом наборе
|
||||||
|
- Генерация примеров
|
||||||
|
|
||||||
|
### Быстрый старт
|
||||||
|
|
||||||
|
```python
|
||||||
|
import torch
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
from models.GAN.gan import create_image_gan
|
||||||
|
from models.GAN.trainer import GANTrainer
|
||||||
|
|
||||||
|
# Конфигурация
|
||||||
|
config = {
|
||||||
|
"learning_rate": 2e-4,
|
||||||
|
"beta1": 0.5,
|
||||||
|
"beta2": 0.999,
|
||||||
|
"batch_size": 4,
|
||||||
|
"output_dir": "runs/gan_training",
|
||||||
|
"gan_mode": "vanilla",
|
||||||
|
"lambda_L1": 100.0,
|
||||||
|
"early_stopping_patience": 20,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Устройство
|
||||||
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
|
||||||
|
# Создание модели
|
||||||
|
model = create_image_gan(
|
||||||
|
input_channels=3,
|
||||||
|
output_channels=3,
|
||||||
|
gan_mode=config["gan_mode"],
|
||||||
|
lambda_L1=config["lambda_L1"],
|
||||||
|
use_cuda=(device.type == "cuda"),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Создание даталоадеров (замените на свои данные)
|
||||||
|
train_loader = DataLoader(train_dataset, batch_size=config["batch_size"], shuffle=True)
|
||||||
|
val_loader = DataLoader(val_dataset, batch_size=config["batch_size"], shuffle=False)
|
||||||
|
|
||||||
|
# Создание тренера
|
||||||
|
trainer = GANTrainer(
|
||||||
|
model=model,
|
||||||
|
train_loader=train_loader,
|
||||||
|
val_loader=val_loader,
|
||||||
|
device=device,
|
||||||
|
config=config,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Обучение
|
||||||
|
trainer.train(num_epochs=100)
|
||||||
|
|
||||||
|
# Оценка
|
||||||
|
metrics = trainer.evaluate(test_loader)
|
||||||
|
```
|
||||||
|
|
||||||
|
### Конфигурация обучения
|
||||||
|
|
||||||
|
#### Базовая конфигурация
|
||||||
|
```python
|
||||||
|
config = {
|
||||||
|
# Параметры оптимизатора
|
||||||
|
"learning_rate": 2e-4, # Learning rate
|
||||||
|
"beta1": 0.5, # Adam beta1
|
||||||
|
"beta2": 0.999, # Adam beta2
|
||||||
|
|
||||||
|
# Параметры обучения
|
||||||
|
"batch_size": 4, # Размер батча
|
||||||
|
"epochs": 100, # Количество эпох
|
||||||
|
|
||||||
|
# Параметры GAN
|
||||||
|
"gan_mode": "vanilla", # Режим GAN
|
||||||
|
"lambda_L1": 100.0, # Вес L1 потерь
|
||||||
|
|
||||||
|
# Регуляризация
|
||||||
|
"grad_clip": 1.0, # Gradient clipping
|
||||||
|
|
||||||
|
# Ранняя остановка
|
||||||
|
"early_stopping_patience": 20,
|
||||||
|
|
||||||
|
# Выходные данные
|
||||||
|
"output_dir": "runs/gan",
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Расширенная конфигурация
|
||||||
|
```python
|
||||||
|
config = {
|
||||||
|
"learning_rate": 2e-4,
|
||||||
|
"beta1": 0.5,
|
||||||
|
"beta2": 0.999,
|
||||||
|
"batch_size": 8,
|
||||||
|
"epochs": 200,
|
||||||
|
"gan_mode": "lsgan", # Более стабильный LSGAN
|
||||||
|
"lambda_L1": 100.0,
|
||||||
|
"grad_clip": 1.0,
|
||||||
|
"weight_decay": 1e-4, # Weight decay
|
||||||
|
"early_stopping_patience": 30,
|
||||||
|
"early_stopping_min_delta": 1e-4,
|
||||||
|
"output_dir": "runs/gan_advanced",
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### Методы тренера
|
||||||
|
|
||||||
|
#### Основные методы
|
||||||
|
- `train_epoch()`: Обучение на одной эпохе
|
||||||
|
- `validate()`: Валидация модели
|
||||||
|
- `train(num_epochs)`: Полное обучение
|
||||||
|
- `evaluate(test_loader)`: Оценка на тестовых данных
|
||||||
|
|
||||||
|
#### Управление чекпоинтами
|
||||||
|
- `save_checkpoint(is_best=False)`: Сохранение чекпоинта
|
||||||
|
- `load_checkpoint(path, resume_training=False)`: Загрузка чекпоинта
|
||||||
|
|
||||||
|
### Выходные файлы
|
||||||
|
|
||||||
|
После обучения создаются следующие файлы:
|
||||||
|
|
||||||
|
```
|
||||||
|
runs/gan_training/
|
||||||
|
├── config.json # Конфигурация обучения
|
||||||
|
├── training_history.json # История потерь
|
||||||
|
├── model_best.pth # Лучшая модель
|
||||||
|
├── model_final.pth # Финальная модель
|
||||||
|
├── checkpoint_epoch_1.pth # Чекпоинты каждой эпохи
|
||||||
|
├── checkpoint_epoch_2.pth
|
||||||
|
├── ...
|
||||||
|
└── tensorboard/ # Логи TensorBoard
|
||||||
|
├── events.out.tfevents...
|
||||||
|
└── ...
|
||||||
|
```
|
||||||
|
|
||||||
|
### TensorBoard
|
||||||
|
|
||||||
|
Для визуализации обучения используйте TensorBoard:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
tensorboard --logdir runs/gan_training/tensorboard
|
||||||
|
```
|
||||||
|
|
||||||
|
Доступные метрики:
|
||||||
|
- `train/batch_g_loss`: Потери генератора на батче
|
||||||
|
- `train/batch_d_loss`: Потери дискриминатора на батче
|
||||||
|
- `train/batch_g_l1_loss`: L1 потери генератора
|
||||||
|
- `train/epoch_g_loss`: Потери генератора на эпохе
|
||||||
|
- `train/epoch_d_loss`: Потери дискриминатора на эпохе
|
||||||
|
- `val/epoch_g_loss`: Валидационные потери генератора
|
||||||
|
- `val/epoch_d_loss`: Валидационные потери дискриминатора
|
||||||
|
|
||||||
|
### Тестирование
|
||||||
|
|
||||||
|
Запустите тесты для проверки работоспособности:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python models/GAN/test_trainer.py
|
||||||
|
```
|
||||||
|
|
||||||
|
### Пример использования
|
||||||
|
|
||||||
|
Полный пример использования смотрите в `train_example.py`.
|
||||||
|
|
||||||
|
### Советы по обучению
|
||||||
|
|
||||||
|
1. **Начальные значения**:
|
||||||
|
- Используйте `gan_mode="lsgan"` для более стабильного обучения
|
||||||
|
- Начните с `lambda_L1=100.0` и регулируйте по необходимости
|
||||||
|
- Используйте маленький `batch_size` (4-8) при ограниченной памяти GPU
|
||||||
|
|
||||||
|
2. **Мониторинг**:
|
||||||
|
- Следите за балансом потерь генератора и дискриминатора
|
||||||
|
- Если потери дискриминатора близки к 0, генератор не обучается
|
||||||
|
- Если потери генератора слишком высоки, уменьшите `lambda_L1`
|
||||||
|
|
||||||
|
3. **Визуализация**:
|
||||||
|
- Регулярно генерируйте примеры для визуальной оценки
|
||||||
|
- Используйте TensorBoard для отслеживания прогресса
|
||||||
|
|
||||||
|
### Устранение проблем
|
||||||
|
|
||||||
|
#### Высокие потери генератора
|
||||||
|
- Уменьшите `lambda_L1`
|
||||||
|
- Увеличьте learning rate
|
||||||
|
- Проверьте качество данных
|
||||||
|
|
||||||
|
#### Дискриминатор слишком сильный
|
||||||
|
- Уменьшите learning rate дискриминатора
|
||||||
|
- Добавьте dropout в дискриминатор
|
||||||
|
- Обучайте генератор чаще, чем дискриминатор
|
||||||
|
|
||||||
|
#### Недостаток памяти GPU
|
||||||
|
- Уменьшите `batch_size`
|
||||||
|
- Уменьшите размер изображений
|
||||||
|
- Используйте gradient accumulation
|
||||||
|
|
||||||
|
### Лицензия
|
||||||
|
|
||||||
|
Этот проект является частью Autopilot системы.
|
||||||
1777
models/GAN/gan.ipynb
Normal file
1777
models/GAN/gan.ipynb
Normal file
File diff suppressed because one or more lines are too long
393
models/GAN/gan.py
Normal file
393
models/GAN/gan.py
Normal file
@@ -0,0 +1,393 @@
|
|||||||
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
|
||||||
|
class UNetDownBlock(nn.Module):
|
||||||
|
"""Блок downsampling для U-Net генератора"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels: int,
|
||||||
|
out_channels: int,
|
||||||
|
normalize: bool = True,
|
||||||
|
dropout: float = 0.0,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
layers = [
|
||||||
|
nn.Conv2d(
|
||||||
|
in_channels,
|
||||||
|
out_channels,
|
||||||
|
kernel_size=4,
|
||||||
|
stride=2,
|
||||||
|
padding=1,
|
||||||
|
bias=False,
|
||||||
|
)
|
||||||
|
]
|
||||||
|
if normalize:
|
||||||
|
layers.append(nn.BatchNorm2d(out_channels))
|
||||||
|
layers.append(nn.LeakyReLU(0.2, inplace=True))
|
||||||
|
if dropout > 0:
|
||||||
|
layers.append(nn.Dropout2d(dropout))
|
||||||
|
self.model = nn.Sequential(*layers)
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
return self.model(x)
|
||||||
|
|
||||||
|
|
||||||
|
class UNetUpBlock(nn.Module):
|
||||||
|
"""Блок upsampling для U-Net генератора"""
|
||||||
|
|
||||||
|
def __init__(self, in_channels: int, out_channels: int, dropout: float = 0.0):
|
||||||
|
super().__init__()
|
||||||
|
layers = [
|
||||||
|
nn.ConvTranspose2d(
|
||||||
|
in_channels,
|
||||||
|
out_channels,
|
||||||
|
kernel_size=4,
|
||||||
|
stride=2,
|
||||||
|
padding=1,
|
||||||
|
bias=False,
|
||||||
|
),
|
||||||
|
nn.BatchNorm2d(out_channels),
|
||||||
|
nn.ReLU(inplace=True),
|
||||||
|
]
|
||||||
|
if dropout > 0:
|
||||||
|
layers.append(nn.Dropout2d(dropout))
|
||||||
|
self.model = nn.Sequential(*layers)
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor, skip_input: torch.Tensor) -> torch.Tensor:
|
||||||
|
x = self.model(x)
|
||||||
|
# Обрезаем skip connection до размера x, если необходимо
|
||||||
|
if x.shape != skip_input.shape:
|
||||||
|
diffY = skip_input.size(2) - x.size(2)
|
||||||
|
diffX = skip_input.size(3) - x.size(3)
|
||||||
|
x = F.pad(
|
||||||
|
x, [diffX // 2, diffX - diffX // 2, diffY // 2, diffY - diffY // 2]
|
||||||
|
)
|
||||||
|
x = torch.cat([x, skip_input], dim=1)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class GeneratorUNet(nn.Module):
|
||||||
|
"""Генератор на основе U-Net архитектуры для преобразования Yandex → Google"""
|
||||||
|
|
||||||
|
def __init__(self, in_channels: int = 3, out_channels: int = 3):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
# Downsampling path
|
||||||
|
self.down1 = UNetDownBlock(in_channels, 64, normalize=False)
|
||||||
|
self.down2 = UNetDownBlock(64, 128)
|
||||||
|
self.down3 = UNetDownBlock(128, 256)
|
||||||
|
self.down4 = UNetDownBlock(256, 512)
|
||||||
|
self.down5 = UNetDownBlock(512, 512)
|
||||||
|
self.down6 = UNetDownBlock(512, 512)
|
||||||
|
self.down7 = UNetDownBlock(512, 512)
|
||||||
|
|
||||||
|
# Bottleneck
|
||||||
|
self.bottleneck = nn.Sequential(
|
||||||
|
nn.Conv2d(512, 512, kernel_size=4, stride=2, padding=1, bias=False),
|
||||||
|
nn.ReLU(inplace=True),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Upsampling path
|
||||||
|
self.up1 = UNetUpBlock(512, 512, dropout=0.5)
|
||||||
|
self.up2 = UNetUpBlock(1024, 512, dropout=0.5)
|
||||||
|
self.up3 = UNetUpBlock(1024, 512, dropout=0.5)
|
||||||
|
self.up4 = UNetUpBlock(1024, 512)
|
||||||
|
self.up5 = UNetUpBlock(1024, 256)
|
||||||
|
self.up6 = UNetUpBlock(512, 128)
|
||||||
|
self.up7 = UNetUpBlock(256, 64)
|
||||||
|
|
||||||
|
# Final layer
|
||||||
|
self.final = nn.Sequential(
|
||||||
|
nn.ConvTranspose2d(128, out_channels, kernel_size=4, stride=2, padding=1),
|
||||||
|
nn.Tanh(),
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
# Downsampling
|
||||||
|
d1 = self.down1(x) # 350x350
|
||||||
|
d2 = self.down2(d1) # 175x175
|
||||||
|
d3 = self.down3(d2) # 88x88
|
||||||
|
d4 = self.down4(d3) # 44x44
|
||||||
|
d5 = self.down5(d4) # 22x22
|
||||||
|
d6 = self.down6(d5) # 11x11
|
||||||
|
d7 = self.down7(d6) # 6x6
|
||||||
|
|
||||||
|
# Bottleneck
|
||||||
|
u = self.bottleneck(d7) # 3x3
|
||||||
|
|
||||||
|
# Upsampling with skip connections
|
||||||
|
u = self.up1(u, d7) # 6x6
|
||||||
|
u = self.up2(u, d6) # 11x11
|
||||||
|
u = self.up3(u, d5) # 22x22
|
||||||
|
u = self.up4(u, d4) # 44x44
|
||||||
|
u = self.up5(u, d3) # 88x88
|
||||||
|
u = self.up6(u, d2) # 175x175
|
||||||
|
u = self.up7(u, d1) # 350x350
|
||||||
|
|
||||||
|
# Final output
|
||||||
|
return self.final(u) # 700x700
|
||||||
|
|
||||||
|
|
||||||
|
class DiscriminatorPatchGAN(nn.Module):
|
||||||
|
"""Дискриминатор PatchGAN для изображений 700x700"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self, in_channels: int = 6
|
||||||
|
): # 3 для реального + 3 для сгенерированного
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
def discriminator_block(
|
||||||
|
in_filters: int, out_filters: int, normalization: bool = True
|
||||||
|
):
|
||||||
|
"""Блок дискриминатора"""
|
||||||
|
layers = [
|
||||||
|
nn.Conv2d(in_filters, out_filters, kernel_size=4, stride=2, padding=1)
|
||||||
|
]
|
||||||
|
if normalization:
|
||||||
|
layers.append(nn.BatchNorm2d(out_filters))
|
||||||
|
layers.append(nn.LeakyReLU(0.2, inplace=True))
|
||||||
|
return layers
|
||||||
|
|
||||||
|
self.model = nn.Sequential(
|
||||||
|
*discriminator_block(in_channels, 64, normalization=False), # 350x350
|
||||||
|
*discriminator_block(64, 128), # 175x175
|
||||||
|
*discriminator_block(128, 256), # 88x88
|
||||||
|
*discriminator_block(256, 512), # 44x44
|
||||||
|
nn.Conv2d(512, 1, kernel_size=4, stride=1, padding=1), # 41x41
|
||||||
|
nn.Sigmoid(),
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, img_A: torch.Tensor, img_B: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Принимает пару изображений (реальное и сгенерированное)
|
||||||
|
и возвращает вероятность того, что пара реальная
|
||||||
|
"""
|
||||||
|
# Объединяем два изображения по каналам
|
||||||
|
img_input = torch.cat((img_A, img_B), 1)
|
||||||
|
return self.model(img_input)
|
||||||
|
|
||||||
|
|
||||||
|
class GANLoss(nn.Module):
|
||||||
|
"""Функция потерь для GAN"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
gan_mode: str = "vanilla",
|
||||||
|
target_real_label: float = 1.0,
|
||||||
|
target_fake_label: float = 0.0,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.register_buffer("real_label", torch.tensor(target_real_label))
|
||||||
|
self.register_buffer("fake_label", torch.tensor(target_fake_label))
|
||||||
|
self.gan_mode = gan_mode
|
||||||
|
|
||||||
|
if gan_mode == "vanilla":
|
||||||
|
self.loss = nn.BCEWithLogitsLoss()
|
||||||
|
elif gan_mode == "lsgan":
|
||||||
|
self.loss = nn.MSELoss()
|
||||||
|
elif gan_mode == "wgangp":
|
||||||
|
self.loss = None
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(f"GAN mode {gan_mode} not implemented")
|
||||||
|
|
||||||
|
def get_target_tensor(
|
||||||
|
self, prediction: torch.Tensor, target_is_real: bool
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""Создает тензор меток"""
|
||||||
|
if target_is_real:
|
||||||
|
target_tensor = self.real_label
|
||||||
|
else:
|
||||||
|
target_tensor = self.fake_label
|
||||||
|
return target_tensor.expand_as(prediction)
|
||||||
|
|
||||||
|
def __call__(self, prediction: torch.Tensor, target_is_real: bool) -> torch.Tensor:
|
||||||
|
"""Вычисляет потери"""
|
||||||
|
if self.gan_mode in ["vanilla", "lsgan"]:
|
||||||
|
target_tensor = self.get_target_tensor(prediction, target_is_real)
|
||||||
|
loss = self.loss(prediction, target_tensor)
|
||||||
|
elif self.gan_mode == "wgangp":
|
||||||
|
if target_is_real:
|
||||||
|
loss = -prediction.mean()
|
||||||
|
else:
|
||||||
|
loss = prediction.mean()
|
||||||
|
return loss
|
||||||
|
|
||||||
|
|
||||||
|
class ImageGAN(nn.Module):
|
||||||
|
"""Основной класс GAN для преобразования изображений Yandex → Google"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
input_channels: int = 3,
|
||||||
|
output_channels: int = 3,
|
||||||
|
gan_mode: str = "vanilla",
|
||||||
|
lambda_L1: float = 100.0,
|
||||||
|
use_cuda: bool = True,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.generator = GeneratorUNet(input_channels, output_channels)
|
||||||
|
self.discriminator = DiscriminatorPatchGAN(input_channels + output_channels)
|
||||||
|
self.gan_loss = GANLoss(gan_mode)
|
||||||
|
self.l1_loss = nn.L1Loss()
|
||||||
|
self.lambda_L1 = lambda_L1
|
||||||
|
|
||||||
|
self.device = torch.device(
|
||||||
|
"cuda" if use_cuda and torch.cuda.is_available() else "cpu"
|
||||||
|
)
|
||||||
|
self.to(self.device)
|
||||||
|
|
||||||
|
def forward(self, yandex_image: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""Генерация изображения Google из Yandex"""
|
||||||
|
return self.generator(yandex_image)
|
||||||
|
|
||||||
|
def generator_step(
|
||||||
|
self, yandex_image: torch.Tensor, real_google_image: torch.Tensor
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||||
|
"""
|
||||||
|
Шаг обучения генератора
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
total_loss: общие потери генератора
|
||||||
|
gan_loss: потери GAN
|
||||||
|
l1_loss: потери L1
|
||||||
|
"""
|
||||||
|
# Генерируем изображение
|
||||||
|
fake_google_image = self.generator(yandex_image)
|
||||||
|
|
||||||
|
# Оцениваем дискриминатором
|
||||||
|
fake_pred = self.discriminator(yandex_image, fake_google_image)
|
||||||
|
|
||||||
|
# Потери GAN (пытаемся обмануть дискриминатор)
|
||||||
|
gan_loss = self.gan_loss(fake_pred, True)
|
||||||
|
|
||||||
|
# Потери L1 для сохранения структуры
|
||||||
|
l1_loss = self.l1_loss(fake_google_image, real_google_image) * self.lambda_L1
|
||||||
|
|
||||||
|
# Общие потери
|
||||||
|
total_loss = gan_loss + l1_loss
|
||||||
|
|
||||||
|
return total_loss, gan_loss, l1_loss
|
||||||
|
|
||||||
|
def discriminator_step(
|
||||||
|
self,
|
||||||
|
yandex_image: torch.Tensor,
|
||||||
|
real_google_image: torch.Tensor,
|
||||||
|
fake_google_image: torch.Tensor,
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||||
|
"""
|
||||||
|
Шаг обучения дискриминатора
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
total_loss: общие потери дискриминатора
|
||||||
|
real_loss: потери на реальных изображениях
|
||||||
|
fake_loss: потери на сгенерированных изображениях
|
||||||
|
"""
|
||||||
|
# Предсказания для реальных пар
|
||||||
|
real_pred = self.discriminator(yandex_image, real_google_image)
|
||||||
|
real_loss = self.gan_loss(real_pred, True)
|
||||||
|
|
||||||
|
# Предсказания для сгенерированных пар
|
||||||
|
fake_pred = self.discriminator(yandex_image, fake_google_image.detach())
|
||||||
|
fake_loss = self.gan_loss(fake_pred, False)
|
||||||
|
|
||||||
|
# Общие потери дискриминатора
|
||||||
|
total_loss = (real_loss + fake_loss) * 0.5
|
||||||
|
|
||||||
|
return total_loss, real_loss, fake_loss
|
||||||
|
|
||||||
|
def to(self, device):
|
||||||
|
"""Перемещает модель на устройство"""
|
||||||
|
self.generator.to(device)
|
||||||
|
self.discriminator.to(device)
|
||||||
|
return self
|
||||||
|
|
||||||
|
def train_mode(self):
|
||||||
|
"""Переключает модель в режим обучения"""
|
||||||
|
self.generator.train()
|
||||||
|
self.discriminator.train()
|
||||||
|
|
||||||
|
def eval_mode(self):
|
||||||
|
"""Переключает модель в режим оценки"""
|
||||||
|
self.generator.eval()
|
||||||
|
self.discriminator.eval()
|
||||||
|
|
||||||
|
def save_checkpoint(self, path: str):
|
||||||
|
"""Сохраняет чекпоинт модели"""
|
||||||
|
checkpoint = {
|
||||||
|
"generator_state_dict": self.generator.state_dict(),
|
||||||
|
"discriminator_state_dict": self.discriminator.state_dict(),
|
||||||
|
"generator_optimizer_state_dict": getattr(
|
||||||
|
self.generator, "optimizer_state_dict", None
|
||||||
|
),
|
||||||
|
"discriminator_optimizer_state_dict": getattr(
|
||||||
|
self.discriminator, "optimizer_state_dict", None
|
||||||
|
),
|
||||||
|
}
|
||||||
|
torch.save(checkpoint, path)
|
||||||
|
|
||||||
|
def load_checkpoint(self, path: str):
|
||||||
|
"""Загружает чекпоинт модели"""
|
||||||
|
checkpoint = torch.load(path, map_location=self.device)
|
||||||
|
self.generator.load_state_dict(checkpoint["generator_state_dict"])
|
||||||
|
self.discriminator.load_state_dict(checkpoint["discriminator_state_dict"])
|
||||||
|
|
||||||
|
if checkpoint["generator_optimizer_state_dict"] is not None:
|
||||||
|
self.generator.optimizer_state_dict = checkpoint[
|
||||||
|
"generator_optimizer_state_dict"
|
||||||
|
]
|
||||||
|
if checkpoint["discriminator_optimizer_state_dict"] is not None:
|
||||||
|
self.discriminator.optimizer_state_dict = checkpoint[
|
||||||
|
"discriminator_optimizer_state_dict"
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def create_image_gan(
|
||||||
|
input_channels: int = 3,
|
||||||
|
output_channels: int = 3,
|
||||||
|
gan_mode: str = "vanilla",
|
||||||
|
lambda_L1: float = 100.0,
|
||||||
|
use_cuda: bool = True,
|
||||||
|
) -> ImageGAN:
|
||||||
|
"""
|
||||||
|
Создает и возвращает модель GAN для преобразования изображений
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input_channels: количество входных каналов (обычно 3 для RGB)
|
||||||
|
output_channels: количество выходных каналов (обычно 3 для RGB)
|
||||||
|
gan_mode: режим GAN ('vanilla', 'lsgan', 'wgangp')
|
||||||
|
lambda_L1: вес L1 потерь
|
||||||
|
use_cuda: использовать ли CUDA если доступно
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ImageGAN: модель GAN
|
||||||
|
"""
|
||||||
|
return ImageGAN(
|
||||||
|
input_channels=input_channels,
|
||||||
|
output_channels=output_channels,
|
||||||
|
gan_mode=gan_mode,
|
||||||
|
lambda_L1=lambda_L1,
|
||||||
|
use_cuda=use_cuda,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Вспомогательные функции для инициализации весов
|
||||||
|
def weights_init_normal(m):
|
||||||
|
"""Инициализация весов с нормальным распределением"""
|
||||||
|
classname = m.__class__.__name__
|
||||||
|
if classname.find("Conv") != -1:
|
||||||
|
nn.init.normal_(m.weight.data, 0.0, 0.02)
|
||||||
|
elif classname.find("BatchNorm") != -1:
|
||||||
|
nn.init.normal_(m.weight.data, 1.0, 0.02)
|
||||||
|
nn.init.constant_(m.batch_norm.bias.data, 0.0)
|
||||||
|
|
||||||
|
|
||||||
|
def initialize_gan_weights(generator: nn.Module, discriminator: nn.Module):
|
||||||
|
"""Инициализирует веса генератора и дискриминатора"""
|
||||||
|
generator.apply(weights_init_normal)
|
||||||
|
discriminator.apply(weights_init_normal)
|
||||||
136
models/GAN/minimal_example.py
Normal file
136
models/GAN/minimal_example.py
Normal file
@@ -0,0 +1,136 @@
|
|||||||
|
"""
|
||||||
|
Минимальный пример использования GAN trainer для преобразования Yandex → Google карт.
|
||||||
|
|
||||||
|
Этот пример показывает самый простой способ использования тренера.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch.utils.data import DataLoader, Dataset
|
||||||
|
|
||||||
|
# Добавляем путь к модулям
|
||||||
|
sys.path.append(str(Path(__file__).parent.parent.parent))
|
||||||
|
|
||||||
|
from models.GAN.gan import create_image_gan
|
||||||
|
from models.GAN.trainer import GANTrainer
|
||||||
|
|
||||||
|
|
||||||
|
class SimpleMapDataset(Dataset):
|
||||||
|
"""Простой датасет с фиктивными данными для примера."""
|
||||||
|
|
||||||
|
def __init__(self, num_samples=100, image_size=(256, 256)):
|
||||||
|
self.num_samples = num_samples
|
||||||
|
self.image_size = image_size
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return self.num_samples
|
||||||
|
|
||||||
|
def __getitem__(self, idx):
|
||||||
|
# Создаем фиктивные изображения
|
||||||
|
# В реальном коде замените на загрузку реальных изображений
|
||||||
|
yandex_img = torch.randn(3, self.image_size[0], self.image_size[1])
|
||||||
|
google_img = torch.randn(3, self.image_size[0], self.image_size[1])
|
||||||
|
|
||||||
|
return {"yandex_img": yandex_img, "google_img": google_img}
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
"""Основная функция минимального примера."""
|
||||||
|
print("Минимальный пример использования GAN trainer")
|
||||||
|
print("=" * 50)
|
||||||
|
|
||||||
|
# 1. Конфигурация (минимальный набор параметров)
|
||||||
|
config = {
|
||||||
|
"learning_rate": 2e-4,
|
||||||
|
"batch_size": 4,
|
||||||
|
"output_dir": "runs/gan_minimal",
|
||||||
|
}
|
||||||
|
|
||||||
|
# 2. Устройство (CPU или GPU)
|
||||||
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
print(f"Используемое устройство: {device}")
|
||||||
|
|
||||||
|
# 3. Создание модели
|
||||||
|
print("\nСоздание GAN модели...")
|
||||||
|
model = create_image_gan(
|
||||||
|
input_channels=3,
|
||||||
|
output_channels=3,
|
||||||
|
gan_mode="vanilla", # Простейший режим
|
||||||
|
lambda_L1=100.0, # Стандартный вес L1 потерь
|
||||||
|
use_cuda=(device.type == "cuda"),
|
||||||
|
)
|
||||||
|
|
||||||
|
# 4. Создание даталоадеров
|
||||||
|
print("Создание даталоадеров...")
|
||||||
|
train_dataset = SimpleMapDataset(num_samples=50)
|
||||||
|
val_dataset = SimpleMapDataset(num_samples=10)
|
||||||
|
|
||||||
|
train_loader = DataLoader(
|
||||||
|
train_dataset,
|
||||||
|
batch_size=config["batch_size"],
|
||||||
|
shuffle=True,
|
||||||
|
num_workers=0,
|
||||||
|
)
|
||||||
|
|
||||||
|
val_loader = DataLoader(
|
||||||
|
val_dataset,
|
||||||
|
batch_size=config["batch_size"],
|
||||||
|
shuffle=False,
|
||||||
|
num_workers=0,
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f" Обучающих примеров: {len(train_dataset)}")
|
||||||
|
print(f" Валидационных примеров: {len(val_dataset)}")
|
||||||
|
|
||||||
|
# 5. Создание тренера
|
||||||
|
print("\nСоздание тренера...")
|
||||||
|
trainer = GANTrainer(
|
||||||
|
model=model,
|
||||||
|
train_loader=train_loader,
|
||||||
|
val_loader=val_loader,
|
||||||
|
device=device,
|
||||||
|
config=config,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 6. Обучение на небольшом количестве эпох
|
||||||
|
print("\nЗапуск обучения (3 эпохи для примера)...")
|
||||||
|
print("=" * 50)
|
||||||
|
|
||||||
|
trainer.train(num_epochs=3)
|
||||||
|
|
||||||
|
# 7. Генерация примеров
|
||||||
|
print("\nГенерация примеров преобразования...")
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
# Создаем тестовые данные
|
||||||
|
test_yandex = torch.randn(2, 3, 256, 256).to(device)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
generated_google = model(test_yandex)
|
||||||
|
|
||||||
|
print(f"Входные изображения: {test_yandex.shape}")
|
||||||
|
print(f"Сгенерированные изображения: {generated_google.shape}")
|
||||||
|
print(
|
||||||
|
f"Диапазон значений: [{generated_google.min():.3f}, {generated_google.max():.3f}]"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 8. Сохранение финальной модели
|
||||||
|
print("\nСохранение модели...")
|
||||||
|
model_save_path = "gan_model_minimal.pth"
|
||||||
|
torch.save(model.state_dict(), model_save_path)
|
||||||
|
print(f"Модель сохранена в: {model_save_path}")
|
||||||
|
|
||||||
|
print("\n" + "=" * 50)
|
||||||
|
print("Минимальный пример завершен!")
|
||||||
|
print("\nДля реального использования:")
|
||||||
|
print("1. Замените SimpleMapDataset на ваш реальный датасет")
|
||||||
|
print("2. Настройте параметры в config")
|
||||||
|
print("3. Увеличьте количество эпох (например, до 100)")
|
||||||
|
print("4. Используйте реальные изображения карт")
|
||||||
|
print("=" * 50)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
349
models/GAN/test_gan.py
Normal file
349
models/GAN/test_gan.py
Normal file
@@ -0,0 +1,349 @@
|
|||||||
|
import os
|
||||||
|
import sys
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
# Добавляем путь к модулю
|
||||||
|
sys.path.append(
|
||||||
|
os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||||
|
)
|
||||||
|
|
||||||
|
from gan import (
|
||||||
|
DiscriminatorPatchGAN,
|
||||||
|
GeneratorUNet,
|
||||||
|
ImageGAN,
|
||||||
|
create_image_gan,
|
||||||
|
initialize_gan_weights,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_generator():
|
||||||
|
"""Тестирование генератора"""
|
||||||
|
print("=" * 60)
|
||||||
|
print("Тестирование генератора...")
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
# Создаем генератор
|
||||||
|
generator = GeneratorUNet(in_channels=3, out_channels=3)
|
||||||
|
|
||||||
|
# Инициализируем веса
|
||||||
|
generator.apply(
|
||||||
|
lambda m: (
|
||||||
|
nn.init.normal_(m.weight.data, 0.0, 0.02)
|
||||||
|
if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d)
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Создаем тестовый входной тензор (Yandex изображение)
|
||||||
|
batch_size = 2
|
||||||
|
height, width = 700, 700
|
||||||
|
yandex_image = torch.randn(batch_size, 3, height, width)
|
||||||
|
|
||||||
|
print(f"Размер входного изображения: {yandex_image.shape}")
|
||||||
|
print(
|
||||||
|
f"Количество параметров генератора: {sum(p.numel() for p in generator.parameters()):,}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Прямой проход
|
||||||
|
with torch.no_grad():
|
||||||
|
generated_image = generator(yandex_image)
|
||||||
|
|
||||||
|
print(f"Размер сгенерированного изображения: {generated_image.shape}")
|
||||||
|
print(
|
||||||
|
f"Диапазон значений сгенерированного изображения: [{generated_image.min():.3f}, {generated_image.max():.3f}]"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Проверка размеров
|
||||||
|
assert generated_image.shape == (batch_size, 3, height, width), (
|
||||||
|
f"Ожидался размер {(batch_size, 3, height, width)}, получен {generated_image.shape}"
|
||||||
|
)
|
||||||
|
|
||||||
|
print("✓ Генератор работает корректно!")
|
||||||
|
return generator
|
||||||
|
|
||||||
|
|
||||||
|
def test_discriminator():
|
||||||
|
"""Тестирование дискриминатора"""
|
||||||
|
print("\n" + "=" * 60)
|
||||||
|
print("Тестирование дискриминатора...")
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
# Создаем дискриминатор
|
||||||
|
discriminator = DiscriminatorPatchGAN(in_channels=6) # 3 + 3 канала
|
||||||
|
|
||||||
|
# Инициализируем веса
|
||||||
|
discriminator.apply(
|
||||||
|
lambda m: (
|
||||||
|
nn.init.normal_(m.weight.data, 0.0, 0.02)
|
||||||
|
if isinstance(m, nn.Conv2d)
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Создаем тестовые тензоры
|
||||||
|
batch_size = 2
|
||||||
|
height, width = 700, 700
|
||||||
|
|
||||||
|
yandex_image = torch.randn(batch_size, 3, height, width)
|
||||||
|
google_image = torch.randn(batch_size, 3, height, width)
|
||||||
|
|
||||||
|
print(f"Размер Yandex изображения: {yandex_image.shape}")
|
||||||
|
print(f"Размер Google изображения: {google_image.shape}")
|
||||||
|
print(
|
||||||
|
f"Количество параметров дискриминатора: {sum(p.numel() for p in discriminator.parameters()):,}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Прямой проход
|
||||||
|
with torch.no_grad():
|
||||||
|
prediction = discriminator(yandex_image, google_image)
|
||||||
|
|
||||||
|
print(f"Размер выхода дискриминатора: {prediction.shape}")
|
||||||
|
print(
|
||||||
|
f"Диапазон значений предсказания: [{prediction.min():.3f}, {prediction.max():.3f}]"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Проверка размеров (PatchGAN выдает карту вероятностей)
|
||||||
|
expected_height = 43 # Для изображения 700x700 после 4 downsampling блоков
|
||||||
|
expected_width = 43
|
||||||
|
assert prediction.shape == (batch_size, 1, expected_height, expected_width), (
|
||||||
|
f"Ожидался размер {(batch_size, 1, expected_height, expected_width)}, получен {prediction.shape}"
|
||||||
|
)
|
||||||
|
|
||||||
|
print(
|
||||||
|
f"✓ Дискриминатор работает корректно! Выходной размер: {prediction.shape[2]}x{prediction.shape[3]}"
|
||||||
|
)
|
||||||
|
return discriminator
|
||||||
|
|
||||||
|
|
||||||
|
def test_gan_model():
|
||||||
|
"""Тестирование полной GAN модели"""
|
||||||
|
print("\n" + "=" * 60)
|
||||||
|
print("Тестирование полной GAN модели...")
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
# Создаем GAN модель
|
||||||
|
gan = ImageGAN(
|
||||||
|
input_channels=3,
|
||||||
|
output_channels=3,
|
||||||
|
gan_mode="vanilla",
|
||||||
|
lambda_L1=100.0,
|
||||||
|
use_cuda=False, # Тестируем на CPU для простоты
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"Устройство модели: {gan.device}")
|
||||||
|
print(
|
||||||
|
f"Количество параметров генератора: {sum(p.numel() for p in gan.generator.parameters()):,}"
|
||||||
|
)
|
||||||
|
print(
|
||||||
|
f"Количество параметров дискриминатора: {sum(p.numel() for p in gan.discriminator.parameters()):,}"
|
||||||
|
)
|
||||||
|
print(f"Общее количество параметров: {sum(p.numel() for p in gan.parameters()):,}")
|
||||||
|
|
||||||
|
# Создаем тестовые данные
|
||||||
|
batch_size = 2
|
||||||
|
height, width = 700, 700
|
||||||
|
|
||||||
|
yandex_image = torch.randn(batch_size, 3, height, width)
|
||||||
|
real_google_image = torch.randn(batch_size, 3, height, width)
|
||||||
|
|
||||||
|
print(f"\nТестирование прямого прохода...")
|
||||||
|
with torch.no_grad():
|
||||||
|
generated_image = gan(yandex_image)
|
||||||
|
|
||||||
|
print(f"Размер сгенерированного изображения: {generated_image.shape}")
|
||||||
|
|
||||||
|
print(f"\nТестирование шага генератора...")
|
||||||
|
gan.train_mode()
|
||||||
|
|
||||||
|
# Тестируем шаг генератора
|
||||||
|
total_loss, gan_loss, l1_loss = gan.generator_step(yandex_image, real_google_image)
|
||||||
|
|
||||||
|
print(f"Общие потери генератора: {total_loss.item():.6f}")
|
||||||
|
print(f"Потери GAN: {gan_loss.item():.6f}")
|
||||||
|
print(f"Потери L1: {l1_loss.item():.6f}")
|
||||||
|
|
||||||
|
print(f"\nТестирование шага дискриминатора...")
|
||||||
|
# Создаем сгенерированное изображение для дискриминатора
|
||||||
|
with torch.no_grad():
|
||||||
|
fake_google_image = gan.generator(yandex_image)
|
||||||
|
|
||||||
|
total_d_loss, real_loss, fake_loss = gan.discriminator_step(
|
||||||
|
yandex_image, real_google_image, fake_google_image
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"Общие потери дискриминатора: {total_d_loss.item():.6f}")
|
||||||
|
print(f"Потери на реальных изображениях: {real_loss.item():.6f}")
|
||||||
|
print(f"Потери на сгенерированных изображениях: {fake_loss.item():.6f}")
|
||||||
|
|
||||||
|
print(f"\nТестирование режимов обучения/оценки...")
|
||||||
|
gan.eval_mode()
|
||||||
|
print(f"Генератор в режиме eval: {not gan.generator.training}")
|
||||||
|
print(f"Дискриминатор в режиме eval: {not gan.discriminator.training}")
|
||||||
|
|
||||||
|
gan.train_mode()
|
||||||
|
print(f"Генератор в режиме train: {gan.generator.training}")
|
||||||
|
print(f"Дискриминатор в режиме train: {gan.discriminator.training}")
|
||||||
|
|
||||||
|
print("\n✓ Полная GAN модель работает корректно!")
|
||||||
|
return gan
|
||||||
|
|
||||||
|
|
||||||
|
def test_factory_function():
|
||||||
|
"""Тестирование фабричной функции"""
|
||||||
|
print("\n" + "=" * 60)
|
||||||
|
print("Тестирование фабричной функции...")
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
# Тестируем разные режимы GAN
|
||||||
|
for gan_mode in ["vanilla", "lsgan"]:
|
||||||
|
print(f"\nСоздание GAN в режиме '{gan_mode}'...")
|
||||||
|
gan = create_image_gan(
|
||||||
|
input_channels=3,
|
||||||
|
output_channels=3,
|
||||||
|
gan_mode=gan_mode,
|
||||||
|
lambda_L1=100.0,
|
||||||
|
use_cuda=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f" Режим GAN: {gan.gan_loss.gan_mode}")
|
||||||
|
print(f" Вес L1 потерь: {gan.lambda_L1}")
|
||||||
|
print(f" Устройство: {gan.device}")
|
||||||
|
|
||||||
|
# Быстрая проверка прямого прохода
|
||||||
|
batch_size = 1
|
||||||
|
yandex_image = torch.randn(batch_size, 3, 700, 700)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
generated = gan(yandex_image)
|
||||||
|
|
||||||
|
print(f" Размер выхода: {generated.shape}")
|
||||||
|
print(f" ✓ GAN в режиме '{gan_mode}' создан успешно")
|
||||||
|
|
||||||
|
print("\n✓ Фабричная функция работает корректно!")
|
||||||
|
|
||||||
|
|
||||||
|
def test_weights_initialization():
|
||||||
|
"""Тестирование инициализации весов"""
|
||||||
|
print("\n" + "=" * 60)
|
||||||
|
print("Тестирование инициализации весов...")
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
# Создаем модели
|
||||||
|
generator = GeneratorUNet(3, 3)
|
||||||
|
discriminator = DiscriminatorPatchGAN(6)
|
||||||
|
|
||||||
|
# Инициализируем веса
|
||||||
|
initialize_gan_weights(generator, discriminator)
|
||||||
|
|
||||||
|
# Проверяем средние значения весов
|
||||||
|
def check_weights_mean(model, model_name):
|
||||||
|
conv_weights = []
|
||||||
|
for name, param in model.named_parameters():
|
||||||
|
if "weight" in name and (
|
||||||
|
"conv" in name.lower() or "Conv" in str(param.__class__)
|
||||||
|
):
|
||||||
|
conv_weights.append(param.data.mean().item())
|
||||||
|
|
||||||
|
if conv_weights:
|
||||||
|
avg_mean = sum(conv_weights) / len(conv_weights)
|
||||||
|
print(f" Среднее значение весов Conv слоев в {model_name}: {avg_mean:.6f}")
|
||||||
|
# Проверяем, что веса инициализированы около 0
|
||||||
|
assert abs(avg_mean) < 0.1, f"Веса {model_name} не инициализированы около 0"
|
||||||
|
|
||||||
|
check_weights_mean(generator, "генераторе")
|
||||||
|
check_weights_mean(discriminator, "дискриминаторе")
|
||||||
|
|
||||||
|
print("✓ Инициализация весов работает корректно!")
|
||||||
|
|
||||||
|
|
||||||
|
def test_memory_usage():
|
||||||
|
"""Тестирование использования памяти"""
|
||||||
|
print("\n" + "=" * 60)
|
||||||
|
print("Тестирование использования памяти...")
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
import os
|
||||||
|
|
||||||
|
import psutil
|
||||||
|
|
||||||
|
# Получаем текущее использование памяти
|
||||||
|
process = psutil.Process(os.getpid())
|
||||||
|
memory_before = process.memory_info().rss / 1024 / 1024 # в MB
|
||||||
|
|
||||||
|
print(f"Память до создания моделей: {memory_before:.2f} MB")
|
||||||
|
|
||||||
|
# Создаем несколько моделей
|
||||||
|
models = []
|
||||||
|
for i in range(3):
|
||||||
|
gan = create_image_gan(use_cuda=False)
|
||||||
|
models.append(gan)
|
||||||
|
|
||||||
|
# Делаем тестовый проход
|
||||||
|
batch_size = 1
|
||||||
|
yandex_image = torch.randn(batch_size, 3, 700, 700)
|
||||||
|
real_google_image = torch.randn(batch_size, 3, 700, 700)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
_ = gan(yandex_image)
|
||||||
|
_ = gan.generator_step(yandex_image, real_google_image)
|
||||||
|
|
||||||
|
memory_after = process.memory_info().rss / 1024 / 1024 # в MB
|
||||||
|
memory_used = memory_after - memory_before
|
||||||
|
|
||||||
|
print(f"Память после создания моделей: {memory_after:.2f} MB")
|
||||||
|
print(f"Использовано памяти: {memory_used:.2f} MB")
|
||||||
|
|
||||||
|
# Очищаем модели
|
||||||
|
del models
|
||||||
|
import gc
|
||||||
|
|
||||||
|
gc.collect()
|
||||||
|
|
||||||
|
memory_final = process.memory_info().rss / 1024 / 1024
|
||||||
|
print(f"Память после очистки: {memory_final:.2f} MB")
|
||||||
|
|
||||||
|
print("✓ Тестирование памяти завершено!")
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
"""Основная функция тестирования"""
|
||||||
|
print("Начало тестирования GAN архитектуры для преобразования Yandex → Google")
|
||||||
|
print("Размер изображения: 700x700 пикселей")
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Запускаем все тесты
|
||||||
|
test_generator()
|
||||||
|
test_discriminator()
|
||||||
|
test_gan_model()
|
||||||
|
test_factory_function()
|
||||||
|
test_weights_initialization()
|
||||||
|
test_memory_usage()
|
||||||
|
|
||||||
|
print("\n" + "=" * 60)
|
||||||
|
print("ВСЕ ТЕСТЫ ПРОЙДЕНЫ УСПЕШНО! 🎉")
|
||||||
|
print("=" * 60)
|
||||||
|
print("\nАрхитектура GAN готова к использованию для преобразования")
|
||||||
|
print("изображений из стиля Yandex в стиль Google.")
|
||||||
|
print("\nОсновные характеристики:")
|
||||||
|
print(" • Генератор: U-Net архитектура")
|
||||||
|
print(" • Дискриминатор: PatchGAN (43x43 патчей)")
|
||||||
|
print(" • Размер входных/выходных изображений: 700x700")
|
||||||
|
print(" • Поддержка режимов: vanilla, lsgan")
|
||||||
|
print(" • L1 регуляризация для сохранения структуры")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"\n❌ Ошибка при тестировании: {e}")
|
||||||
|
import traceback
|
||||||
|
|
||||||
|
traceback.print_exc()
|
||||||
|
return 1
|
||||||
|
|
||||||
|
return 0
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
exit_code = main()
|
||||||
|
sys.exit(exit_code)
|
||||||
342
models/GAN/test_trainer.py
Normal file
342
models/GAN/test_trainer.py
Normal file
@@ -0,0 +1,342 @@
|
|||||||
|
"""
|
||||||
|
Тестовый скрипт для проверки GAN trainer.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from torch.utils.data import DataLoader, Dataset
|
||||||
|
|
||||||
|
# Добавляем путь к модулям
|
||||||
|
sys.path.append(str(Path(__file__).parent.parent.parent))
|
||||||
|
|
||||||
|
from models.GAN.gan import create_image_gan
|
||||||
|
from models.GAN.trainer import GANTrainer
|
||||||
|
|
||||||
|
|
||||||
|
class SimpleDataset(Dataset):
|
||||||
|
"""Простой датасет для тестирования."""
|
||||||
|
|
||||||
|
def __init__(self, num_samples=100, image_size=(256, 256)):
|
||||||
|
self.num_samples = num_samples
|
||||||
|
self.image_size = image_size
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return self.num_samples
|
||||||
|
|
||||||
|
def __getitem__(self, idx):
|
||||||
|
# Создаем случайные изображения для тестирования
|
||||||
|
yandex_img = torch.randn(3, self.image_size[0], self.image_size[1])
|
||||||
|
google_img = torch.randn(3, self.image_size[0], self.image_size[1])
|
||||||
|
|
||||||
|
return {"yandex_img": yandex_img, "google_img": google_img}
|
||||||
|
|
||||||
|
|
||||||
|
def test_gan_model():
|
||||||
|
"""Тестирование GAN модели."""
|
||||||
|
print("Тестирование GAN модели...")
|
||||||
|
|
||||||
|
# Создаем модель
|
||||||
|
model = create_image_gan(
|
||||||
|
input_channels=3,
|
||||||
|
output_channels=3,
|
||||||
|
gan_mode="vanilla",
|
||||||
|
lambda_L1=100.0,
|
||||||
|
use_cuda=False, # Используем CPU для тестирования
|
||||||
|
)
|
||||||
|
|
||||||
|
# Тестируем forward pass
|
||||||
|
batch_size = 2
|
||||||
|
image_size = (256, 256)
|
||||||
|
yandex_input = torch.randn(batch_size, 3, *image_size)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
output = model(yandex_input)
|
||||||
|
|
||||||
|
print(f"Входной размер: {yandex_input.shape}")
|
||||||
|
print(f"Выходной размер: {output.shape}")
|
||||||
|
print(f"Диапазон выходных значений: [{output.min():.3f}, {output.max():.3f}]")
|
||||||
|
|
||||||
|
# Проверяем, что выход в диапазоне [-1, 1] (из-за Tanh)
|
||||||
|
assert output.min() >= -1.0 and output.max() <= 1.0, "Выход не в диапазоне [-1, 1]"
|
||||||
|
print("✓ Forward pass работает корректно")
|
||||||
|
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def test_generator_step():
|
||||||
|
"""Тестирование шага генератора."""
|
||||||
|
print("\nТестирование шага генератора...")
|
||||||
|
|
||||||
|
model = create_image_gan(use_cuda=False)
|
||||||
|
model.train_mode()
|
||||||
|
|
||||||
|
batch_size = 2
|
||||||
|
yandex_img = torch.randn(batch_size, 3, 256, 256)
|
||||||
|
google_img = torch.randn(batch_size, 3, 256, 256)
|
||||||
|
|
||||||
|
# Тестируем generator_step
|
||||||
|
total_loss, gan_loss, l1_loss = model.generator_step(yandex_img, google_img)
|
||||||
|
|
||||||
|
print(f"Total loss: {total_loss.item():.6f}")
|
||||||
|
print(f"GAN loss: {gan_loss.item():.6f}")
|
||||||
|
print(f"L1 loss: {l1_loss.item():.6f}")
|
||||||
|
|
||||||
|
assert total_loss.item() > 0, "Потери должны быть положительными"
|
||||||
|
print("✓ Шаг генератора работает корректно")
|
||||||
|
|
||||||
|
|
||||||
|
def test_discriminator_step():
|
||||||
|
"""Тестирование шага дискриминатора."""
|
||||||
|
print("\nТестирование шага дискриминатора...")
|
||||||
|
|
||||||
|
model = create_image_gan(use_cuda=False)
|
||||||
|
model.train_mode()
|
||||||
|
|
||||||
|
batch_size = 2
|
||||||
|
yandex_img = torch.randn(batch_size, 3, 256, 256)
|
||||||
|
google_img = torch.randn(batch_size, 3, 256, 256)
|
||||||
|
|
||||||
|
# Генерируем fake изображение
|
||||||
|
with torch.no_grad():
|
||||||
|
fake_google_img = model.generator(yandex_img)
|
||||||
|
|
||||||
|
# Тестируем discriminator_step
|
||||||
|
total_loss, real_loss, fake_loss = model.discriminator_step(
|
||||||
|
yandex_img, google_img, fake_google_img
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"Total loss: {total_loss.item():.6f}")
|
||||||
|
print(f"Real loss: {real_loss.item():.6f}")
|
||||||
|
print(f"Fake loss: {fake_loss.item():.6f}")
|
||||||
|
|
||||||
|
assert total_loss.item() > 0, "Потери должны быть положительными"
|
||||||
|
print("✓ Шаг дискриминатора работает корректно")
|
||||||
|
|
||||||
|
|
||||||
|
def test_trainer_initialization():
|
||||||
|
"""Тестирование инициализации тренера."""
|
||||||
|
print("\nТестирование инициализации тренера...")
|
||||||
|
|
||||||
|
# Создаем модель
|
||||||
|
model = create_image_gan(use_cuda=False)
|
||||||
|
|
||||||
|
# Создаем даталоадеры
|
||||||
|
train_dataset = SimpleDataset(num_samples=50)
|
||||||
|
val_dataset = SimpleDataset(num_samples=10)
|
||||||
|
|
||||||
|
train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True)
|
||||||
|
val_loader = DataLoader(val_dataset, batch_size=2, shuffle=False)
|
||||||
|
|
||||||
|
# Конфигурация
|
||||||
|
config = {
|
||||||
|
"learning_rate": 2e-4,
|
||||||
|
"beta1": 0.5,
|
||||||
|
"beta2": 0.999,
|
||||||
|
"output_dir": "test_runs/gan",
|
||||||
|
"early_stopping_patience": 10,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Создаем тренер
|
||||||
|
device = torch.device("cpu")
|
||||||
|
trainer = GANTrainer(
|
||||||
|
model=model,
|
||||||
|
train_loader=train_loader,
|
||||||
|
val_loader=val_loader,
|
||||||
|
device=device,
|
||||||
|
config=config,
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"Тренер создан успешно")
|
||||||
|
print(f"Оптимизатор генератора: {type(trainer.optimizer_G).__name__}")
|
||||||
|
print(f"Оптимизатор дискриминатора: {type(trainer.optimizer_D).__name__}")
|
||||||
|
print(f"Выходная директория: {trainer.output_dir}")
|
||||||
|
|
||||||
|
assert trainer.output_dir.exists(), "Выходная директория не создана"
|
||||||
|
print("✓ Тренер инициализирован корректно")
|
||||||
|
|
||||||
|
return trainer, train_loader, val_loader
|
||||||
|
|
||||||
|
|
||||||
|
def test_train_epoch():
|
||||||
|
"""Тестирование одной эпохи обучения."""
|
||||||
|
print("\nТестирование одной эпохи обучения...")
|
||||||
|
|
||||||
|
model = create_image_gan(use_cuda=False)
|
||||||
|
train_dataset = SimpleDataset(num_samples=20)
|
||||||
|
train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True)
|
||||||
|
val_loader = DataLoader(SimpleDataset(num_samples=5), batch_size=2, shuffle=False)
|
||||||
|
|
||||||
|
config = {
|
||||||
|
"learning_rate": 2e-4,
|
||||||
|
"beta1": 0.5,
|
||||||
|
"beta2": 0.999,
|
||||||
|
"output_dir": "test_runs/gan",
|
||||||
|
}
|
||||||
|
|
||||||
|
device = torch.device("cpu")
|
||||||
|
trainer = GANTrainer(
|
||||||
|
model=model,
|
||||||
|
train_loader=train_loader,
|
||||||
|
val_loader=val_loader,
|
||||||
|
device=device,
|
||||||
|
config=config,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Запускаем одну эпоху обучения
|
||||||
|
avg_g_loss, avg_d_loss = trainer.train_epoch()
|
||||||
|
|
||||||
|
print(f"Средние потери за эпоху:")
|
||||||
|
print(f" Генератор: {avg_g_loss:.6f}")
|
||||||
|
print(f" Дискриминатор: {avg_d_loss:.6f}")
|
||||||
|
|
||||||
|
assert avg_g_loss > 0, "Потери генератора должны быть положительными"
|
||||||
|
assert avg_d_loss > 0, "Потери дискриминатора должны быть положительными"
|
||||||
|
print("✓ Эпоха обучения завершена успешно")
|
||||||
|
|
||||||
|
|
||||||
|
def test_validation():
|
||||||
|
"""Тестирование валидации."""
|
||||||
|
print("\nТестирование валидации...")
|
||||||
|
|
||||||
|
model = create_image_gan(use_cuda=False)
|
||||||
|
train_loader = DataLoader(SimpleDataset(num_samples=10), batch_size=2, shuffle=True)
|
||||||
|
val_loader = DataLoader(SimpleDataset(num_samples=5), batch_size=2, shuffle=False)
|
||||||
|
|
||||||
|
config = {
|
||||||
|
"learning_rate": 2e-4,
|
||||||
|
"beta1": 0.5,
|
||||||
|
"beta2": 0.999,
|
||||||
|
"output_dir": "test_runs/gan",
|
||||||
|
}
|
||||||
|
|
||||||
|
device = torch.device("cpu")
|
||||||
|
trainer = GANTrainer(
|
||||||
|
model=model,
|
||||||
|
train_loader=train_loader,
|
||||||
|
val_loader=val_loader,
|
||||||
|
device=device,
|
||||||
|
config=config,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Запускаем валидацию
|
||||||
|
val_g_loss, val_d_loss = trainer.validate()
|
||||||
|
|
||||||
|
print(f"Потери на валидации:")
|
||||||
|
print(f" Генератор: {val_g_loss:.6f}")
|
||||||
|
print(f" Дискриминатор: {val_d_loss:.6f}")
|
||||||
|
|
||||||
|
assert val_g_loss > 0, "Потери генератора должны быть положительными"
|
||||||
|
assert val_d_loss > 0, "Потери дискриминатора должны быть положительными"
|
||||||
|
print("✓ Валидация завершена успешно")
|
||||||
|
|
||||||
|
|
||||||
|
def test_checkpoint_saving():
|
||||||
|
"""Тестирование сохранения чекпоинтов."""
|
||||||
|
print("\nТестирование сохранения чекпоинтов...")
|
||||||
|
|
||||||
|
model = create_image_gan(use_cuda=False)
|
||||||
|
train_loader = DataLoader(SimpleDataset(num_samples=10), batch_size=2, shuffle=True)
|
||||||
|
val_loader = DataLoader(SimpleDataset(num_samples=5), batch_size=2, shuffle=False)
|
||||||
|
|
||||||
|
config = {
|
||||||
|
"learning_rate": 2e-4,
|
||||||
|
"beta1": 0.5,
|
||||||
|
"beta2": 0.999,
|
||||||
|
"output_dir": "test_runs/gan_checkpoint",
|
||||||
|
}
|
||||||
|
|
||||||
|
device = torch.device("cpu")
|
||||||
|
trainer = GANTrainer(
|
||||||
|
model=model,
|
||||||
|
train_loader=train_loader,
|
||||||
|
val_loader=val_loader,
|
||||||
|
device=device,
|
||||||
|
config=config,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Сохраняем чекпоинт
|
||||||
|
trainer.save_checkpoint(is_best=True)
|
||||||
|
|
||||||
|
# Проверяем, что файлы созданы
|
||||||
|
checkpoint_files = list(trainer.output_dir.glob("*.pth"))
|
||||||
|
print(f"Создано файлов чекпоинтов: {len(checkpoint_files)}")
|
||||||
|
|
||||||
|
for file in checkpoint_files:
|
||||||
|
print(f" - {file.name}")
|
||||||
|
|
||||||
|
assert len(checkpoint_files) > 0, "Файлы чекпоинтов не созданы"
|
||||||
|
print("✓ Чекпоинты сохранены успешно")
|
||||||
|
|
||||||
|
# Тестируем загрузку чекпоинта
|
||||||
|
checkpoint_path = checkpoint_files[0]
|
||||||
|
print(f"\nТестируем загрузку чекпоинта: {checkpoint_path}")
|
||||||
|
|
||||||
|
# Создаем новую модель и тренер
|
||||||
|
new_model = create_image_gan(use_cuda=False)
|
||||||
|
new_trainer = GANTrainer(
|
||||||
|
model=new_model,
|
||||||
|
train_loader=train_loader,
|
||||||
|
val_loader=val_loader,
|
||||||
|
device=device,
|
||||||
|
config=config,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Загружаем чекпоинт
|
||||||
|
new_trainer.load_checkpoint(str(checkpoint_path))
|
||||||
|
|
||||||
|
print(f"Загружен чекпоинт эпохи: {new_trainer.current_epoch + 1}")
|
||||||
|
print("✓ Чекпоинт загружен успешно")
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
"""Основная функция тестирования."""
|
||||||
|
print("=" * 60)
|
||||||
|
print("Начало тестирования GAN trainer")
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Запускаем все тесты
|
||||||
|
test_gan_model()
|
||||||
|
test_generator_step()
|
||||||
|
test_discriminator_step()
|
||||||
|
test_trainer_initialization()
|
||||||
|
test_train_epoch()
|
||||||
|
test_validation()
|
||||||
|
test_checkpoint_saving()
|
||||||
|
|
||||||
|
print("\n" + "=" * 60)
|
||||||
|
print("Все тесты пройдены успешно! ✓")
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"\nОшибка при тестировании: {e}")
|
||||||
|
import traceback
|
||||||
|
|
||||||
|
traceback.print_exc()
|
||||||
|
return 1
|
||||||
|
|
||||||
|
return 0
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
# Очищаем тестовые директории
|
||||||
|
import shutil
|
||||||
|
|
||||||
|
test_dirs = ["test_runs/gan", "test_runs/gan_checkpoint"]
|
||||||
|
for dir_path in test_dirs:
|
||||||
|
if Path(dir_path).exists():
|
||||||
|
shutil.rmtree(dir_path)
|
||||||
|
|
||||||
|
# Запускаем тесты
|
||||||
|
exit_code = main()
|
||||||
|
|
||||||
|
# Очищаем после тестов
|
||||||
|
for dir_path in test_dirs:
|
||||||
|
if Path(dir_path).exists():
|
||||||
|
shutil.rmtree(dir_path)
|
||||||
|
|
||||||
|
exit(exit_code)
|
||||||
347
models/GAN/train_example.py
Normal file
347
models/GAN/train_example.py
Normal file
@@ -0,0 +1,347 @@
|
|||||||
|
"""
|
||||||
|
Пример обучения GAN модели для преобразования Yandex → Google карт.
|
||||||
|
|
||||||
|
Этот скрипт показывает, как использовать GANTrainer для обучения модели.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
|
# Добавляем путь к модулям
|
||||||
|
sys.path.append(str(Path(__file__).parent.parent.parent))
|
||||||
|
|
||||||
|
from models.GAN.gan import create_image_gan
|
||||||
|
from models.GAN.trainer import GANTrainer
|
||||||
|
|
||||||
|
|
||||||
|
def create_simple_config():
|
||||||
|
"""Создает простую конфигурацию для обучения."""
|
||||||
|
config = {
|
||||||
|
# Параметры оптимизатора
|
||||||
|
"learning_rate": 2e-4,
|
||||||
|
"beta1": 0.5,
|
||||||
|
"beta2": 0.999,
|
||||||
|
# Параметры обучения
|
||||||
|
"batch_size": 4,
|
||||||
|
"epochs": 100,
|
||||||
|
# Параметры GAN
|
||||||
|
"gan_mode": "vanilla", # "vanilla", "lsgan", или "wgangp"
|
||||||
|
"lambda_L1": 100.0, # Вес L1 потерь
|
||||||
|
# Регуляризация
|
||||||
|
"grad_clip": 1.0,
|
||||||
|
# Ранняя остановка
|
||||||
|
"early_stopping_patience": 20,
|
||||||
|
# Выходные данные
|
||||||
|
"output_dir": "runs/gan_training",
|
||||||
|
# Логирование
|
||||||
|
"log_interval": 10, # Логировать каждые N батчей
|
||||||
|
"save_interval": 5, # Сохранять чекпоинт каждые N эпох
|
||||||
|
}
|
||||||
|
return config
|
||||||
|
|
||||||
|
|
||||||
|
def create_advanced_config():
|
||||||
|
"""Создает расширенную конфигурацию для обучения."""
|
||||||
|
config = {
|
||||||
|
# Параметры оптимизатора
|
||||||
|
"learning_rate": 2e-4,
|
||||||
|
"beta1": 0.5,
|
||||||
|
"beta2": 0.999,
|
||||||
|
# Планировщик learning rate
|
||||||
|
"use_scheduler": True,
|
||||||
|
"scheduler_type": "linear", # "linear", "cosine", или "plateau"
|
||||||
|
"scheduler_start_epoch": 50,
|
||||||
|
"scheduler_end_epoch": 100,
|
||||||
|
# Параметры обучения
|
||||||
|
"batch_size": 8,
|
||||||
|
"epochs": 200,
|
||||||
|
# Параметры GAN
|
||||||
|
"gan_mode": "lsgan", # LSGAN обычно более стабилен
|
||||||
|
"lambda_L1": 100.0,
|
||||||
|
# Аугментация данных
|
||||||
|
"augmentation": {
|
||||||
|
"random_crop": True,
|
||||||
|
"crop_size": 256,
|
||||||
|
"random_flip": True,
|
||||||
|
"color_jitter": True,
|
||||||
|
"brightness": 0.2,
|
||||||
|
"contrast": 0.2,
|
||||||
|
"saturation": 0.2,
|
||||||
|
"hue": 0.1,
|
||||||
|
},
|
||||||
|
# Регуляризация
|
||||||
|
"grad_clip": 1.0,
|
||||||
|
"weight_decay": 1e-4,
|
||||||
|
# Ранняя остановка
|
||||||
|
"early_stopping_patience": 30,
|
||||||
|
"early_stopping_min_delta": 1e-4,
|
||||||
|
# Выходные данные
|
||||||
|
"output_dir": "runs/gan_advanced",
|
||||||
|
# Логирование
|
||||||
|
"log_interval": 20,
|
||||||
|
"save_interval": 10,
|
||||||
|
"save_best_only": True, # Сохранять только лучшую модель
|
||||||
|
# Визуализация
|
||||||
|
"visualize_samples": True,
|
||||||
|
"num_visualize": 4,
|
||||||
|
"visualize_interval": 5, # Визуализировать каждые N эпох
|
||||||
|
}
|
||||||
|
return config
|
||||||
|
|
||||||
|
|
||||||
|
def print_config_summary(config):
|
||||||
|
"""Печатает сводку конфигурации."""
|
||||||
|
print("=" * 60)
|
||||||
|
print("Конфигурация обучения GAN")
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
print(f"\nПараметры модели:")
|
||||||
|
print(f" Режим GAN: {config.get('gan_mode', 'vanilla')}")
|
||||||
|
print(f" Вес L1 потерь: {config.get('lambda_L1', 100.0)}")
|
||||||
|
|
||||||
|
print(f"\nПараметры обучения:")
|
||||||
|
print(f" Learning rate: {config.get('learning_rate', 2e-4)}")
|
||||||
|
print(f" Batch size: {config.get('batch_size', 4)}")
|
||||||
|
print(f" Эпох: {config.get('epochs', 100)}")
|
||||||
|
print(f" Beta1: {config.get('beta1', 0.5)}")
|
||||||
|
print(f" Beta2: {config.get('beta2', 0.999)}")
|
||||||
|
|
||||||
|
if config.get("use_scheduler", False):
|
||||||
|
print(f" Планировщик LR: {config.get('scheduler_type', 'linear')}")
|
||||||
|
|
||||||
|
print(f"\nРегуляризация:")
|
||||||
|
print(f" Gradient clipping: {config.get('grad_clip', 1.0)}")
|
||||||
|
if "weight_decay" in config:
|
||||||
|
print(f" Weight decay: {config['weight_decay']}")
|
||||||
|
|
||||||
|
print(f"\nРанняя остановка:")
|
||||||
|
if config.get("early_stopping_patience", 0) > 0:
|
||||||
|
print(f" Patience: {config['early_stopping_patience']} эпох")
|
||||||
|
if "early_stopping_min_delta" in config:
|
||||||
|
print(f" Min delta: {config['early_stopping_min_delta']}")
|
||||||
|
|
||||||
|
print(f"\nВыходные данные:")
|
||||||
|
print(f" Директория: {config.get('output_dir', 'runs/gan')}")
|
||||||
|
print(f" Интервал сохранения: {config.get('save_interval', 5)} эпох")
|
||||||
|
|
||||||
|
print(f"\nЛогирование:")
|
||||||
|
print(f" Интервал логирования: {config.get('log_interval', 10)} батчей")
|
||||||
|
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
|
||||||
|
def setup_training():
|
||||||
|
"""Настраивает обучение."""
|
||||||
|
print("Настройка обучения GAN...")
|
||||||
|
|
||||||
|
# Выбираем конфигурацию
|
||||||
|
use_advanced = False # Измените на True для расширенной конфигурации
|
||||||
|
|
||||||
|
if use_advanced:
|
||||||
|
config = create_advanced_config()
|
||||||
|
else:
|
||||||
|
config = create_simple_config()
|
||||||
|
|
||||||
|
# Печатаем сводку конфигурации
|
||||||
|
print_config_summary(config)
|
||||||
|
|
||||||
|
# Устройство
|
||||||
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
print(f"\nИспользуемое устройство: {device}")
|
||||||
|
|
||||||
|
if device.type == "cuda":
|
||||||
|
print(f" GPU: {torch.cuda.get_device_name(0)}")
|
||||||
|
print(
|
||||||
|
f" Память: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Создаем модель
|
||||||
|
print("\nСоздание модели...")
|
||||||
|
model = create_image_gan(
|
||||||
|
input_channels=3,
|
||||||
|
output_channels=3,
|
||||||
|
gan_mode=config.get("gan_mode", "vanilla"),
|
||||||
|
lambda_L1=config.get("lambda_L1", 100.0),
|
||||||
|
use_cuda=(device.type == "cuda"),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Создаем даталоадеры
|
||||||
|
print("\nСоздание даталоадеров...")
|
||||||
|
# ЗАМЕНИТЕ ЭТО НА ВАШИ РЕАЛЬНЫЕ ДАННЫЕ
|
||||||
|
# Пример:
|
||||||
|
# from your_dataset_module import create_data_loaders
|
||||||
|
# train_loader, val_loader = create_data_loaders(
|
||||||
|
# data_dir="ваш/путь/к/данным",
|
||||||
|
# batch_size=config["batch_size"],
|
||||||
|
# image_size=(256, 256),
|
||||||
|
# augment=config.get("augmentation", None),
|
||||||
|
# )
|
||||||
|
|
||||||
|
# Для примера создаем фиктивные даталоадеры
|
||||||
|
# ВАЖНО: Замените это на реальные данные!
|
||||||
|
print(" ВНИМАНИЕ: Используются фиктивные данные!")
|
||||||
|
print(" Замените на реальные даталоадеры!")
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from torch.utils.data import Dataset
|
||||||
|
|
||||||
|
class DummyDataset(Dataset):
|
||||||
|
def __init__(self, num_samples=100):
|
||||||
|
self.num_samples = num_samples
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return self.num_samples
|
||||||
|
|
||||||
|
def __getitem__(self, idx):
|
||||||
|
# Фиктивные данные для примера
|
||||||
|
yandex_img = torch.randn(3, 256, 256)
|
||||||
|
google_img = torch.randn(3, 256, 256)
|
||||||
|
return {"yandex_img": yandex_img, "google_img": google_img}
|
||||||
|
|
||||||
|
train_dataset = DummyDataset(num_samples=100)
|
||||||
|
val_dataset = DummyDataset(num_samples=20)
|
||||||
|
|
||||||
|
train_loader = DataLoader(
|
||||||
|
train_dataset,
|
||||||
|
batch_size=config.get("batch_size", 4),
|
||||||
|
shuffle=True,
|
||||||
|
num_workers=0,
|
||||||
|
)
|
||||||
|
|
||||||
|
val_loader = DataLoader(
|
||||||
|
val_dataset,
|
||||||
|
batch_size=config.get("batch_size", 4),
|
||||||
|
shuffle=False,
|
||||||
|
num_workers=0,
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f" Размер обучающего набора: {len(train_dataset)}")
|
||||||
|
print(f" Размер валидационного набора: {len(val_dataset)}")
|
||||||
|
print(f" Батчей в эпохе: {len(train_loader)}")
|
||||||
|
|
||||||
|
# Создаем тренер
|
||||||
|
print("\nСоздание тренера...")
|
||||||
|
trainer = GANTrainer(
|
||||||
|
model=model,
|
||||||
|
train_loader=train_loader,
|
||||||
|
val_loader=val_loader,
|
||||||
|
device=device,
|
||||||
|
config=config,
|
||||||
|
)
|
||||||
|
|
||||||
|
return trainer, config
|
||||||
|
|
||||||
|
|
||||||
|
def train_model(trainer, config):
|
||||||
|
"""Запускает обучение модели."""
|
||||||
|
print("\n" + "=" * 60)
|
||||||
|
print("Начало обучения")
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
epochs = config.get("epochs", 100)
|
||||||
|
|
||||||
|
try:
|
||||||
|
trainer.train(num_epochs=epochs)
|
||||||
|
|
||||||
|
print("\n" + "=" * 60)
|
||||||
|
print("Обучение завершено успешно!")
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
print("\n\nОбучение прервано пользователем.")
|
||||||
|
print("Сохранение текущего состояния...")
|
||||||
|
trainer.save_checkpoint(is_best=False)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"\n\nОшибка при обучении: {e}")
|
||||||
|
import traceback
|
||||||
|
|
||||||
|
traceback.print_exc()
|
||||||
|
|
||||||
|
# Пытаемся сохранить чекпоинт при ошибке
|
||||||
|
try:
|
||||||
|
trainer.save_checkpoint(is_best=False)
|
||||||
|
print("Текущее состояние сохранено.")
|
||||||
|
except:
|
||||||
|
print("Не удалось сохранить состояние.")
|
||||||
|
|
||||||
|
|
||||||
|
def evaluate_model(trainer, test_loader=None):
|
||||||
|
"""Оценивает обученную модель."""
|
||||||
|
print("\n" + "=" * 60)
|
||||||
|
print("Оценка модели")
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
if test_loader is None:
|
||||||
|
print("Тестовый даталоадер не предоставлен.")
|
||||||
|
print("Используется валидационный даталоадер для оценки.")
|
||||||
|
test_loader = trainer.val_loader
|
||||||
|
|
||||||
|
metrics = trainer.evaluate(test_loader)
|
||||||
|
|
||||||
|
print("\nМетрики оценки:")
|
||||||
|
for key, value in metrics.items():
|
||||||
|
print(f" {key}: {value:.6f}")
|
||||||
|
|
||||||
|
return metrics
|
||||||
|
|
||||||
|
|
||||||
|
def generate_examples(model, device, num_examples=4):
|
||||||
|
"""Генерирует примеры преобразования."""
|
||||||
|
print("\n" + "=" * 60)
|
||||||
|
print("Генерация примеров")
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
# Создаем фиктивные входные данные
|
||||||
|
yandex_input = torch.randn(num_examples, 3, 256, 256).to(device)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
google_output = model(yandex_input)
|
||||||
|
|
||||||
|
print(f"Сгенерировано {num_examples} примеров")
|
||||||
|
print(f"Размер входных данных: {yandex_input.shape}")
|
||||||
|
print(f"Размер выходных данных: {google_output.shape}")
|
||||||
|
|
||||||
|
# Сохраняем примеры (в реальном коде сохраняйте как изображения)
|
||||||
|
print("\nПримеры сгенерированы.")
|
||||||
|
print("В реальном коде сохраняйте их как изображения для визуализации.")
|
||||||
|
|
||||||
|
return yandex_input, google_output
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
"""Основная функция."""
|
||||||
|
print("=" * 60)
|
||||||
|
print("Пример обучения GAN для преобразования Yandex → Google")
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
# Настройка
|
||||||
|
trainer, config = setup_training()
|
||||||
|
|
||||||
|
# Обучение
|
||||||
|
train_model(trainer, config)
|
||||||
|
|
||||||
|
# Оценка (требует реальных тестовых данных)
|
||||||
|
# evaluate_model(trainer)
|
||||||
|
|
||||||
|
# Генерация примеров
|
||||||
|
# generate_examples(trainer.model, trainer.device)
|
||||||
|
|
||||||
|
print("\n" + "=" * 60)
|
||||||
|
print("Скрипт завершен.")
|
||||||
|
print("=" * 60)
|
||||||
|
print("\nСледующие шаги:")
|
||||||
|
print("1. Замените фиктивные даталоадеры на реальные данные")
|
||||||
|
print("2. Настройте параметры в create_simple_config()")
|
||||||
|
print("3. Запустите обучение с реальными данными")
|
||||||
|
print("4. Визуализируйте результаты")
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
415
models/GAN/trainer.py
Normal file
415
models/GAN/trainer.py
Normal file
@@ -0,0 +1,415 @@
|
|||||||
|
import json
|
||||||
|
import time
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any, Dict, List, Tuple
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.optim as optim
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
# Type aliases
|
||||||
|
ModuleType = nn.Module
|
||||||
|
TensorType = torch.Tensor
|
||||||
|
|
||||||
|
|
||||||
|
class GANTrainer:
|
||||||
|
"""Trainer class for GAN model."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model: ModuleType,
|
||||||
|
train_loader: DataLoader,
|
||||||
|
val_loader: DataLoader,
|
||||||
|
device: torch.device,
|
||||||
|
config: Dict[str, Any],
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Initialize the GAN trainer.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model: GAN model (ImageGAN)
|
||||||
|
train_loader: Training data loader
|
||||||
|
val_loader: Validation data loader
|
||||||
|
device: Device to run training on
|
||||||
|
config: Training configuration dictionary
|
||||||
|
"""
|
||||||
|
self.model = model.to(device)
|
||||||
|
self.train_loader = train_loader
|
||||||
|
self.val_loader = val_loader
|
||||||
|
self.device = device
|
||||||
|
self.config = config
|
||||||
|
|
||||||
|
# Optimizers
|
||||||
|
lr = config.get("learning_rate", 2e-4)
|
||||||
|
beta1 = config.get("beta1", 0.5)
|
||||||
|
beta2 = config.get("beta2", 0.999)
|
||||||
|
|
||||||
|
# Separate optimizers for generator and discriminator
|
||||||
|
# Note: self.model is expected to have .generator and .discriminator attributes
|
||||||
|
self.optimizer_G = optim.Adam(
|
||||||
|
self.model.generator.parameters(), lr=lr, betas=(beta1, beta2)
|
||||||
|
)
|
||||||
|
self.optimizer_D = optim.Adam(
|
||||||
|
self.model.discriminator.parameters(), lr=lr, betas=(beta1, beta2)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Training state
|
||||||
|
self.current_epoch = 0
|
||||||
|
self.best_val_loss = float("inf")
|
||||||
|
self.g_losses: List[float] = []
|
||||||
|
self.d_losses: List[float] = []
|
||||||
|
self.val_g_losses: List[float] = []
|
||||||
|
self.val_d_losses: List[float] = []
|
||||||
|
|
||||||
|
# Create output directory
|
||||||
|
self.output_dir = Path(config.get("output_dir", "runs/gan"))
|
||||||
|
self.output_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
# TensorBoard writer
|
||||||
|
self.writer = SummaryWriter(log_dir=self.output_dir / "tensorboard")
|
||||||
|
|
||||||
|
# Save configuration
|
||||||
|
config_path = self.output_dir / "config.json"
|
||||||
|
with open(config_path, "w") as f:
|
||||||
|
json.dump(config, f, indent=2)
|
||||||
|
|
||||||
|
print(f"Training configuration saved to {config_path}")
|
||||||
|
# Access parameters through the model's generator and discriminator
|
||||||
|
generator_params = sum(p.numel() for p in self.model.generator.parameters())
|
||||||
|
discriminator_params = sum(
|
||||||
|
p.numel() for p in self.model.discriminator.parameters()
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"Generator has {generator_params:,} parameters")
|
||||||
|
print(f"Discriminator has {discriminator_params:,} parameters")
|
||||||
|
|
||||||
|
def train_epoch(self) -> Tuple[float, float]:
|
||||||
|
"""
|
||||||
|
Train for one epoch.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (average generator loss, average discriminator loss)
|
||||||
|
"""
|
||||||
|
self.model.train()
|
||||||
|
total_g_loss = 0.0
|
||||||
|
total_d_loss = 0.0
|
||||||
|
num_batches = len(self.train_loader)
|
||||||
|
|
||||||
|
progress_bar = tqdm(self.train_loader, desc=f"Epoch {self.current_epoch + 1}")
|
||||||
|
for batch_idx, batch in enumerate(progress_bar):
|
||||||
|
# Move data to device
|
||||||
|
yandex_img = batch["yandex_img"].to(self.device)
|
||||||
|
google_img = batch["google_img"].to(self.device)
|
||||||
|
|
||||||
|
# ========== Train Discriminator ==========
|
||||||
|
self.optimizer_D.zero_grad()
|
||||||
|
|
||||||
|
# Generate fake image
|
||||||
|
with torch.no_grad():
|
||||||
|
fake_google_img = self.model.generator(yandex_img)
|
||||||
|
|
||||||
|
# Discriminator loss - returns tuple of tensors
|
||||||
|
d_loss_tuple = self.model.discriminator_step(
|
||||||
|
yandex_img, google_img, fake_google_img
|
||||||
|
)
|
||||||
|
d_loss, d_real_loss, d_fake_loss = d_loss_tuple
|
||||||
|
|
||||||
|
# Backward and optimize discriminator
|
||||||
|
d_loss.backward()
|
||||||
|
self.optimizer_D.step()
|
||||||
|
|
||||||
|
# ========== Train Generator ==========
|
||||||
|
self.optimizer_G.zero_grad()
|
||||||
|
|
||||||
|
# Generate fake image
|
||||||
|
fake_google_img = self.model.generator(yandex_img)
|
||||||
|
|
||||||
|
# Generator loss - returns tuple of tensors
|
||||||
|
g_loss_tuple = self.model.generator_step(yandex_img, google_img)
|
||||||
|
g_loss, g_gan_loss, g_l1_loss = g_loss_tuple
|
||||||
|
|
||||||
|
# Backward and optimize generator
|
||||||
|
g_loss.backward()
|
||||||
|
self.optimizer_G.step()
|
||||||
|
|
||||||
|
# Update statistics
|
||||||
|
total_g_loss += g_loss.item()
|
||||||
|
total_d_loss += d_loss.item()
|
||||||
|
|
||||||
|
# Update progress bar
|
||||||
|
progress_bar.set_postfix(
|
||||||
|
{
|
||||||
|
"g_loss": g_loss.item(),
|
||||||
|
"d_loss": d_loss.item(),
|
||||||
|
"g_l1": g_l1_loss.item(),
|
||||||
|
"d_real": d_real_loss.item(),
|
||||||
|
"d_fake": d_fake_loss.item(),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Log batch losses to TensorBoard
|
||||||
|
global_step = self.current_epoch * num_batches + batch_idx
|
||||||
|
self.writer.add_scalar("train/batch_g_loss", g_loss.item(), global_step)
|
||||||
|
self.writer.add_scalar("train/batch_d_loss", d_loss.item(), global_step)
|
||||||
|
self.writer.add_scalar(
|
||||||
|
"train/batch_g_l1_loss", g_l1_loss.item(), global_step
|
||||||
|
)
|
||||||
|
self.writer.add_scalar(
|
||||||
|
"train/batch_d_real_loss", d_real_loss.item(), global_step
|
||||||
|
)
|
||||||
|
self.writer.add_scalar(
|
||||||
|
"train/batch_d_fake_loss", d_fake_loss.item(), global_step
|
||||||
|
)
|
||||||
|
|
||||||
|
avg_g_loss = total_g_loss / num_batches
|
||||||
|
avg_d_loss = total_d_loss / num_batches
|
||||||
|
self.g_losses.append(avg_g_loss)
|
||||||
|
self.d_losses.append(avg_d_loss)
|
||||||
|
|
||||||
|
return avg_g_loss, avg_d_loss
|
||||||
|
|
||||||
|
def validate(self) -> Tuple[float, float]:
|
||||||
|
"""
|
||||||
|
Validate the model.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (average generator validation loss, average discriminator validation loss)
|
||||||
|
"""
|
||||||
|
self.model.eval()
|
||||||
|
total_g_loss = 0.0
|
||||||
|
total_d_loss = 0.0
|
||||||
|
|
||||||
|
progress_bar = tqdm(self.val_loader, desc="Validation")
|
||||||
|
for batch in progress_bar:
|
||||||
|
# Move data to device
|
||||||
|
yandex_img = batch["yandex_img"].to(self.device)
|
||||||
|
google_img = batch["google_img"].to(self.device)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
# Generate fake image
|
||||||
|
fake_google_img = self.model.generator(yandex_img)
|
||||||
|
|
||||||
|
# Generator loss - returns tuple of tensors
|
||||||
|
g_loss_tuple = self.model.generator_step(yandex_img, google_img)
|
||||||
|
g_loss, g_gan_loss, g_l1_loss = g_loss_tuple
|
||||||
|
|
||||||
|
# Discriminator loss - returns tuple of tensors
|
||||||
|
d_loss_tuple = self.model.discriminator_step(
|
||||||
|
yandex_img, google_img, fake_google_img
|
||||||
|
)
|
||||||
|
d_loss, d_real_loss, d_fake_loss = d_loss_tuple
|
||||||
|
|
||||||
|
# Update statistics
|
||||||
|
total_g_loss += g_loss.item()
|
||||||
|
total_d_loss += d_loss.item()
|
||||||
|
|
||||||
|
# Update progress bar
|
||||||
|
progress_bar.set_postfix({"g_loss": g_loss.item(), "d_loss": d_loss.item()})
|
||||||
|
|
||||||
|
avg_g_loss = total_g_loss / len(self.val_loader)
|
||||||
|
avg_d_loss = total_d_loss / len(self.val_loader)
|
||||||
|
self.val_g_losses.append(avg_g_loss)
|
||||||
|
self.val_d_losses.append(avg_d_loss)
|
||||||
|
|
||||||
|
return avg_g_loss, avg_d_loss
|
||||||
|
|
||||||
|
def save_checkpoint(self, is_best: bool = False):
|
||||||
|
"""
|
||||||
|
Save training checkpoint.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
is_best: Whether this is the best model so far
|
||||||
|
"""
|
||||||
|
checkpoint = {
|
||||||
|
"epoch": self.current_epoch,
|
||||||
|
"model_state_dict": self.model.state_dict(),
|
||||||
|
"optimizer_G_state_dict": self.optimizer_G.state_dict(),
|
||||||
|
"optimizer_D_state_dict": self.optimizer_D.state_dict(),
|
||||||
|
"g_losses": self.g_losses,
|
||||||
|
"d_losses": self.d_losses,
|
||||||
|
"val_g_losses": self.val_g_losses,
|
||||||
|
"val_d_losses": self.val_d_losses,
|
||||||
|
"best_val_loss": self.best_val_loss,
|
||||||
|
"config": self.config,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Save regular checkpoint
|
||||||
|
checkpoint_path = (
|
||||||
|
self.output_dir / f"checkpoint_epoch_{self.current_epoch + 1}.pth"
|
||||||
|
)
|
||||||
|
torch.save(checkpoint, checkpoint_path)
|
||||||
|
|
||||||
|
# Save best model
|
||||||
|
if is_best:
|
||||||
|
best_path = self.output_dir / "model_best.pth"
|
||||||
|
torch.save(checkpoint, best_path)
|
||||||
|
print(f"Best model saved to {best_path}")
|
||||||
|
|
||||||
|
def load_checkpoint(self, checkpoint_path: str, resume_training: bool = False):
|
||||||
|
"""
|
||||||
|
Load training checkpoint.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
checkpoint_path: Path to checkpoint file
|
||||||
|
resume_training: Если True, продолжить обучение с сохраненной эпохи
|
||||||
|
"""
|
||||||
|
checkpoint = torch.load(checkpoint_path, map_location=self.device)
|
||||||
|
|
||||||
|
self.model.load_state_dict(checkpoint["model_state_dict"])
|
||||||
|
self.optimizer_G.load_state_dict(checkpoint["optimizer_G_state_dict"])
|
||||||
|
self.optimizer_D.load_state_dict(checkpoint["optimizer_D_state_dict"])
|
||||||
|
|
||||||
|
self.current_epoch = checkpoint["epoch"]
|
||||||
|
self.g_losses = checkpoint["g_losses"]
|
||||||
|
self.d_losses = checkpoint["d_losses"]
|
||||||
|
self.val_g_losses = checkpoint["val_g_losses"]
|
||||||
|
self.val_d_losses = checkpoint["val_d_losses"]
|
||||||
|
self.best_val_loss = checkpoint["best_val_loss"]
|
||||||
|
|
||||||
|
if resume_training:
|
||||||
|
print(f"Resuming training from epoch {self.current_epoch + 1}")
|
||||||
|
else:
|
||||||
|
print(f"Loaded checkpoint from epoch {self.current_epoch + 1}")
|
||||||
|
|
||||||
|
def train(self, num_epochs: int, start_epoch: int = 0):
|
||||||
|
"""
|
||||||
|
Train the model for specified number of epochs.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
num_epochs: Number of epochs to train
|
||||||
|
start_epoch: Starting epoch (useful when resuming training)
|
||||||
|
"""
|
||||||
|
print(
|
||||||
|
f"Starting GAN training for {num_epochs} epochs from epoch {start_epoch + 1}..."
|
||||||
|
)
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
for epoch in range(start_epoch, start_epoch + num_epochs):
|
||||||
|
self.current_epoch = epoch
|
||||||
|
|
||||||
|
# Train for one epoch
|
||||||
|
train_g_loss, train_d_loss = self.train_epoch()
|
||||||
|
|
||||||
|
# Validate
|
||||||
|
val_g_loss, val_d_loss = self.validate()
|
||||||
|
|
||||||
|
# Log to TensorBoard
|
||||||
|
self.writer.add_scalar("train/epoch_g_loss", train_g_loss, epoch)
|
||||||
|
self.writer.add_scalar("train/epoch_d_loss", train_d_loss, epoch)
|
||||||
|
self.writer.add_scalar("val/epoch_g_loss", val_g_loss, epoch)
|
||||||
|
self.writer.add_scalar("val/epoch_d_loss", val_d_loss, epoch)
|
||||||
|
|
||||||
|
# Print epoch summary
|
||||||
|
print(f"\nEpoch {epoch + 1}/{num_epochs}:")
|
||||||
|
print(" Generator:")
|
||||||
|
print(f" Train Loss: {train_g_loss:.6f}")
|
||||||
|
print(f" Val Loss: {val_g_loss:.6f}")
|
||||||
|
print(" Discriminator:")
|
||||||
|
print(f" Train Loss: {train_d_loss:.6f}")
|
||||||
|
print(f" Val Loss: {val_d_loss:.6f}")
|
||||||
|
|
||||||
|
# Save checkpoint
|
||||||
|
val_total_loss = val_g_loss + val_d_loss
|
||||||
|
is_best = val_total_loss < self.best_val_loss
|
||||||
|
if is_best:
|
||||||
|
self.best_val_loss = val_total_loss
|
||||||
|
|
||||||
|
self.save_checkpoint(is_best=is_best)
|
||||||
|
|
||||||
|
# Early stopping
|
||||||
|
if self.config.get("early_stopping_patience", 0) > 0:
|
||||||
|
val_losses = [
|
||||||
|
g + d for g, d in zip(self.val_g_losses, self.val_d_losses)
|
||||||
|
]
|
||||||
|
if (
|
||||||
|
epoch - np.argmin(val_losses)
|
||||||
|
>= self.config["early_stopping_patience"]
|
||||||
|
):
|
||||||
|
print(f"Early stopping at epoch {epoch + 1}")
|
||||||
|
break
|
||||||
|
|
||||||
|
# Training completed
|
||||||
|
training_time = time.time() - start_time
|
||||||
|
print(f"\nTraining completed in {training_time:.2f} seconds")
|
||||||
|
print(f"Best validation total loss: {self.best_val_loss:.6f}")
|
||||||
|
|
||||||
|
# Save final model
|
||||||
|
final_model_path = self.output_dir / "model_final.pth"
|
||||||
|
torch.save(self.model.state_dict(), final_model_path)
|
||||||
|
print(f"Final model saved to {final_model_path}")
|
||||||
|
|
||||||
|
# Save training history
|
||||||
|
history_path = self.output_dir / "training_history.json"
|
||||||
|
history = {
|
||||||
|
"g_losses": self.g_losses,
|
||||||
|
"d_losses": self.d_losses,
|
||||||
|
"val_g_losses": self.val_g_losses,
|
||||||
|
"val_d_losses": self.val_d_losses,
|
||||||
|
"best_val_loss": self.best_val_loss,
|
||||||
|
"total_epochs": self.current_epoch + 1,
|
||||||
|
}
|
||||||
|
with open(history_path, "w") as f:
|
||||||
|
json.dump(history, f, indent=2)
|
||||||
|
print(f"Training history saved to {history_path}")
|
||||||
|
|
||||||
|
# Close TensorBoard writer
|
||||||
|
self.writer.close()
|
||||||
|
|
||||||
|
def evaluate(self, test_loader: DataLoader) -> Dict:
|
||||||
|
"""
|
||||||
|
Evaluate the model on test data.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
test_loader: Test data loader
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary with evaluation metrics
|
||||||
|
"""
|
||||||
|
self.model.eval()
|
||||||
|
total_g_loss = 0.0
|
||||||
|
total_d_loss = 0.0
|
||||||
|
|
||||||
|
print("Evaluating model on test set...")
|
||||||
|
|
||||||
|
for batch in tqdm(test_loader, desc="Evaluation"):
|
||||||
|
# Move data to device
|
||||||
|
yandex_img = batch["yandex_img"].to(self.device)
|
||||||
|
google_img = batch["google_img"].to(self.device)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
# Generate fake image
|
||||||
|
fake_google_img = self.model.generator(yandex_img)
|
||||||
|
|
||||||
|
# Generator loss - returns tuple of tensors
|
||||||
|
g_loss_tuple = self.model.generator_step(yandex_img, google_img)
|
||||||
|
g_loss, g_gan_loss, g_l1_loss = g_loss_tuple
|
||||||
|
|
||||||
|
# Discriminator loss - returns tuple of tensors
|
||||||
|
d_loss_tuple = self.model.discriminator_step(
|
||||||
|
yandex_img, google_img, fake_google_img
|
||||||
|
)
|
||||||
|
d_loss, d_real_loss, d_fake_loss = d_loss_tuple
|
||||||
|
|
||||||
|
# Update statistics
|
||||||
|
total_g_loss += g_loss.item()
|
||||||
|
total_d_loss += d_loss.item()
|
||||||
|
|
||||||
|
avg_g_loss = total_g_loss / len(test_loader)
|
||||||
|
avg_d_loss = total_d_loss / len(test_loader)
|
||||||
|
|
||||||
|
metrics = {
|
||||||
|
"generator_loss": avg_g_loss,
|
||||||
|
"discriminator_loss": avg_d_loss,
|
||||||
|
"total_loss": avg_g_loss + avg_d_loss,
|
||||||
|
}
|
||||||
|
|
||||||
|
print("\nTest Results:")
|
||||||
|
print(f" Generator Loss: {avg_g_loss:.6f}")
|
||||||
|
print(f" Discriminator Loss: {avg_d_loss:.6f}")
|
||||||
|
print(f" Total Loss: {avg_g_loss + avg_d_loss:.6f}")
|
||||||
|
|
||||||
|
return metrics
|
||||||
1
models/SiaN/.gitignore
vendored
Normal file
1
models/SiaN/.gitignore
vendored
Normal file
@@ -0,0 +1 @@
|
|||||||
|
runs
|
||||||
295
models/SiaN/README_homography.md
Normal file
295
models/SiaN/README_homography.md
Normal file
@@ -0,0 +1,295 @@
|
|||||||
|
# Homography Estimation System
|
||||||
|
|
||||||
|
This system provides a complete pipeline for estimating homography matrices between Google and Yandex map images using deep learning.
|
||||||
|
|
||||||
|
## Overview
|
||||||
|
|
||||||
|
Homography estimation is crucial for aligning images from different sources (Google Maps and Yandex Maps in this case). The system includes:
|
||||||
|
|
||||||
|
1. **Dataset handling** - Loading and preprocessing image pairs
|
||||||
|
2. **Data augmentation** - Homography-based augmentation for robust training
|
||||||
|
3. **CNN model** - Deep learning model for homography estimation
|
||||||
|
4. **Training pipeline** - Complete training and evaluation workflow
|
||||||
|
5. **Inference tools** - Tools for using trained models on new data
|
||||||
|
|
||||||
|
## Installation
|
||||||
|
|
||||||
|
### Prerequisites
|
||||||
|
- Python 3.8+
|
||||||
|
- PyTorch 1.9+
|
||||||
|
- OpenCV
|
||||||
|
- PIL/Pillow
|
||||||
|
- NumPy
|
||||||
|
|
||||||
|
### Install dependencies
|
||||||
|
```bash
|
||||||
|
pip install torch torchvision opencv-python pillow numpy matplotlib tqdm tensorboard
|
||||||
|
```
|
||||||
|
|
||||||
|
## Dataset Structure
|
||||||
|
|
||||||
|
The system expects image pairs in the following format:
|
||||||
|
```
|
||||||
|
dataset/
|
||||||
|
├── 0000_google.png
|
||||||
|
├── 0000_yandex.png
|
||||||
|
├── 0001_google.png
|
||||||
|
├── 0001_yandex.png
|
||||||
|
└── ...
|
||||||
|
```
|
||||||
|
|
||||||
|
Each pair consists of:
|
||||||
|
- `{idx:04d}_google.png` - Google map image
|
||||||
|
- `{idx:04d}_yandex.png` - Yandex map image
|
||||||
|
|
||||||
|
## Quick Start
|
||||||
|
|
||||||
|
### 1. Explore the dataset
|
||||||
|
```python
|
||||||
|
from models.homography import HomographyDataset
|
||||||
|
|
||||||
|
dataset = HomographyDataset(
|
||||||
|
root_dir=r"C:\Users\admin\Projects\autopilot\datasets\ya_go_maps\images",
|
||||||
|
augment=True,
|
||||||
|
image_size=(256, 256)
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"Found {len(dataset)} image pairs")
|
||||||
|
sample = dataset[0]
|
||||||
|
print(f"Sample homography matrix:\n{sample['homography'].numpy()}")
|
||||||
|
```
|
||||||
|
|
||||||
|
### 2. Train a model
|
||||||
|
```bash
|
||||||
|
python models/train_homography.py \
|
||||||
|
--data_dir "C:\Users\admin\Projects\autopilot\datasets\ya_go_maps\images" \
|
||||||
|
--epochs 50 \
|
||||||
|
--batch_size 16 \
|
||||||
|
--lr 1e-3 \
|
||||||
|
--output_dir "runs/my_experiment"
|
||||||
|
```
|
||||||
|
|
||||||
|
### 3. Perform inference
|
||||||
|
```bash
|
||||||
|
python models/infer_homography.py \
|
||||||
|
--model_path "runs/my_experiment/checkpoint_best.pth" \
|
||||||
|
--mode single \
|
||||||
|
--google_path "path/to/google.png" \
|
||||||
|
--yandex_path "path/to/yandex.png" \
|
||||||
|
--output_vis "alignment_result.png"
|
||||||
|
```
|
||||||
|
|
||||||
|
## File Structure
|
||||||
|
|
||||||
|
```
|
||||||
|
models/
|
||||||
|
├── homography.py # Dataset class and data loaders
|
||||||
|
├── homography_cnn.py # CNN model architecture
|
||||||
|
├── train_homography.py # Training script
|
||||||
|
├── infer_homography.py # Inference script
|
||||||
|
├── example_homography.py # Example usage
|
||||||
|
├── homography.ipynb # Jupyter notebook (empty)
|
||||||
|
└── README_homography.md # This file
|
||||||
|
```
|
||||||
|
|
||||||
|
## Model Architecture
|
||||||
|
|
||||||
|
The homography estimation model (`HomographyCNN`) consists of:
|
||||||
|
|
||||||
|
1. **Dual encoders** - Separate feature extraction for Google and Yandex images
|
||||||
|
2. **Residual blocks** - For deep feature learning
|
||||||
|
3. **Fusion layers** - Combine features from both images
|
||||||
|
4. **Regression head** - Predict 3x3 homography matrix
|
||||||
|
|
||||||
|
### Key Features:
|
||||||
|
- Residual connections for stable training
|
||||||
|
- Batch normalization options
|
||||||
|
- Dropout for regularization
|
||||||
|
- Geometric consistency loss
|
||||||
|
|
||||||
|
## Training Configuration
|
||||||
|
|
||||||
|
Default training parameters:
|
||||||
|
- **Optimizer**: Adam with learning rate 1e-3
|
||||||
|
- **Loss function**: Combined matrix + geometric + regularization loss
|
||||||
|
- **Batch size**: 32
|
||||||
|
- **Image size**: 256x256
|
||||||
|
- **Train/val split**: 80/20
|
||||||
|
|
||||||
|
## Inference Modes
|
||||||
|
|
||||||
|
The inference script supports three modes:
|
||||||
|
|
||||||
|
### 1. Single image pair
|
||||||
|
```bash
|
||||||
|
python infer_homography.py --mode single \
|
||||||
|
--google_path google.png --yandex_path yandex.png
|
||||||
|
```
|
||||||
|
|
||||||
|
### 2. Dataset evaluation
|
||||||
|
```bash
|
||||||
|
python infer_homography.py --mode dataset \
|
||||||
|
--dataset_dir path/to/dataset --num_samples 100
|
||||||
|
```
|
||||||
|
|
||||||
|
### 3. Batch processing
|
||||||
|
```bash
|
||||||
|
python infer_homography.py --mode batch \
|
||||||
|
--input_dir path/to/input --output_dir path/to/output
|
||||||
|
```
|
||||||
|
|
||||||
|
## Evaluation Metrics
|
||||||
|
|
||||||
|
The system computes several metrics:
|
||||||
|
- **Matrix MSE**: Mean squared error of homography matrix elements
|
||||||
|
- **Corner error**: Average pixel error at image corners
|
||||||
|
- **Geometric consistency**: Warping error across grid points
|
||||||
|
|
||||||
|
## Data Augmentation
|
||||||
|
|
||||||
|
The dataset applies homography-based augmentation:
|
||||||
|
- Random rotation (-30° to 30°)
|
||||||
|
- Random scaling (0.8x to 1.2x)
|
||||||
|
- Random translation (-50 to 50 pixels)
|
||||||
|
- Small perspective distortion
|
||||||
|
|
||||||
|
## Integration with Autopilot System
|
||||||
|
|
||||||
|
The homography estimation can be integrated into the autopilot system:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from models.infer_homography import HomographyInference
|
||||||
|
|
||||||
|
# Initialize inference
|
||||||
|
inference = HomographyInference(model_path="path/to/model.pth")
|
||||||
|
|
||||||
|
# During flight loop
|
||||||
|
homography = inference.predict(google_img, yandex_img)
|
||||||
|
# Use homography to update drone position
|
||||||
|
```
|
||||||
|
|
||||||
|
## Troubleshooting
|
||||||
|
|
||||||
|
### Common Issues
|
||||||
|
|
||||||
|
1. **Out of memory**
|
||||||
|
- Reduce batch size
|
||||||
|
- Use smaller image size
|
||||||
|
- Enable gradient checkpointing
|
||||||
|
|
||||||
|
2. **Poor convergence**
|
||||||
|
- Adjust learning rate
|
||||||
|
- Increase model capacity
|
||||||
|
- Add more data augmentation
|
||||||
|
|
||||||
|
3. **Inference errors**
|
||||||
|
- Check image formats (must be RGB)
|
||||||
|
- Verify model was trained with same image size
|
||||||
|
- Ensure proper normalization
|
||||||
|
|
||||||
|
### Debugging Tips
|
||||||
|
|
||||||
|
- Use `example_homography.py` to test components
|
||||||
|
- Enable TensorBoard for training visualization
|
||||||
|
- Check homography matrix normalization (last element should be ~1)
|
||||||
|
|
||||||
|
## Advanced Usage
|
||||||
|
|
||||||
|
### Custom Model Architecture
|
||||||
|
```python
|
||||||
|
from models.homography_cnn import create_homography_model
|
||||||
|
|
||||||
|
model = create_homography_model(
|
||||||
|
model_type="cnn",
|
||||||
|
input_size=(512, 512),
|
||||||
|
hidden_channels=128,
|
||||||
|
num_blocks=6,
|
||||||
|
dropout_rate=0.5
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
### Custom Loss Function
|
||||||
|
```python
|
||||||
|
from models.homography_cnn import HomographyLoss
|
||||||
|
|
||||||
|
loss_fn = HomographyLoss(
|
||||||
|
matrix_weight=0.7,
|
||||||
|
geometric_weight=0.3,
|
||||||
|
reg_weight=0.05,
|
||||||
|
grid_size=16
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
### Transfer Learning
|
||||||
|
```python
|
||||||
|
# Load pretrained model
|
||||||
|
checkpoint = torch.load("pretrained.pth")
|
||||||
|
model.load_state_dict(checkpoint["model_state_dict"])
|
||||||
|
|
||||||
|
# Fine-tune on new data
|
||||||
|
for param in model.google_encoder.parameters():
|
||||||
|
param.requires_grad = False # Freeze encoder
|
||||||
|
```
|
||||||
|
|
||||||
|
## Performance Tips
|
||||||
|
|
||||||
|
1. **Training speed**
|
||||||
|
- Use multiple GPU workers for data loading
|
||||||
|
- Enable mixed precision training
|
||||||
|
- Use gradient accumulation for larger effective batch sizes
|
||||||
|
|
||||||
|
2. **Memory efficiency**
|
||||||
|
- Use gradient checkpointing
|
||||||
|
- Implement progressive resizing
|
||||||
|
- Use memory-efficient optimizers
|
||||||
|
|
||||||
|
3. **Inference speed**
|
||||||
|
- Use TensorRT or ONNX for deployment
|
||||||
|
- Implement model quantization
|
||||||
|
- Use batch inference when possible
|
||||||
|
|
||||||
|
## Future Improvements
|
||||||
|
|
||||||
|
1. **Model enhancements**
|
||||||
|
- Transformer-based architecture
|
||||||
|
- Multi-scale feature fusion
|
||||||
|
- Uncertainty estimation
|
||||||
|
|
||||||
|
2. **Training improvements**
|
||||||
|
- Self-supervised pre-training
|
||||||
|
- Curriculum learning
|
||||||
|
- Adversarial training
|
||||||
|
|
||||||
|
3. **Deployment features**
|
||||||
|
- Real-time inference optimization
|
||||||
|
- Mobile deployment
|
||||||
|
- Web API
|
||||||
|
|
||||||
|
## Citation
|
||||||
|
|
||||||
|
If you use this system in your research, please cite:
|
||||||
|
|
||||||
|
```bibtex
|
||||||
|
@software{homography_estimation_2024,
|
||||||
|
title = {Homography Estimation for Map Alignment},
|
||||||
|
author = {Autopilot Team},
|
||||||
|
year = {2024},
|
||||||
|
url = {https://github.com/your-repo/homography-estimation}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
## License
|
||||||
|
|
||||||
|
This project is licensed under the MIT License - see the LICENSE file for details.
|
||||||
|
|
||||||
|
## Support
|
||||||
|
|
||||||
|
For questions and issues:
|
||||||
|
1. Check the troubleshooting section
|
||||||
|
2. Review the example scripts
|
||||||
|
3. Open an issue on GitHub
|
||||||
|
4. Contact the development team
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
*Last updated: October 2024*
|
||||||
345
models/SiaN/example_homography.py
Normal file
345
models/SiaN/example_homography.py
Normal file
@@ -0,0 +1,345 @@
|
|||||||
|
"""
|
||||||
|
Example script demonstrating the complete homography estimation workflow.
|
||||||
|
|
||||||
|
This script shows how to:
|
||||||
|
1. Load the ya_go_maps dataset
|
||||||
|
2. Create and train a homography estimation model
|
||||||
|
3. Perform inference on new image pairs
|
||||||
|
4. Visualize results
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
# Add parent directory to path for imports
|
||||||
|
sys.path.append(str(Path(__file__).parent.parent))
|
||||||
|
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
from models.homography import HomographyDataset, create_data_loaders
|
||||||
|
from models.homography_cnn import HomographyCNN, HomographyLoss, create_homography_model
|
||||||
|
from models.infer_homography import HomographyInference
|
||||||
|
from models.train_homography import HomographyTrainer
|
||||||
|
|
||||||
|
|
||||||
|
def example_dataset_loading():
|
||||||
|
"""Example 1: Loading and exploring the dataset."""
|
||||||
|
print("=" * 60)
|
||||||
|
print("Example 1: Loading and exploring the dataset")
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
# Path to the dataset
|
||||||
|
dataset_path = r"C:\Users\admin\Projects\autopilot\datasets\ya_go_maps\images"
|
||||||
|
|
||||||
|
# Create dataset
|
||||||
|
dataset = HomographyDataset(
|
||||||
|
root_dir=dataset_path,
|
||||||
|
augment=True,
|
||||||
|
image_size=(256, 256),
|
||||||
|
cache_homographies=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"Dataset size: {len(dataset)} image pairs")
|
||||||
|
|
||||||
|
# Get a sample
|
||||||
|
sample = dataset[0]
|
||||||
|
print(f"\nSample keys: {list(sample.keys())}")
|
||||||
|
print(f"Google image shape: {sample['google_img'].shape}")
|
||||||
|
print(f"Yandex image shape: {sample['yandex_img'].shape}")
|
||||||
|
print(f"Homography shape: {sample['homography'].shape}")
|
||||||
|
|
||||||
|
# Show sample homography matrix
|
||||||
|
print(f"\nSample homography matrix:")
|
||||||
|
print(sample["homography"].numpy())
|
||||||
|
|
||||||
|
# Get sample without augmentation for visualization
|
||||||
|
raw_sample = dataset.get_sample_without_augmentation(0)
|
||||||
|
|
||||||
|
# Visualize sample
|
||||||
|
fig, axes = plt.subplots(1, 2, figsize=(10, 5))
|
||||||
|
|
||||||
|
axes[0].imshow(raw_sample["google_img"])
|
||||||
|
axes[0].set_title("Google Map")
|
||||||
|
axes[0].axis("off")
|
||||||
|
|
||||||
|
axes[1].imshow(raw_sample["yandex_img"])
|
||||||
|
axes[1].set_title("Yandex Map")
|
||||||
|
axes[1].axis("off")
|
||||||
|
|
||||||
|
plt.suptitle("Sample Image Pair (without augmentation)")
|
||||||
|
plt.tight_layout()
|
||||||
|
plt.show()
|
||||||
|
|
||||||
|
return dataset
|
||||||
|
|
||||||
|
|
||||||
|
def example_model_creation():
|
||||||
|
"""Example 2: Creating and testing the model."""
|
||||||
|
print("\n" + "=" * 60)
|
||||||
|
print("Example 2: Creating and testing the model")
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
# Set device
|
||||||
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
print(f"Using device: {device}")
|
||||||
|
|
||||||
|
# Create model
|
||||||
|
model = HomographyCNN(
|
||||||
|
input_channels=3,
|
||||||
|
hidden_channels=64,
|
||||||
|
num_blocks=4,
|
||||||
|
dropout_rate=0.3,
|
||||||
|
use_batch_norm=True,
|
||||||
|
).to(device)
|
||||||
|
|
||||||
|
print(
|
||||||
|
f"Model created with {sum(p.numel() for p in model.parameters()):,} parameters"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create dummy input
|
||||||
|
batch_size = 2
|
||||||
|
height, width = 256, 256
|
||||||
|
|
||||||
|
google_img = torch.randn(batch_size, 3, height, width).to(device)
|
||||||
|
yandex_img = torch.randn(batch_size, 3, height, width).to(device)
|
||||||
|
|
||||||
|
# Test forward pass
|
||||||
|
print("\nTesting forward pass...")
|
||||||
|
output = model(google_img, yandex_img, return_matrix=True)
|
||||||
|
print(f"Output shape: {output.shape}") # Should be (2, 3, 3)
|
||||||
|
print(f"Sample output matrix:")
|
||||||
|
print(output[0].cpu().detach().numpy())
|
||||||
|
|
||||||
|
# Test loss function
|
||||||
|
print("\nTesting loss function...")
|
||||||
|
target_homography = torch.eye(3).unsqueeze(0).expand(batch_size, -1, -1).to(device)
|
||||||
|
loss_fn = HomographyLoss(
|
||||||
|
matrix_weight=1.0,
|
||||||
|
geometric_weight=0.5,
|
||||||
|
reg_weight=0.1,
|
||||||
|
grid_size=8,
|
||||||
|
).to(device)
|
||||||
|
|
||||||
|
loss = loss_fn(output, target_homography, google_img, yandex_img)
|
||||||
|
print(f"Loss value: {loss.item():.6f}")
|
||||||
|
|
||||||
|
# Test metrics
|
||||||
|
print("\nTesting metrics...")
|
||||||
|
metrics = loss_fn.compute_metrics(output, target_homography)
|
||||||
|
for key, value in metrics.items():
|
||||||
|
print(f"{key}: {value:.6f}")
|
||||||
|
|
||||||
|
return model, loss_fn
|
||||||
|
|
||||||
|
|
||||||
|
def example_data_loaders():
|
||||||
|
"""Example 3: Creating data loaders for training."""
|
||||||
|
print("\n" + "=" * 60)
|
||||||
|
print("Example 3: Creating data loaders for training")
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
# Path to the dataset
|
||||||
|
dataset_path = r"C:\Users\admin\Projects\autopilot\datasets\ya_go_maps\images"
|
||||||
|
|
||||||
|
# Create data loaders
|
||||||
|
train_loader, val_loader = create_data_loaders(
|
||||||
|
root_dir=dataset_path,
|
||||||
|
batch_size=16,
|
||||||
|
train_split=0.8,
|
||||||
|
num_workers=0, # Use 0 for debugging, increase for training
|
||||||
|
image_size=(256, 256),
|
||||||
|
augment_train=True,
|
||||||
|
augment_val=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"Train batches: {len(train_loader)}")
|
||||||
|
print(f"Val batches: {len(val_loader)}")
|
||||||
|
|
||||||
|
# Get a batch from train loader
|
||||||
|
batch = next(iter(train_loader))
|
||||||
|
print(f"\nBatch keys: {list(batch.keys())}")
|
||||||
|
print(f"Batch size: {batch['google_img'].shape[0]}")
|
||||||
|
print(f"Image shape: {batch['google_img'].shape[1:]}")
|
||||||
|
|
||||||
|
return train_loader, val_loader
|
||||||
|
|
||||||
|
|
||||||
|
def example_training_config():
|
||||||
|
"""Example 4: Setting up training configuration."""
|
||||||
|
print("\n" + "=" * 60)
|
||||||
|
print("Example 4: Training configuration")
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
# Training configuration
|
||||||
|
config = {
|
||||||
|
# Model config
|
||||||
|
"model_type": "cnn",
|
||||||
|
"hidden_channels": 64,
|
||||||
|
"num_blocks": 4,
|
||||||
|
"dropout_rate": 0.3,
|
||||||
|
"use_batch_norm": True,
|
||||||
|
"image_size": [256, 256],
|
||||||
|
# Training config
|
||||||
|
"epochs": 50,
|
||||||
|
"batch_size": 16,
|
||||||
|
"learning_rate": 1e-3,
|
||||||
|
"weight_decay": 1e-4,
|
||||||
|
"optimizer": "adam",
|
||||||
|
"scheduler": "plateau",
|
||||||
|
"grad_clip": 1.0,
|
||||||
|
# Loss config
|
||||||
|
"matrix_weight": 1.0,
|
||||||
|
"geometric_weight": 0.5,
|
||||||
|
"reg_weight": 0.1,
|
||||||
|
"grid_size": 8,
|
||||||
|
# Data config
|
||||||
|
"data_dir": r"C:\Users\admin\Projects\autopilot\datasets\ya_go_maps\images",
|
||||||
|
"train_split": 0.8,
|
||||||
|
"num_workers": 0,
|
||||||
|
# Output config
|
||||||
|
"output_dir": "runs/example_training",
|
||||||
|
"seed": 42,
|
||||||
|
}
|
||||||
|
|
||||||
|
print("Training configuration:")
|
||||||
|
for key, value in config.items():
|
||||||
|
print(f" {key}: {value}")
|
||||||
|
|
||||||
|
return config
|
||||||
|
|
||||||
|
|
||||||
|
def example_inference():
|
||||||
|
"""Example 5: Performing inference with a trained model."""
|
||||||
|
print("\n" + "=" * 60)
|
||||||
|
print("Example 5: Inference example")
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
# Note: This example assumes you have a trained model
|
||||||
|
# For demonstration, we'll show the code structure
|
||||||
|
|
||||||
|
print("Inference workflow:")
|
||||||
|
print("1. Load trained model")
|
||||||
|
print("2. Preprocess input images")
|
||||||
|
print("3. Predict homography matrix")
|
||||||
|
print("4. Visualize alignment")
|
||||||
|
|
||||||
|
# Example code structure (commented out since we don't have a trained model yet)
|
||||||
|
"""
|
||||||
|
# Create inference object
|
||||||
|
inference = HomographyInference(
|
||||||
|
model_path="runs/homography/checkpoint_best.pth",
|
||||||
|
device="cuda" if torch.cuda.is_available() else "cpu",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Load images
|
||||||
|
google_img = Image.open("path/to/google.png")
|
||||||
|
yandex_img = Image.open("path/to/yandex.png")
|
||||||
|
|
||||||
|
# Predict homography
|
||||||
|
homography = inference.predict(google_img, yandex_img)
|
||||||
|
|
||||||
|
print(f"Predicted homography matrix:")
|
||||||
|
print(homography.cpu().numpy())
|
||||||
|
|
||||||
|
# Visualize alignment
|
||||||
|
inference.visualize_alignment(
|
||||||
|
google_img,
|
||||||
|
yandex_img,
|
||||||
|
homography.cpu().numpy(),
|
||||||
|
save_path="alignment_visualization.png",
|
||||||
|
show=True,
|
||||||
|
)
|
||||||
|
"""
|
||||||
|
|
||||||
|
print(
|
||||||
|
"\nNote: To run actual inference, first train a model using train_homography.py"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def example_quick_start():
|
||||||
|
"""Example 6: Quick start guide."""
|
||||||
|
print("\n" + "=" * 60)
|
||||||
|
print("Example 6: Quick Start Guide")
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
print("\nTo get started with homography estimation:")
|
||||||
|
print("\n1. First, explore the dataset:")
|
||||||
|
print(
|
||||||
|
' python -c "from models.homography import HomographyDataset; '
|
||||||
|
"dataset = HomographyDataset(r'C:\\Users\\admin\\Projects\\autopilot\\datasets\\ya_go_maps\\images'); "
|
||||||
|
"print(f'Found {len(dataset)} image pairs')\""
|
||||||
|
)
|
||||||
|
|
||||||
|
print("\n2. Train a model:")
|
||||||
|
print(" python models/train_homography.py \\")
|
||||||
|
print(
|
||||||
|
' --data_dir "C:\\Users\\admin\\Projects\\autopilot\\datasets\\ya_go_maps\\images" \\'
|
||||||
|
)
|
||||||
|
print(" --epochs 50 \\")
|
||||||
|
print(" --batch_size 16 \\")
|
||||||
|
print(' --output_dir "runs/my_experiment"')
|
||||||
|
|
||||||
|
print("\n3. Perform inference on a single image pair:")
|
||||||
|
print(" python models/infer_homography.py \\")
|
||||||
|
print(' --model_path "runs/my_experiment/checkpoint_best.pth" \\')
|
||||||
|
print(" --mode single \\")
|
||||||
|
print(' --google_path "path/to/google.png" \\')
|
||||||
|
print(' --yandex_path "path/to/yandex.png" \\')
|
||||||
|
print(' --output_vis "alignment_result.png"')
|
||||||
|
|
||||||
|
print("\n4. Evaluate on the entire dataset:")
|
||||||
|
print(" python models/infer_homography.py \\")
|
||||||
|
print(' --model_path "runs/my_experiment/checkpoint_best.pth" \\')
|
||||||
|
print(" --mode dataset \\")
|
||||||
|
print(
|
||||||
|
' --dataset_dir "C:\\Users\\admin\\Projects\\autopilot\\datasets\\ya_go_maps\\images" \\'
|
||||||
|
)
|
||||||
|
print(' --save_results "evaluation_results.json"')
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
"""Run all examples."""
|
||||||
|
print("Homography Estimation Workflow Examples")
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Example 1: Dataset loading
|
||||||
|
dataset = example_dataset_loading()
|
||||||
|
|
||||||
|
# Example 2: Model creation
|
||||||
|
model, loss_fn = example_model_creation()
|
||||||
|
|
||||||
|
# Example 3: Data loaders
|
||||||
|
train_loader, val_loader = example_data_loaders()
|
||||||
|
|
||||||
|
# Example 4: Training configuration
|
||||||
|
config = example_training_config()
|
||||||
|
|
||||||
|
# Example 5: Inference
|
||||||
|
example_inference()
|
||||||
|
|
||||||
|
# Example 6: Quick start
|
||||||
|
example_quick_start()
|
||||||
|
|
||||||
|
print("\n" + "=" * 60)
|
||||||
|
print("All examples completed successfully!")
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
print("\nNext steps:")
|
||||||
|
print("1. Run the training script to train a model")
|
||||||
|
print("2. Use the inference script to test the trained model")
|
||||||
|
print("3. Integrate the homography estimation into your autopilot system")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"\nError running examples: {e}")
|
||||||
|
import traceback
|
||||||
|
|
||||||
|
traceback.print_exc()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
2104
models/SiaN/homography.ipynb
Normal file
2104
models/SiaN/homography.ipynb
Normal file
File diff suppressed because one or more lines are too long
434
models/SiaN/homography.py
Normal file
434
models/SiaN/homography.py
Normal file
@@ -0,0 +1,434 @@
|
|||||||
|
import os
|
||||||
|
import random
|
||||||
|
from typing import Any, Dict, List, Optional, Tuple
|
||||||
|
|
||||||
|
import cv2
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from PIL import Image
|
||||||
|
from torch.utils.data import DataLoader, Dataset
|
||||||
|
|
||||||
|
|
||||||
|
class HomographyDataset(Dataset):
|
||||||
|
"""
|
||||||
|
Dataset for homography estimation between Yandex and Google map image pairs.
|
||||||
|
|
||||||
|
This dataset loads pairs of images (Yandex and Google maps) and provides
|
||||||
|
homography matrices for data augmentation and training.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
root_dir: str,
|
||||||
|
transform=None,
|
||||||
|
augment: bool = True,
|
||||||
|
max_samples: Optional[int] = None,
|
||||||
|
image_size: Tuple[int, int] = (700, 700),
|
||||||
|
cache_homographies: bool = True,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Initialize the HomographyDataset.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
root_dir: Directory containing image pairs (format: {idx:04d}_google.png, {idx:04d}_yandex.png)
|
||||||
|
transform: Optional torchvision transforms to apply
|
||||||
|
augment: Whether to apply homography-based data augmentation
|
||||||
|
max_samples: Maximum number of samples to load (None for all)
|
||||||
|
image_size: Target size for images (height, width)
|
||||||
|
cache_homographies: Whether to cache generated homography matrices to disk
|
||||||
|
"""
|
||||||
|
self.root_dir = root_dir
|
||||||
|
self.transform = transform
|
||||||
|
self.augment = augment
|
||||||
|
self.image_size = image_size
|
||||||
|
self.cache_homographies = cache_homographies
|
||||||
|
|
||||||
|
# Find all image pairs
|
||||||
|
self.image_pairs = self._discover_image_pairs()
|
||||||
|
|
||||||
|
if max_samples is not None:
|
||||||
|
self.image_pairs = self.image_pairs[:max_samples]
|
||||||
|
|
||||||
|
print(f"Found {len(self.image_pairs)} image pairs in {root_dir}")
|
||||||
|
|
||||||
|
# Create directory for cached homographies if needed
|
||||||
|
if cache_homographies:
|
||||||
|
self.homography_cache_dir = os.path.join(root_dir, "homography_cache")
|
||||||
|
os.makedirs(self.homography_cache_dir, exist_ok=True)
|
||||||
|
|
||||||
|
def _discover_image_pairs(self) -> List[Dict[str, Any]]:
|
||||||
|
"""Discover all Google-Yandex image pairs in the dataset directory."""
|
||||||
|
image_pairs = []
|
||||||
|
|
||||||
|
# Get all Google images
|
||||||
|
google_files = [
|
||||||
|
f for f in os.listdir(self.root_dir) if f.endswith("_google.png")
|
||||||
|
]
|
||||||
|
|
||||||
|
for google_file in sorted(google_files):
|
||||||
|
# Extract index from filename
|
||||||
|
idx_str = google_file.split("_")[0]
|
||||||
|
try:
|
||||||
|
idx = int(idx_str)
|
||||||
|
except ValueError:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Check if corresponding Yandex image exists
|
||||||
|
yandex_file = f"{idx:04d}_yandex.png"
|
||||||
|
yandex_path = os.path.join(self.root_dir, yandex_file)
|
||||||
|
|
||||||
|
if os.path.exists(yandex_path):
|
||||||
|
image_pairs.append(
|
||||||
|
{
|
||||||
|
"idx": idx,
|
||||||
|
"google_path": os.path.join(self.root_dir, google_file),
|
||||||
|
"yandex_path": yandex_path,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
return image_pairs
|
||||||
|
|
||||||
|
def __len__(self) -> int:
|
||||||
|
"""Return the number of image pairs in the dataset."""
|
||||||
|
return len(self.image_pairs)
|
||||||
|
|
||||||
|
def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
|
||||||
|
"""
|
||||||
|
Get a sample from the dataset.
|
||||||
|
|
||||||
|
Returns a dictionary with:
|
||||||
|
- 'google_img': Google map image tensor
|
||||||
|
- 'yandex_img': Yandex map image tensor
|
||||||
|
- 'homography': Ground truth homography matrix (3x3)
|
||||||
|
- 'idx': Sample index
|
||||||
|
"""
|
||||||
|
pair_info = self.image_pairs[idx]
|
||||||
|
|
||||||
|
# Load images
|
||||||
|
google_img = Image.open(pair_info["google_path"]).convert("RGB")
|
||||||
|
yandex_img = Image.open(pair_info["yandex_path"]).convert("RGB")
|
||||||
|
|
||||||
|
# Resize images to target size
|
||||||
|
google_img = google_img.resize(
|
||||||
|
(self.image_size[1], self.image_size[0]), Image.BILINEAR
|
||||||
|
)
|
||||||
|
yandex_img = yandex_img.resize(
|
||||||
|
(self.image_size[1], self.image_size[0]), Image.BILINEAR
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get or generate homography matrix
|
||||||
|
homography_matrix = self._get_homography_matrix(pair_info["idx"])
|
||||||
|
|
||||||
|
# Apply data augmentation if enabled
|
||||||
|
if self.augment:
|
||||||
|
google_img, yandex_img, homography_matrix = self._apply_augmentation(
|
||||||
|
google_img, yandex_img, homography_matrix
|
||||||
|
)
|
||||||
|
|
||||||
|
# Convert images to tensors
|
||||||
|
if self.transform:
|
||||||
|
google_img = self.transform(google_img)
|
||||||
|
yandex_img = self.transform(yandex_img)
|
||||||
|
else:
|
||||||
|
# Default conversion to tensor
|
||||||
|
google_img = (
|
||||||
|
torch.from_numpy(np.array(google_img)).float().permute(2, 0, 1) / 255.0
|
||||||
|
)
|
||||||
|
yandex_img = (
|
||||||
|
torch.from_numpy(np.array(yandex_img)).float().permute(2, 0, 1) / 255.0
|
||||||
|
)
|
||||||
|
|
||||||
|
# Convert homography to tensor
|
||||||
|
homography_tensor = torch.from_numpy(homography_matrix).float()
|
||||||
|
|
||||||
|
return {
|
||||||
|
"google_img": google_img,
|
||||||
|
"yandex_img": yandex_img,
|
||||||
|
"homography": homography_tensor,
|
||||||
|
"idx": torch.tensor(pair_info["idx"], dtype=torch.long),
|
||||||
|
}
|
||||||
|
|
||||||
|
def _get_homography_matrix(self, idx: int) -> np.ndarray:
|
||||||
|
"""
|
||||||
|
Get homography matrix for a given index.
|
||||||
|
|
||||||
|
If cached homography exists, load it. Otherwise generate a new one.
|
||||||
|
"""
|
||||||
|
if self.cache_homographies:
|
||||||
|
cache_path = os.path.join(
|
||||||
|
self.homography_cache_dir, f"{idx:04d}_homography.npy"
|
||||||
|
)
|
||||||
|
if os.path.exists(cache_path):
|
||||||
|
return np.load(cache_path)
|
||||||
|
|
||||||
|
# Generate new homography matrix
|
||||||
|
homography_matrix = self.generate_random_homography()
|
||||||
|
|
||||||
|
# Cache if enabled
|
||||||
|
if self.cache_homographies:
|
||||||
|
np.save(cache_path, homography_matrix)
|
||||||
|
|
||||||
|
return homography_matrix
|
||||||
|
|
||||||
|
def generate_random_homography(self) -> np.ndarray:
|
||||||
|
"""
|
||||||
|
Generate a random homography matrix for data augmentation.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
np.ndarray: 3x3 homography matrix.
|
||||||
|
"""
|
||||||
|
# Generate random affine transformation parameters
|
||||||
|
angle = np.random.uniform(-30, 30) # rotation in degrees
|
||||||
|
scale = np.random.uniform(0.8, 1.2) # scaling factor
|
||||||
|
tx = np.random.uniform(-50, 50) # translation in x
|
||||||
|
ty = np.random.uniform(-50, 50) # translation in y
|
||||||
|
|
||||||
|
# Convert angle to radians
|
||||||
|
theta = np.radians(angle)
|
||||||
|
|
||||||
|
# Create affine transformation matrix
|
||||||
|
affine_matrix = np.array(
|
||||||
|
[
|
||||||
|
[scale * np.cos(theta), -scale * np.sin(theta), tx],
|
||||||
|
[scale * np.sin(theta), scale * np.cos(theta), ty],
|
||||||
|
[0, 0, 1],
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add small perspective distortion
|
||||||
|
perspective = np.random.uniform(-0.001, 0.001, (2, 3))
|
||||||
|
perspective = np.vstack([perspective, [0, 0, 0]])
|
||||||
|
|
||||||
|
homography_matrix = affine_matrix + perspective
|
||||||
|
|
||||||
|
return homography_matrix
|
||||||
|
|
||||||
|
def _apply_augmentation(
|
||||||
|
self,
|
||||||
|
google_img: Image.Image,
|
||||||
|
yandex_img: Image.Image,
|
||||||
|
base_homography: np.ndarray,
|
||||||
|
) -> Tuple[Image.Image, Image.Image, np.ndarray]:
|
||||||
|
"""
|
||||||
|
Apply homography-based data augmentation to image pair.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
google_img: Google map image
|
||||||
|
yandex_img: Yandex map image
|
||||||
|
base_homography: Base homography matrix
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (augmented_google_img, augmented_yandex_img, augmented_homography)
|
||||||
|
"""
|
||||||
|
# Generate augmentation homography
|
||||||
|
aug_homography = self.generate_random_homography()
|
||||||
|
|
||||||
|
# Combine with base homography
|
||||||
|
combined_homography = aug_homography @ base_homography
|
||||||
|
|
||||||
|
# Apply augmentation to both images
|
||||||
|
google_aug = self._apply_homography_to_image(google_img, aug_homography)
|
||||||
|
yandex_aug = self._apply_homography_to_image(yandex_img, aug_homography)
|
||||||
|
|
||||||
|
return google_aug, yandex_aug, combined_homography
|
||||||
|
|
||||||
|
def _apply_homography_to_image(
|
||||||
|
self, img: Image.Image, homography: np.ndarray
|
||||||
|
) -> Image.Image:
|
||||||
|
"""
|
||||||
|
Apply homography transformation to a single image.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
img: PIL Image to transform
|
||||||
|
homography: 3x3 homography matrix
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Transformed PIL Image
|
||||||
|
"""
|
||||||
|
# Convert to numpy array
|
||||||
|
img_np = np.array(img)
|
||||||
|
|
||||||
|
# Get image dimensions
|
||||||
|
h, w = img_np.shape[:2]
|
||||||
|
|
||||||
|
# Apply homography transformation
|
||||||
|
transformed = cv2.warpPerspective(
|
||||||
|
img_np,
|
||||||
|
homography,
|
||||||
|
(w, h),
|
||||||
|
flags=cv2.INTER_LINEAR,
|
||||||
|
borderMode=cv2.BORDER_REFLECT,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Convert back to PIL Image
|
||||||
|
return Image.fromarray(transformed)
|
||||||
|
|
||||||
|
def get_sample_without_augmentation(self, idx: int) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Get a sample without data augmentation.
|
||||||
|
|
||||||
|
Useful for visualization and evaluation.
|
||||||
|
"""
|
||||||
|
pair_info = self.image_pairs[idx]
|
||||||
|
|
||||||
|
# Load images
|
||||||
|
google_img = Image.open(pair_info["google_path"]).convert("RGB")
|
||||||
|
yandex_img = Image.open(pair_info["yandex_path"]).convert("RGB")
|
||||||
|
|
||||||
|
# Resize
|
||||||
|
google_img = google_img.resize(
|
||||||
|
(self.image_size[1], self.image_size[0]), Image.BILINEAR
|
||||||
|
)
|
||||||
|
yandex_img = yandex_img.resize(
|
||||||
|
(self.image_size[1], self.image_size[0]), Image.BILINEAR
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get homography matrix
|
||||||
|
homography_matrix = self._get_homography_matrix(pair_info["idx"])
|
||||||
|
|
||||||
|
return {
|
||||||
|
"google_img": google_img,
|
||||||
|
"yandex_img": yandex_img,
|
||||||
|
"homography": homography_matrix,
|
||||||
|
"idx": pair_info["idx"],
|
||||||
|
"google_path": pair_info["google_path"],
|
||||||
|
"yandex_path": pair_info["yandex_path"],
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def create_data_loaders(
|
||||||
|
root_dir: str,
|
||||||
|
batch_size: int = 32,
|
||||||
|
train_split: float = 0.8,
|
||||||
|
num_workers: int = 4,
|
||||||
|
image_size: Tuple[int, int] = (256, 256),
|
||||||
|
augment_train: bool = True,
|
||||||
|
augment_val: bool = False,
|
||||||
|
) -> Tuple[DataLoader, DataLoader]:
|
||||||
|
"""
|
||||||
|
Create train and validation data loaders for homography estimation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
root_dir: Directory containing image pairs
|
||||||
|
batch_size: Batch size for data loaders
|
||||||
|
train_split: Fraction of data to use for training
|
||||||
|
num_workers: Number of worker processes for data loading
|
||||||
|
image_size: Target image size (height, width)
|
||||||
|
augment_train: Whether to augment training data
|
||||||
|
augment_val: Whether to augment validation data
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (train_loader, val_loader)
|
||||||
|
"""
|
||||||
|
from torchvision import transforms
|
||||||
|
|
||||||
|
# Define transforms
|
||||||
|
transform = transforms.Compose(
|
||||||
|
[
|
||||||
|
transforms.ToTensor(),
|
||||||
|
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create full dataset
|
||||||
|
full_dataset = HomographyDataset(
|
||||||
|
root_dir=root_dir,
|
||||||
|
transform=transform,
|
||||||
|
augment=False, # We'll handle augmentation separately
|
||||||
|
image_size=image_size,
|
||||||
|
cache_homographies=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Split dataset
|
||||||
|
dataset_size = len(full_dataset)
|
||||||
|
train_size = int(train_split * dataset_size)
|
||||||
|
val_size = dataset_size - train_size
|
||||||
|
|
||||||
|
# Create indices for splitting
|
||||||
|
indices = list(range(dataset_size))
|
||||||
|
random.shuffle(indices)
|
||||||
|
train_indices = indices[:train_size]
|
||||||
|
val_indices = indices[train_size:]
|
||||||
|
|
||||||
|
# Create subset samplers
|
||||||
|
from torch.utils.data import Subset
|
||||||
|
|
||||||
|
train_dataset = Subset(full_dataset, train_indices)
|
||||||
|
val_dataset = Subset(full_dataset, val_indices)
|
||||||
|
|
||||||
|
# Apply augmentation by overriding __getitem__ for train dataset
|
||||||
|
if augment_train:
|
||||||
|
|
||||||
|
class AugmentedSubset(Subset):
|
||||||
|
def __getitem__(self, idx):
|
||||||
|
sample = self.dataset[self.indices[idx]]
|
||||||
|
# Apply augmentation
|
||||||
|
google_img = sample["google_img"]
|
||||||
|
yandex_img = sample["yandex_img"]
|
||||||
|
homography = sample["homography"]
|
||||||
|
|
||||||
|
# Generate augmentation homography
|
||||||
|
aug_homography = torch.from_numpy(
|
||||||
|
full_dataset.generate_random_homography()
|
||||||
|
).float()
|
||||||
|
|
||||||
|
# Combine homographies
|
||||||
|
combined_homography = aug_homography @ homography
|
||||||
|
|
||||||
|
# Apply augmentation (simplified - in practice would warp images)
|
||||||
|
# For now, we just return the combined homography
|
||||||
|
return {
|
||||||
|
"google_img": google_img,
|
||||||
|
"yandex_img": yandex_img,
|
||||||
|
"homography": combined_homography,
|
||||||
|
"idx": sample["idx"],
|
||||||
|
}
|
||||||
|
|
||||||
|
train_dataset = AugmentedSubset(full_dataset, train_indices)
|
||||||
|
|
||||||
|
# Create data loaders
|
||||||
|
train_loader = DataLoader(
|
||||||
|
train_dataset,
|
||||||
|
batch_size=batch_size,
|
||||||
|
shuffle=True,
|
||||||
|
num_workers=num_workers,
|
||||||
|
pin_memory=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
val_loader = DataLoader(
|
||||||
|
val_dataset,
|
||||||
|
batch_size=batch_size,
|
||||||
|
shuffle=False,
|
||||||
|
num_workers=num_workers,
|
||||||
|
pin_memory=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
return train_loader, val_loader
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
# Example usage
|
||||||
|
dataset = HomographyDataset(
|
||||||
|
root_dir=r"C:\Users\admin\Projects\autopilot\datasets\ya_go_maps\images",
|
||||||
|
augment=True,
|
||||||
|
image_size=(256, 256),
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"Dataset size: {len(dataset)}")
|
||||||
|
|
||||||
|
# Get a sample
|
||||||
|
sample = dataset[0]
|
||||||
|
print(f"Sample keys: {list(sample.keys())}")
|
||||||
|
print(f"Google image shape: {sample['google_img'].shape}")
|
||||||
|
print(f"Yandex image shape: {sample['yandex_img'].shape}")
|
||||||
|
print(f"Homography shape: {sample['homography'].shape}")
|
||||||
|
|
||||||
|
# Create data loaders
|
||||||
|
train_loader, val_loader = create_data_loaders(
|
||||||
|
root_dir=r"C:\Users\admin\Projects\autopilot\datasets\ya_go_maps\images",
|
||||||
|
batch_size=16,
|
||||||
|
train_split=0.8,
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"Train batches: {len(train_loader)}")
|
||||||
|
print(f"Val batches: {len(val_loader)}")
|
||||||
551
models/SiaN/homography_cnn.py
Normal file
551
models/SiaN/homography_cnn.py
Normal file
@@ -0,0 +1,551 @@
|
|||||||
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
|
||||||
|
class HomographyCNN(nn.Module):
|
||||||
|
"""
|
||||||
|
CNN model for homography estimation between two images.
|
||||||
|
|
||||||
|
This model takes two images (Google and Yandex maps) as input and
|
||||||
|
outputs a 3x3 homography matrix that transforms one image to align with the other.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
input_channels: int = 3,
|
||||||
|
hidden_channels: int = 64,
|
||||||
|
num_blocks: int = 4,
|
||||||
|
dropout_rate: float = 0.3,
|
||||||
|
use_batch_norm: bool = True,
|
||||||
|
output_size: int = 9, # Flattened 3x3 homography matrix
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Initialize the HomographyCNN model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input_channels: Number of input channels per image (3 for RGB)
|
||||||
|
hidden_channels: Base number of channels in the network
|
||||||
|
num_blocks: Number of convolutional blocks
|
||||||
|
dropout_rate: Dropout rate for regularization
|
||||||
|
use_batch_norm: Whether to use batch normalization
|
||||||
|
output_size: Size of output vector (9 for flattened 3x3 matrix)
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.input_channels = input_channels
|
||||||
|
self.hidden_channels = hidden_channels
|
||||||
|
self.num_blocks = num_blocks
|
||||||
|
self.dropout_rate = dropout_rate
|
||||||
|
self.use_batch_norm = use_batch_norm
|
||||||
|
|
||||||
|
# Feature extraction for each image separately
|
||||||
|
self.google_encoder = self._build_encoder()
|
||||||
|
self.yandex_encoder = self._build_encoder()
|
||||||
|
|
||||||
|
# Fusion layers to combine features from both images
|
||||||
|
self.fusion_layers = self._build_fusion_layers()
|
||||||
|
|
||||||
|
# Regression head for homography estimation
|
||||||
|
self.regression_head = self._build_regression_head(output_size)
|
||||||
|
|
||||||
|
# Initialize weights
|
||||||
|
self._initialize_weights()
|
||||||
|
|
||||||
|
def _build_encoder(self) -> nn.Module:
|
||||||
|
"""Build the encoder network for a single image."""
|
||||||
|
layers = []
|
||||||
|
in_channels = self.input_channels
|
||||||
|
out_channels = self.hidden_channels
|
||||||
|
|
||||||
|
# First convolutional block
|
||||||
|
layers.append(
|
||||||
|
nn.Conv2d(in_channels, out_channels, kernel_size=7, stride=2, padding=3)
|
||||||
|
)
|
||||||
|
if self.use_batch_norm:
|
||||||
|
layers.append(nn.BatchNorm2d(out_channels))
|
||||||
|
layers.append(nn.ReLU(inplace=True))
|
||||||
|
layers.append(nn.MaxPool2d(kernel_size=3, stride=2, padding=1))
|
||||||
|
|
||||||
|
# Additional convolutional blocks
|
||||||
|
for i in range(self.num_blocks):
|
||||||
|
block_in_channels = out_channels
|
||||||
|
block_out_channels = out_channels * 2 if i < 2 else out_channels
|
||||||
|
|
||||||
|
layers.append(
|
||||||
|
ResidualBlock(
|
||||||
|
in_channels=block_in_channels,
|
||||||
|
out_channels=block_out_channels,
|
||||||
|
stride=1 if i == 0 else 2,
|
||||||
|
dropout_rate=self.dropout_rate,
|
||||||
|
use_batch_norm=self.use_batch_norm,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
if i < 2:
|
||||||
|
out_channels = block_out_channels
|
||||||
|
|
||||||
|
return nn.Sequential(*layers)
|
||||||
|
|
||||||
|
def _build_fusion_layers(self) -> nn.Module:
|
||||||
|
"""Build layers to fuse features from both images."""
|
||||||
|
# After encoding, each image has hidden_channels * 4 features
|
||||||
|
fused_channels = (
|
||||||
|
self.hidden_channels * 8
|
||||||
|
) # Concatenated features from both images
|
||||||
|
|
||||||
|
layers = [
|
||||||
|
# Reduce dimensionality
|
||||||
|
nn.Conv2d(
|
||||||
|
fused_channels, self.hidden_channels * 4, kernel_size=3, padding=1
|
||||||
|
),
|
||||||
|
nn.BatchNorm2d(self.hidden_channels * 4)
|
||||||
|
if self.use_batch_norm
|
||||||
|
else nn.Identity(),
|
||||||
|
nn.ReLU(inplace=True),
|
||||||
|
nn.Dropout2d(self.dropout_rate),
|
||||||
|
# Further processing
|
||||||
|
nn.Conv2d(
|
||||||
|
self.hidden_channels * 4,
|
||||||
|
self.hidden_channels * 2,
|
||||||
|
kernel_size=3,
|
||||||
|
padding=1,
|
||||||
|
),
|
||||||
|
nn.BatchNorm2d(self.hidden_channels * 2)
|
||||||
|
if self.use_batch_norm
|
||||||
|
else nn.Identity(),
|
||||||
|
nn.ReLU(inplace=True),
|
||||||
|
nn.Dropout2d(self.dropout_rate),
|
||||||
|
# Global pooling
|
||||||
|
nn.AdaptiveAvgPool2d((1, 1)),
|
||||||
|
]
|
||||||
|
|
||||||
|
return nn.Sequential(*layers)
|
||||||
|
|
||||||
|
def _build_regression_head(self, output_size: int) -> nn.Module:
|
||||||
|
"""Build the regression head for homography estimation."""
|
||||||
|
# Input size after fusion and global pooling
|
||||||
|
input_features = self.hidden_channels * 2
|
||||||
|
|
||||||
|
layers = [
|
||||||
|
nn.Flatten(),
|
||||||
|
nn.Linear(input_features, 512),
|
||||||
|
nn.BatchNorm1d(512) if self.use_batch_norm else nn.Identity(),
|
||||||
|
nn.ReLU(inplace=True),
|
||||||
|
nn.Dropout(self.dropout_rate),
|
||||||
|
nn.Linear(512, 256),
|
||||||
|
nn.BatchNorm1d(256) if self.use_batch_norm else nn.Identity(),
|
||||||
|
nn.ReLU(inplace=True),
|
||||||
|
nn.Dropout(self.dropout_rate),
|
||||||
|
nn.Linear(256, 128),
|
||||||
|
nn.BatchNorm1d(128) if self.use_batch_norm else nn.Identity(),
|
||||||
|
nn.ReLU(inplace=True),
|
||||||
|
nn.Dropout(self.dropout_rate),
|
||||||
|
nn.Linear(128, output_size),
|
||||||
|
]
|
||||||
|
|
||||||
|
return nn.Sequential(*layers)
|
||||||
|
|
||||||
|
def _initialize_weights(self):
|
||||||
|
"""Initialize model weights."""
|
||||||
|
for m in self.modules():
|
||||||
|
if isinstance(m, nn.Conv2d):
|
||||||
|
nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
|
||||||
|
if m.bias is not None:
|
||||||
|
nn.init.constant_(m.bias, 0)
|
||||||
|
elif isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d):
|
||||||
|
nn.init.constant_(m.weight, 1)
|
||||||
|
nn.init.constant_(m.bias, 0)
|
||||||
|
elif isinstance(m, nn.Linear):
|
||||||
|
nn.init.normal_(m.weight, 0, 0.01)
|
||||||
|
nn.init.constant_(m.bias, 0)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
google_img: torch.Tensor,
|
||||||
|
yandex_img: torch.Tensor,
|
||||||
|
return_matrix: bool = True,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Forward pass of the model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
google_img: Google map image tensor of shape (B, C, H, W)
|
||||||
|
yandex_img: Yandex map image tensor of shape (B, C, H, W)
|
||||||
|
return_matrix: If True, return 3x3 matrix; if False, return flattened vector
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Homography matrix tensor of shape (B, 3, 3) or flattened vector of shape (B, 9)
|
||||||
|
"""
|
||||||
|
# Extract features from both images
|
||||||
|
google_features = self.google_encoder(google_img)
|
||||||
|
yandex_features = self.yandex_encoder(yandex_img)
|
||||||
|
|
||||||
|
# Concatenate features along channel dimension
|
||||||
|
combined_features = torch.cat([google_features, yandex_features], dim=1)
|
||||||
|
|
||||||
|
# Fuse features
|
||||||
|
fused_features = self.fusion_layers(combined_features)
|
||||||
|
|
||||||
|
# Regression to get homography parameters
|
||||||
|
homography_flat = self.regression_head(fused_features)
|
||||||
|
|
||||||
|
if return_matrix:
|
||||||
|
# Reshape to 3x3 matrix
|
||||||
|
batch_size = homography_flat.shape[0]
|
||||||
|
homography_matrix = homography_flat.view(batch_size, 3, 3)
|
||||||
|
|
||||||
|
# Ensure the last element is 1 (homogeneous coordinate normalization)
|
||||||
|
# Add small epsilon to prevent division by zero
|
||||||
|
epsilon = 1e-8
|
||||||
|
homography_matrix = homography_matrix / (
|
||||||
|
homography_matrix[:, 2, 2].view(-1, 1, 1) + epsilon
|
||||||
|
)
|
||||||
|
|
||||||
|
return homography_matrix
|
||||||
|
else:
|
||||||
|
return homography_flat
|
||||||
|
|
||||||
|
def predict_homography(
|
||||||
|
self,
|
||||||
|
google_img: torch.Tensor,
|
||||||
|
yandex_img: torch.Tensor,
|
||||||
|
normalize: bool = True,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Predict homography matrix with optional normalization.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
google_img: Google map image tensor
|
||||||
|
yandex_img: Yandex map image tensor
|
||||||
|
normalize: Whether to normalize the homography matrix
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Predicted homography matrix
|
||||||
|
"""
|
||||||
|
self.eval()
|
||||||
|
with torch.no_grad():
|
||||||
|
homography = self.forward(google_img, yandex_img, return_matrix=True)
|
||||||
|
|
||||||
|
if normalize:
|
||||||
|
# Normalize so that last element is 1
|
||||||
|
homography = homography / homography[:, 2, 2].view(-1, 1, 1)
|
||||||
|
|
||||||
|
return homography
|
||||||
|
|
||||||
|
|
||||||
|
class ResidualBlock(nn.Module):
|
||||||
|
"""Residual block with optional downsampling."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels: int,
|
||||||
|
out_channels: int,
|
||||||
|
stride: int = 1,
|
||||||
|
dropout_rate: float = 0.3,
|
||||||
|
use_batch_norm: bool = True,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.conv1 = nn.Conv2d(
|
||||||
|
in_channels,
|
||||||
|
out_channels,
|
||||||
|
kernel_size=3,
|
||||||
|
stride=stride,
|
||||||
|
padding=1,
|
||||||
|
bias=False,
|
||||||
|
)
|
||||||
|
self.bn1 = nn.BatchNorm2d(out_channels) if use_batch_norm else nn.Identity()
|
||||||
|
self.relu1 = nn.ReLU(inplace=True)
|
||||||
|
self.dropout1 = (
|
||||||
|
nn.Dropout2d(dropout_rate) if dropout_rate > 0 else nn.Identity()
|
||||||
|
)
|
||||||
|
|
||||||
|
self.conv2 = nn.Conv2d(
|
||||||
|
out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False
|
||||||
|
)
|
||||||
|
self.bn2 = nn.BatchNorm2d(out_channels) if use_batch_norm else nn.Identity()
|
||||||
|
self.relu2 = nn.ReLU(inplace=True)
|
||||||
|
self.dropout2 = (
|
||||||
|
nn.Dropout2d(dropout_rate) if dropout_rate > 0 else nn.Identity()
|
||||||
|
)
|
||||||
|
|
||||||
|
# Shortcut connection
|
||||||
|
self.shortcut = nn.Sequential()
|
||||||
|
if stride != 1 or in_channels != out_channels:
|
||||||
|
self.shortcut = nn.Sequential(
|
||||||
|
nn.Conv2d(
|
||||||
|
in_channels, out_channels, kernel_size=1, stride=stride, bias=False
|
||||||
|
),
|
||||||
|
nn.BatchNorm2d(out_channels) if use_batch_norm else nn.Identity(),
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
identity = self.shortcut(x)
|
||||||
|
|
||||||
|
out = self.conv1(x)
|
||||||
|
out = self.bn1(out)
|
||||||
|
out = self.relu1(out)
|
||||||
|
out = self.dropout1(out)
|
||||||
|
|
||||||
|
out = self.conv2(out)
|
||||||
|
out = self.bn2(out)
|
||||||
|
|
||||||
|
out += identity
|
||||||
|
out = self.relu2(out)
|
||||||
|
out = self.dropout2(out)
|
||||||
|
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class HomographyLoss(nn.Module):
|
||||||
|
"""
|
||||||
|
Custom loss function for homography estimation.
|
||||||
|
|
||||||
|
Combines multiple loss terms:
|
||||||
|
1. Matrix element-wise L2 loss
|
||||||
|
2. Geometric consistency loss (warping error)
|
||||||
|
3. Determinant regularization (to prevent degenerate matrices)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
matrix_weight: float = 1.0,
|
||||||
|
geometric_weight: float = 0.5,
|
||||||
|
reg_weight: float = 0.1,
|
||||||
|
grid_size: int = 8,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.matrix_weight = matrix_weight
|
||||||
|
self.geometric_weight = geometric_weight
|
||||||
|
self.reg_weight = reg_weight
|
||||||
|
self.grid_size = grid_size
|
||||||
|
|
||||||
|
# Create grid of points for geometric loss
|
||||||
|
self.register_buffer(
|
||||||
|
"grid_points",
|
||||||
|
self._create_grid_points(grid_size),
|
||||||
|
persistent=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _create_grid_points(self, grid_size: int) -> torch.Tensor:
|
||||||
|
"""Create a grid of points for geometric consistency loss."""
|
||||||
|
x = torch.linspace(-1, 1, grid_size)
|
||||||
|
y = torch.linspace(-1, 1, grid_size)
|
||||||
|
grid_y, grid_x = torch.meshgrid(y, x, indexing="ij")
|
||||||
|
grid_points = torch.stack([grid_x.flatten(), grid_y.flatten()], dim=1)
|
||||||
|
# Add homogeneous coordinate
|
||||||
|
ones = torch.ones(grid_points.shape[0], 1)
|
||||||
|
grid_points = torch.cat([grid_points, ones], dim=1)
|
||||||
|
return grid_points.T # Shape: (3, grid_size*grid_size)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
pred_homography: torch.Tensor,
|
||||||
|
target_homography: torch.Tensor,
|
||||||
|
google_img: Optional[torch.Tensor] = None,
|
||||||
|
yandex_img: Optional[torch.Tensor] = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Compute homography loss.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pred_homography: Predicted homography matrices (B, 3, 3)
|
||||||
|
target_homography: Target homography matrices (B, 3, 3)
|
||||||
|
google_img: Google images (optional, for geometric loss)
|
||||||
|
yandex_img: Yandex images (optional, for geometric loss)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Combined loss value
|
||||||
|
"""
|
||||||
|
batch_size = pred_homography.shape[0]
|
||||||
|
|
||||||
|
# 1. Matrix element-wise L2 loss
|
||||||
|
matrix_loss = F.mse_loss(pred_homography, target_homography)
|
||||||
|
|
||||||
|
# 2. Geometric consistency loss (if images provided)
|
||||||
|
geometric_loss = torch.tensor(0.0, device=pred_homography.device)
|
||||||
|
if google_img is not None and yandex_img is not None:
|
||||||
|
# Warp grid points with predicted homography
|
||||||
|
grid_points = self.grid_points.unsqueeze(0).expand(batch_size, -1, -1)
|
||||||
|
warped_points = torch.bmm(pred_homography, grid_points)
|
||||||
|
|
||||||
|
# Normalize homogeneous coordinates
|
||||||
|
warped_points = warped_points / (warped_points[:, 2:3, :] + 1e-8)
|
||||||
|
|
||||||
|
# Warp with target homography for comparison
|
||||||
|
target_warped_points = torch.bmm(target_homography, grid_points)
|
||||||
|
target_warped_points = target_warped_points / (
|
||||||
|
target_warped_points[:, 2:3, :] + 1e-8
|
||||||
|
)
|
||||||
|
|
||||||
|
# Compute point-wise distance
|
||||||
|
geometric_loss = F.mse_loss(
|
||||||
|
warped_points[:, :2, :], target_warped_points[:, :2, :]
|
||||||
|
)
|
||||||
|
|
||||||
|
# 3. Regularization loss (prevent degenerate matrices)
|
||||||
|
# Encourage determinant to be close to 1
|
||||||
|
pred_det = torch.det(pred_homography)
|
||||||
|
reg_loss = F.mse_loss(pred_det, torch.ones_like(pred_det))
|
||||||
|
|
||||||
|
# Combine losses
|
||||||
|
total_loss = (
|
||||||
|
self.matrix_weight * matrix_loss
|
||||||
|
+ self.geometric_weight * geometric_loss
|
||||||
|
+ self.reg_weight * reg_loss
|
||||||
|
)
|
||||||
|
|
||||||
|
return total_loss
|
||||||
|
|
||||||
|
def compute_metrics(
|
||||||
|
self,
|
||||||
|
pred_homography: torch.Tensor,
|
||||||
|
target_homography: torch.Tensor,
|
||||||
|
) -> dict:
|
||||||
|
"""
|
||||||
|
Compute evaluation metrics for homography estimation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pred_homography: Predicted homography matrices
|
||||||
|
target_homography: Target homography matrices
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary of metrics
|
||||||
|
"""
|
||||||
|
with torch.no_grad():
|
||||||
|
# Normalize matrices
|
||||||
|
pred_norm = pred_homography / pred_homography[:, 2, 2].view(-1, 1, 1)
|
||||||
|
target_norm = target_homography / target_homography[:, 2, 2].view(-1, 1, 1)
|
||||||
|
|
||||||
|
# Matrix L2 error
|
||||||
|
matrix_error = F.mse_loss(pred_norm, target_norm, reduction="none").mean(
|
||||||
|
dim=(1, 2)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Corner error (warp 4 corners of the image)
|
||||||
|
corners = torch.tensor(
|
||||||
|
[[-1, -1, 1], [1, -1, 1], [1, 1, 1], [-1, 1, 1]],
|
||||||
|
dtype=torch.float32,
|
||||||
|
device=pred_homography.device,
|
||||||
|
).T # Shape: (3, 4)
|
||||||
|
|
||||||
|
corners = corners.unsqueeze(0).expand(pred_homography.shape[0], -1, -1)
|
||||||
|
|
||||||
|
pred_corners = torch.bmm(pred_norm, corners)
|
||||||
|
pred_corners = pred_corners / (pred_corners[:, 2:3, :] + 1e-8)
|
||||||
|
|
||||||
|
target_corners = torch.bmm(target_norm, corners)
|
||||||
|
target_corners = target_corners / (target_corners[:, 2:3, :] + 1e-8)
|
||||||
|
|
||||||
|
corner_error = torch.mean(
|
||||||
|
torch.norm(pred_corners[:, :2, :] - target_corners[:, :2, :], dim=1),
|
||||||
|
dim=1,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Average corner error in pixels (assuming image coordinates in [-1, 1])
|
||||||
|
# Convert to pixel error if image size is known
|
||||||
|
avg_corner_error = corner_error.mean().item()
|
||||||
|
|
||||||
|
return {
|
||||||
|
"matrix_mse": matrix_error.mean().item(),
|
||||||
|
"corner_error": avg_corner_error,
|
||||||
|
"corner_error_px": avg_corner_error * 128, # Assuming 256x256 images
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def create_homography_model(
|
||||||
|
model_type: str = "cnn",
|
||||||
|
input_size: Tuple[int, int] = (256, 256),
|
||||||
|
**kwargs,
|
||||||
|
) -> nn.Module:
|
||||||
|
"""
|
||||||
|
Factory function to create homography estimation model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_type: Type of model to create ('cnn' or 'resnet')
|
||||||
|
input_size: Input image size (height, width)
|
||||||
|
**kwargs: Additional arguments passed to model constructor
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Homography estimation model
|
||||||
|
"""
|
||||||
|
if model_type == "cnn":
|
||||||
|
return HomographyCNN(**kwargs)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown model type: {model_type}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
# Test the model
|
||||||
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
print(f"Using device: {device}")
|
||||||
|
|
||||||
|
# Create model
|
||||||
|
model = HomographyCNN(
|
||||||
|
input_channels=3,
|
||||||
|
hidden_channels=64,
|
||||||
|
num_blocks=4,
|
||||||
|
dropout_rate=0.3,
|
||||||
|
use_batch_norm=True,
|
||||||
|
).to(device)
|
||||||
|
|
||||||
|
print(
|
||||||
|
f"Model created with {sum(p.numel() for p in model.parameters()):,} parameters"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create dummy input
|
||||||
|
batch_size = 4
|
||||||
|
height, width = 256, 256
|
||||||
|
|
||||||
|
google_img = torch.randn(batch_size, 3, height, width).to(device)
|
||||||
|
yandex_img = torch.randn(batch_size, 3, height, width).to(device)
|
||||||
|
|
||||||
|
# Test forward pass
|
||||||
|
print("\nTesting forward pass...")
|
||||||
|
output = model(google_img, yandex_img, return_matrix=True)
|
||||||
|
print(f"Output shape: {output.shape}") # Should be (4, 3, 3)
|
||||||
|
print(f"Sample output:\n{output[0]}")
|
||||||
|
|
||||||
|
# Test prediction
|
||||||
|
print("\nTesting prediction...")
|
||||||
|
pred = model.predict_homography(google_img, yandex_img)
|
||||||
|
print(f"Prediction shape: {pred.shape}")
|
||||||
|
print(f"Last element (should be ~1): {pred[0, 2, 2]:.6f}")
|
||||||
|
|
||||||
|
# Test loss function
|
||||||
|
print("\nTesting loss function...")
|
||||||
|
target_homography = torch.eye(3).unsqueeze(0).expand(batch_size, -1, -1).to(device)
|
||||||
|
loss_fn = HomographyLoss(
|
||||||
|
matrix_weight=1.0,
|
||||||
|
geometric_weight=0.5,
|
||||||
|
reg_weight=0.1,
|
||||||
|
grid_size=8,
|
||||||
|
).to(device)
|
||||||
|
|
||||||
|
loss = loss_fn(output, target_homography, google_img, yandex_img)
|
||||||
|
print(f"Loss value: {loss.item():.6f}")
|
||||||
|
|
||||||
|
# Test metrics
|
||||||
|
print("\nTesting metrics...")
|
||||||
|
metrics = loss_fn.compute_metrics(output, target_homography)
|
||||||
|
for key, value in metrics.items():
|
||||||
|
print(f"{key}: {value:.6f}")
|
||||||
|
|
||||||
|
# Test model factory
|
||||||
|
print("\nTesting model factory...")
|
||||||
|
model2 = create_homography_model(
|
||||||
|
model_type="cnn",
|
||||||
|
input_size=(256, 256),
|
||||||
|
input_channels=3,
|
||||||
|
hidden_channels=32,
|
||||||
|
num_blocks=3,
|
||||||
|
).to(device)
|
||||||
|
|
||||||
|
print(
|
||||||
|
f"Model2 created with {sum(p.numel() for p in model2.parameters()):,} parameters"
|
||||||
|
)
|
||||||
|
|
||||||
|
print("\nAll tests completed successfully!")
|
||||||
553
models/SiaN/infer_homography.py
Normal file
553
models/SiaN/infer_homography.py
Normal file
@@ -0,0 +1,553 @@
|
|||||||
|
"""
|
||||||
|
Inference script for homography estimation between Google and Yandex map images.
|
||||||
|
|
||||||
|
This script loads a trained homography estimation model and performs inference
|
||||||
|
on new image pairs or the test dataset.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Dict, List, Optional, Tuple
|
||||||
|
|
||||||
|
import cv2
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from homography import HomographyDataset
|
||||||
|
from homography_cnn import HomographyCNN, create_homography_model
|
||||||
|
from PIL import Image
|
||||||
|
from torchvision import transforms
|
||||||
|
|
||||||
|
|
||||||
|
class HomographyInference:
|
||||||
|
"""Class for performing inference with homography estimation model."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model_path: str,
|
||||||
|
config_path: Optional[str] = None,
|
||||||
|
device: Optional[str] = None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Initialize the inference class.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_path: Path to trained model checkpoint
|
||||||
|
config_path: Path to model configuration file (optional)
|
||||||
|
device: Device to run inference on ('cuda' or 'cpu')
|
||||||
|
"""
|
||||||
|
# Set device
|
||||||
|
if device is None:
|
||||||
|
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
else:
|
||||||
|
self.device = torch.device(device)
|
||||||
|
|
||||||
|
print(f"Using device: {self.device}")
|
||||||
|
|
||||||
|
# Load configuration
|
||||||
|
if config_path is None:
|
||||||
|
# Try to find config in the same directory as model
|
||||||
|
model_dir = Path(model_path).parent
|
||||||
|
config_path = model_dir / "config.json"
|
||||||
|
|
||||||
|
if os.path.exists(config_path):
|
||||||
|
with open(config_path, "r") as f:
|
||||||
|
self.config = json.load(f)
|
||||||
|
print(f"Loaded configuration from {config_path}")
|
||||||
|
else:
|
||||||
|
# Use default configuration
|
||||||
|
self.config = {
|
||||||
|
"image_size": [256, 256],
|
||||||
|
"hidden_channels": 64,
|
||||||
|
"num_blocks": 4,
|
||||||
|
"dropout_rate": 0.3,
|
||||||
|
"use_batch_norm": True,
|
||||||
|
}
|
||||||
|
print("Using default configuration")
|
||||||
|
|
||||||
|
# Create model
|
||||||
|
self.model = self._create_model()
|
||||||
|
self._load_model(model_path)
|
||||||
|
|
||||||
|
# Set up transforms
|
||||||
|
self.transform = self._create_transforms()
|
||||||
|
|
||||||
|
# Set model to evaluation mode
|
||||||
|
self.model.eval()
|
||||||
|
|
||||||
|
def _create_model(self) -> HomographyCNN:
|
||||||
|
"""Create model based on configuration."""
|
||||||
|
image_size = self.config.get("image_size", [256, 256])
|
||||||
|
|
||||||
|
model = create_homography_model(
|
||||||
|
model_type="cnn",
|
||||||
|
input_size=tuple(image_size),
|
||||||
|
input_channels=3,
|
||||||
|
hidden_channels=self.config.get("hidden_channels", 64),
|
||||||
|
num_blocks=self.config.get("num_blocks", 4),
|
||||||
|
dropout_rate=self.config.get("dropout_rate", 0.3),
|
||||||
|
use_batch_norm=self.config.get("use_batch_norm", True),
|
||||||
|
)
|
||||||
|
|
||||||
|
return model
|
||||||
|
|
||||||
|
def _load_model(self, model_path: str):
|
||||||
|
"""Load model weights from checkpoint."""
|
||||||
|
if not os.path.exists(model_path):
|
||||||
|
raise FileNotFoundError(f"Model file not found: {model_path}")
|
||||||
|
|
||||||
|
# Load checkpoint
|
||||||
|
checkpoint = torch.load(model_path, map_location=self.device)
|
||||||
|
|
||||||
|
if isinstance(checkpoint, dict) and "model_state_dict" in checkpoint:
|
||||||
|
# Trainer checkpoint format
|
||||||
|
self.model.load_state_dict(checkpoint["model_state_dict"])
|
||||||
|
else:
|
||||||
|
# Raw model weights format
|
||||||
|
self.model.load_state_dict(checkpoint)
|
||||||
|
|
||||||
|
self.model = self.model.to(self.device)
|
||||||
|
print(f"Loaded model from {model_path}")
|
||||||
|
|
||||||
|
def _create_transforms(self):
|
||||||
|
"""Create image transforms for inference."""
|
||||||
|
return transforms.Compose(
|
||||||
|
[
|
||||||
|
transforms.Resize(tuple(self.config.get("image_size", [256, 256]))),
|
||||||
|
transforms.ToTensor(),
|
||||||
|
transforms.Normalize(
|
||||||
|
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
|
||||||
|
),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
def preprocess_images(
|
||||||
|
self, google_img: Image.Image, yandex_img: Image.Image
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
"""
|
||||||
|
Preprocess images for inference.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
google_img: Google map image (PIL Image)
|
||||||
|
yandex_img: Yandex map image (PIL Image)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of preprocessed image tensors
|
||||||
|
"""
|
||||||
|
# Convert to RGB if needed
|
||||||
|
if google_img.mode != "RGB":
|
||||||
|
google_img = google_img.convert("RGB")
|
||||||
|
if yandex_img.mode != "RGB":
|
||||||
|
yandex_img = yandex_img.convert("RGB")
|
||||||
|
|
||||||
|
# Apply transforms
|
||||||
|
google_tensor = self.transform(google_img).unsqueeze(0) # Add batch dimension
|
||||||
|
yandex_tensor = self.transform(yandex_img).unsqueeze(0)
|
||||||
|
|
||||||
|
return google_tensor, yandex_tensor
|
||||||
|
|
||||||
|
def predict(
|
||||||
|
self,
|
||||||
|
google_img: Image.Image,
|
||||||
|
yandex_img: Image.Image,
|
||||||
|
return_matrix: bool = True,
|
||||||
|
normalize: bool = True,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Predict homography matrix for image pair.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
google_img: Google map image (PIL Image)
|
||||||
|
yandex_img: Yandex map image (PIL Image)
|
||||||
|
return_matrix: If True, return 3x3 matrix; if False, return flattened vector
|
||||||
|
normalize: Whether to normalize the homography matrix
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Predicted homography matrix or vector
|
||||||
|
"""
|
||||||
|
# Preprocess images
|
||||||
|
google_tensor, yandex_tensor = self.preprocess_images(google_img, yandex_img)
|
||||||
|
|
||||||
|
# Move to device
|
||||||
|
google_tensor = google_tensor.to(self.device)
|
||||||
|
yandex_tensor = yandex_tensor.to(self.device)
|
||||||
|
|
||||||
|
# Perform inference
|
||||||
|
with torch.no_grad():
|
||||||
|
homography = self.model(
|
||||||
|
google_tensor, yandex_tensor, return_matrix=return_matrix
|
||||||
|
)
|
||||||
|
|
||||||
|
if return_matrix and normalize:
|
||||||
|
# Normalize so that last element is 1
|
||||||
|
homography = homography / homography[:, 2, 2].view(-1, 1, 1)
|
||||||
|
|
||||||
|
return homography.squeeze(0) # Remove batch dimension
|
||||||
|
|
||||||
|
def predict_from_paths(
|
||||||
|
self,
|
||||||
|
google_path: str,
|
||||||
|
yandex_path: str,
|
||||||
|
return_matrix: bool = True,
|
||||||
|
normalize: bool = True,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Predict homography matrix from image file paths.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
google_path: Path to Google map image
|
||||||
|
yandex_path: Path to Yandex map image
|
||||||
|
return_matrix: If True, return 3x3 matrix; if False, return flattened vector
|
||||||
|
normalize: Whether to normalize the homography matrix
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Predicted homography matrix or vector
|
||||||
|
"""
|
||||||
|
# Load images
|
||||||
|
google_img = Image.open(google_path)
|
||||||
|
yandex_img = Image.open(yandex_path)
|
||||||
|
|
||||||
|
return self.predict(google_img, yandex_img, return_matrix, normalize)
|
||||||
|
|
||||||
|
def warp_image(
|
||||||
|
self,
|
||||||
|
img: Image.Image,
|
||||||
|
homography: np.ndarray,
|
||||||
|
output_size: Optional[Tuple[int, int]] = None,
|
||||||
|
) -> Image.Image:
|
||||||
|
"""
|
||||||
|
Warp image using homography matrix.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
img: Input image (PIL Image)
|
||||||
|
homography: 3x3 homography matrix (numpy array)
|
||||||
|
output_size: Output image size (width, height). If None, uses input size.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Warped image (PIL Image)
|
||||||
|
"""
|
||||||
|
# Convert to numpy array
|
||||||
|
img_np = np.array(img)
|
||||||
|
|
||||||
|
# Get output size
|
||||||
|
if output_size is None:
|
||||||
|
output_size = (img_np.shape[1], img_np.shape[0])
|
||||||
|
|
||||||
|
# Apply homography transformation
|
||||||
|
warped_np = cv2.warpPerspective(
|
||||||
|
img_np,
|
||||||
|
homography,
|
||||||
|
output_size,
|
||||||
|
flags=cv2.INTER_LINEAR,
|
||||||
|
borderMode=cv2.BORDER_REFLECT,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Convert back to PIL Image
|
||||||
|
return Image.fromarray(warped_np)
|
||||||
|
|
||||||
|
def visualize_alignment(
|
||||||
|
self,
|
||||||
|
google_img: Image.Image,
|
||||||
|
yandex_img: Image.Image,
|
||||||
|
homography: np.ndarray,
|
||||||
|
save_path: Optional[str] = None,
|
||||||
|
show: bool = True,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Visualize alignment between images using homography.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
google_img: Google map image
|
||||||
|
yandex_img: Yandex map image
|
||||||
|
homography: Homography matrix
|
||||||
|
save_path: Path to save visualization (optional)
|
||||||
|
show: Whether to display the visualization
|
||||||
|
"""
|
||||||
|
# Warp yandex image to align with google
|
||||||
|
yandex_warped = self.warp_image(yandex_img, homography)
|
||||||
|
|
||||||
|
# Convert images to numpy arrays for visualization
|
||||||
|
google_np = np.array(google_img)
|
||||||
|
yandex_np = np.array(yandex_img)
|
||||||
|
yandex_warped_np = np.array(yandex_warped)
|
||||||
|
|
||||||
|
# Create visualization
|
||||||
|
fig, axes = plt.subplots(2, 2, figsize=(12, 10))
|
||||||
|
|
||||||
|
# Original images
|
||||||
|
axes[0, 0].imshow(google_np)
|
||||||
|
axes[0, 0].set_title("Google Map (Original)")
|
||||||
|
axes[0, 0].axis("off")
|
||||||
|
|
||||||
|
axes[0, 1].imshow(yandex_np)
|
||||||
|
axes[0, 1].set_title("Yandex Map (Original)")
|
||||||
|
axes[0, 1].axis("off")
|
||||||
|
|
||||||
|
# Warped image
|
||||||
|
axes[1, 0].imshow(yandex_warped_np)
|
||||||
|
axes[1, 0].set_title("Yandex Map (Warped)")
|
||||||
|
axes[1, 0].axis("off")
|
||||||
|
|
||||||
|
# Overlay (50% transparency)
|
||||||
|
overlay = cv2.addWeighted(google_np, 0.5, yandex_warped_np, 0.5, 0)
|
||||||
|
axes[1, 1].imshow(overlay)
|
||||||
|
axes[1, 1].set_title("Overlay (Google + Warped Yandex)")
|
||||||
|
axes[1, 1].axis("off")
|
||||||
|
|
||||||
|
plt.tight_layout()
|
||||||
|
|
||||||
|
if save_path:
|
||||||
|
plt.savefig(save_path, dpi=150, bbox_inches="tight")
|
||||||
|
print(f"Visualization saved to {save_path}")
|
||||||
|
|
||||||
|
if show:
|
||||||
|
plt.show()
|
||||||
|
else:
|
||||||
|
plt.close()
|
||||||
|
|
||||||
|
def evaluate_on_dataset(
|
||||||
|
self,
|
||||||
|
dataset_dir: str,
|
||||||
|
num_samples: Optional[int] = None,
|
||||||
|
save_dir: Optional[str] = None,
|
||||||
|
) -> Dict[str, float]:
|
||||||
|
"""
|
||||||
|
Evaluate model on a dataset.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dataset_dir: Directory containing image pairs
|
||||||
|
num_samples: Number of samples to evaluate (None for all)
|
||||||
|
save_dir: Directory to save visualizations (optional)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary of evaluation metrics
|
||||||
|
"""
|
||||||
|
# Create dataset
|
||||||
|
dataset = HomographyDataset(
|
||||||
|
root_dir=dataset_dir,
|
||||||
|
transform=None, # We'll handle transforms manually
|
||||||
|
augment=False,
|
||||||
|
image_size=tuple(self.config.get("image_size", [256, 256])),
|
||||||
|
cache_homographies=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
if num_samples is not None:
|
||||||
|
indices = list(range(min(num_samples, len(dataset))))
|
||||||
|
else:
|
||||||
|
indices = list(range(len(dataset)))
|
||||||
|
|
||||||
|
errors = []
|
||||||
|
corner_errors = []
|
||||||
|
|
||||||
|
print(f"Evaluating on {len(indices)} samples...")
|
||||||
|
|
||||||
|
for idx in indices:
|
||||||
|
# Get sample without augmentation
|
||||||
|
sample = dataset.get_sample_without_augmentation(idx)
|
||||||
|
|
||||||
|
google_img = sample["google_img"]
|
||||||
|
yandex_img = sample["yandex_img"]
|
||||||
|
target_homography = sample["homography"]
|
||||||
|
|
||||||
|
# Predict homography
|
||||||
|
pred_homography = self.predict(
|
||||||
|
google_img, yandex_img, return_matrix=True, normalize=True
|
||||||
|
)
|
||||||
|
|
||||||
|
# Convert to numpy
|
||||||
|
pred_homography_np = pred_homography.cpu().numpy()
|
||||||
|
target_homography_np = target_homography
|
||||||
|
|
||||||
|
# Compute matrix error
|
||||||
|
matrix_error = np.mean((pred_homography_np - target_homography_np) ** 2)
|
||||||
|
errors.append(matrix_error)
|
||||||
|
|
||||||
|
# Compute corner error
|
||||||
|
corners = np.array(
|
||||||
|
[
|
||||||
|
[-1, -1, 1],
|
||||||
|
[1, -1, 1],
|
||||||
|
[1, 1, 1],
|
||||||
|
[-1, 1, 1],
|
||||||
|
],
|
||||||
|
dtype=np.float32,
|
||||||
|
).T
|
||||||
|
|
||||||
|
pred_corners = pred_homography_np @ corners
|
||||||
|
pred_corners = pred_corners / (pred_corners[2:3, :] + 1e-8)
|
||||||
|
|
||||||
|
target_corners = target_homography_np @ corners
|
||||||
|
target_corners = target_corners / (target_corners[2:3, :] + 1e-8)
|
||||||
|
|
||||||
|
corner_error = np.mean(
|
||||||
|
np.linalg.norm(pred_corners[:2, :] - target_corners[:2, :], axis=0)
|
||||||
|
)
|
||||||
|
corner_errors.append(corner_error)
|
||||||
|
|
||||||
|
# Save visualization if requested
|
||||||
|
if save_dir:
|
||||||
|
os.makedirs(save_dir, exist_ok=True)
|
||||||
|
vis_path = os.path.join(save_dir, f"sample_{idx:04d}.png")
|
||||||
|
self.visualize_alignment(
|
||||||
|
google_img,
|
||||||
|
yandex_img,
|
||||||
|
pred_homography_np,
|
||||||
|
save_path=vis_path,
|
||||||
|
show=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Compute metrics
|
||||||
|
metrics = {
|
||||||
|
"mean_matrix_error": float(np.mean(errors)),
|
||||||
|
"std_matrix_error": float(np.std(errors)),
|
||||||
|
"mean_corner_error": float(np.mean(corner_errors)),
|
||||||
|
"std_corner_error": float(np.std(corner_errors)),
|
||||||
|
"median_corner_error": float(np.median(corner_errors)),
|
||||||
|
"max_corner_error": float(np.max(corner_errors)),
|
||||||
|
"min_corner_error": float(np.min(corner_errors)),
|
||||||
|
}
|
||||||
|
|
||||||
|
print("\nEvaluation Results:")
|
||||||
|
for key, value in metrics.items():
|
||||||
|
print(f" {key}: {value:.6f}")
|
||||||
|
|
||||||
|
return metrics
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
"""Main inference function."""
|
||||||
|
parser = argparse.ArgumentParser(description="Inference for homography estimation")
|
||||||
|
|
||||||
|
# Model arguments
|
||||||
|
parser.add_argument(
|
||||||
|
"--model_path",
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
help="Path to trained model checkpoint",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--config_path",
|
||||||
|
type=str,
|
||||||
|
help="Path to model configuration file (optional)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--device",
|
||||||
|
type=str,
|
||||||
|
choices=["cuda", "cpu"],
|
||||||
|
help="Device to run inference on",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Inference mode
|
||||||
|
parser.add_argument(
|
||||||
|
"--mode",
|
||||||
|
type=str,
|
||||||
|
default="single",
|
||||||
|
choices=["single", "dataset", "batch"],
|
||||||
|
help="Inference mode",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Single image mode
|
||||||
|
parser.add_argument(
|
||||||
|
"--google_path",
|
||||||
|
type=str,
|
||||||
|
help="Path to Google map image (single mode)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--yandex_path",
|
||||||
|
type=str,
|
||||||
|
help="Path to Yandex map image (single mode)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--output_vis",
|
||||||
|
type=str,
|
||||||
|
help="Path to save visualization (single mode)",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Dataset mode
|
||||||
|
parser.add_argument(
|
||||||
|
"--dataset_dir",
|
||||||
|
type=str,
|
||||||
|
default=r"C:\Users\admin\Projects\autopilot\datasets\ya_go_maps\images",
|
||||||
|
help="Directory containing image pairs (dataset mode)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--num_samples",
|
||||||
|
type=int,
|
||||||
|
help="Number of samples to evaluate (dataset mode)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--save_vis_dir",
|
||||||
|
type=str,
|
||||||
|
help="Directory to save visualizations (dataset mode)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--save_results",
|
||||||
|
type=str,
|
||||||
|
help="Path to save evaluation results (dataset mode)",
|
||||||
|
)
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
# Create inference object
|
||||||
|
inference = HomographyInference(
|
||||||
|
model_path=args.model_path,
|
||||||
|
config_path=args.config_path,
|
||||||
|
device=args.device,
|
||||||
|
)
|
||||||
|
|
||||||
|
if args.mode == "single":
|
||||||
|
# Single image pair inference
|
||||||
|
if not args.google_path or not args.yandex_path:
|
||||||
|
raise ValueError(
|
||||||
|
"Both --google_path and --yandex_path are required for single mode"
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"Processing single image pair:")
|
||||||
|
print(f" Google: {args.google_path}")
|
||||||
|
print(f" Yandex: {args.yandex_path}")
|
||||||
|
|
||||||
|
# Predict homography
|
||||||
|
homography = inference.predict_from_paths(args.google_path, args.yandex_path)
|
||||||
|
|
||||||
|
print(f"\nPredicted homography matrix:")
|
||||||
|
print(homography.cpu().numpy())
|
||||||
|
|
||||||
|
# Visualize alignment
|
||||||
|
if args.output_vis:
|
||||||
|
google_img = Image.open(args.google_path)
|
||||||
|
yandex_img = Image.open(args.yandex_path)
|
||||||
|
inference.visualize_alignment(
|
||||||
|
google_img,
|
||||||
|
yandex_img,
|
||||||
|
homography.cpu().numpy(),
|
||||||
|
save_path=args.output_vis,
|
||||||
|
show=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
elif args.mode == "dataset":
|
||||||
|
# Evaluate on dataset
|
||||||
|
metrics = inference.evaluate_on_dataset(
|
||||||
|
dataset_dir=args.dataset_dir,
|
||||||
|
num_samples=args.num_samples,
|
||||||
|
save_dir=args.save_vis_dir,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Save results if requested
|
||||||
|
if args.save_results:
|
||||||
|
with open(args.save_results, "w") as f:
|
||||||
|
json.dump(metrics, f, indent=2)
|
||||||
|
print(f"\nResults saved to {args.save_results}")
|
||||||
|
|
||||||
|
elif args.mode == "batch":
|
||||||
|
# Batch processing (placeholder for future implementation)
|
||||||
|
print("Batch mode not yet implemented")
|
||||||
|
# Could implement processing multiple image pairs from a directory
|
||||||
|
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown mode: {args.mode}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
611
models/SiaN/train_homography.py
Normal file
611
models/SiaN/train_homography.py
Normal file
@@ -0,0 +1,611 @@
|
|||||||
|
"""
|
||||||
|
Training script for homography estimation between Google and Yandex map images.
|
||||||
|
|
||||||
|
This script trains a CNN model to estimate homography matrices that align
|
||||||
|
Google map images with Yandex map images.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
from datetime import datetime
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Dict, List, Optional, Tuple
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.optim as optim
|
||||||
|
from homography import HomographyDataset, create_data_loaders
|
||||||
|
from homography_cnn import HomographyCNN, HomographyLoss, create_homography_model
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
|
||||||
|
class HomographyTrainer:
|
||||||
|
"""Trainer class for homography estimation model."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model: nn.Module,
|
||||||
|
train_loader: DataLoader,
|
||||||
|
val_loader: DataLoader,
|
||||||
|
device: torch.device,
|
||||||
|
config: Dict,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Initialize the trainer.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model: Homography estimation model
|
||||||
|
train_loader: Training data loader
|
||||||
|
val_loader: Validation data loader
|
||||||
|
device: Device to run training on
|
||||||
|
config: Training configuration dictionary
|
||||||
|
"""
|
||||||
|
self.model = model.to(device)
|
||||||
|
self.train_loader = train_loader
|
||||||
|
self.val_loader = val_loader
|
||||||
|
self.device = device
|
||||||
|
self.config = config
|
||||||
|
|
||||||
|
# Loss function
|
||||||
|
self.criterion = HomographyLoss(
|
||||||
|
matrix_weight=config.get("matrix_weight", 1.0),
|
||||||
|
geometric_weight=config.get("geometric_weight", 0.5),
|
||||||
|
reg_weight=config.get("reg_weight", 0.1),
|
||||||
|
grid_size=config.get("grid_size", 8),
|
||||||
|
).to(device)
|
||||||
|
|
||||||
|
# Optimizer
|
||||||
|
optimizer_name = config.get("optimizer", "adam").lower()
|
||||||
|
lr = config.get("learning_rate", 1e-3)
|
||||||
|
weight_decay = config.get("weight_decay", 1e-4)
|
||||||
|
|
||||||
|
if optimizer_name == "adam":
|
||||||
|
self.optimizer = optim.Adam(
|
||||||
|
self.model.parameters(), lr=lr, weight_decay=weight_decay
|
||||||
|
)
|
||||||
|
elif optimizer_name == "adamw":
|
||||||
|
self.optimizer = optim.AdamW(
|
||||||
|
self.model.parameters(), lr=lr, weight_decay=weight_decay
|
||||||
|
)
|
||||||
|
elif optimizer_name == "sgd":
|
||||||
|
self.optimizer = optim.SGD(
|
||||||
|
self.model.parameters(),
|
||||||
|
lr=lr,
|
||||||
|
momentum=config.get("momentum", 0.9),
|
||||||
|
weight_decay=weight_decay,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown optimizer: {optimizer_name}")
|
||||||
|
|
||||||
|
# Learning rate scheduler
|
||||||
|
scheduler_name = config.get("scheduler", "plateau").lower()
|
||||||
|
if scheduler_name == "plateau":
|
||||||
|
self.scheduler = optim.lr_scheduler.ReduceLROnPlateau(
|
||||||
|
self.optimizer,
|
||||||
|
mode="min",
|
||||||
|
factor=config.get("scheduler_factor", 0.5),
|
||||||
|
patience=config.get("scheduler_patience", 5),
|
||||||
|
verbose=True,
|
||||||
|
)
|
||||||
|
elif scheduler_name == "cosine":
|
||||||
|
self.scheduler = optim.lr_scheduler.CosineAnnealingLR(
|
||||||
|
self.optimizer,
|
||||||
|
T_max=config.get("epochs", 100),
|
||||||
|
eta_min=config.get("min_lr", 1e-6),
|
||||||
|
)
|
||||||
|
elif scheduler_name == "step":
|
||||||
|
self.scheduler = optim.lr_scheduler.StepLR(
|
||||||
|
self.optimizer,
|
||||||
|
step_size=config.get("step_size", 30),
|
||||||
|
gamma=config.get("gamma", 0.1),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.scheduler = None
|
||||||
|
|
||||||
|
# Training state
|
||||||
|
self.current_epoch = 0
|
||||||
|
self.best_val_loss = float("inf")
|
||||||
|
self.train_losses: List[float] = []
|
||||||
|
self.val_losses: List[float] = []
|
||||||
|
self.val_metrics: List[Dict] = []
|
||||||
|
|
||||||
|
# Create output directory
|
||||||
|
self.output_dir = Path(config.get("output_dir", "runs/homography"))
|
||||||
|
self.output_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
# TensorBoard writer
|
||||||
|
self.writer = SummaryWriter(log_dir=self.output_dir / "tensorboard")
|
||||||
|
|
||||||
|
# Save configuration
|
||||||
|
config_path = self.output_dir / "config.json"
|
||||||
|
with open(config_path, "w") as f:
|
||||||
|
json.dump(config, f, indent=2)
|
||||||
|
|
||||||
|
print(f"Training configuration saved to {config_path}")
|
||||||
|
print(
|
||||||
|
f"Model has {sum(p.numel() for p in self.model.parameters()):,} parameters"
|
||||||
|
)
|
||||||
|
|
||||||
|
def train_epoch(self) -> float:
|
||||||
|
"""
|
||||||
|
Train for one epoch.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Average training loss for the epoch
|
||||||
|
"""
|
||||||
|
self.model.train()
|
||||||
|
total_loss = 0.0
|
||||||
|
num_batches = len(self.train_loader)
|
||||||
|
|
||||||
|
progress_bar = tqdm(self.train_loader, desc=f"Epoch {self.current_epoch + 1}")
|
||||||
|
for batch_idx, batch in enumerate(progress_bar):
|
||||||
|
# Move data to device
|
||||||
|
google_img = batch["google_img"].to(self.device)
|
||||||
|
yandex_img = batch["yandex_img"].to(self.device)
|
||||||
|
target_homography = batch["homography"].to(self.device)
|
||||||
|
|
||||||
|
# Forward pass
|
||||||
|
self.optimizer.zero_grad()
|
||||||
|
pred_homography = self.model(google_img, yandex_img, return_matrix=True)
|
||||||
|
|
||||||
|
# Compute loss
|
||||||
|
loss = self.criterion(
|
||||||
|
pred_homography,
|
||||||
|
target_homography,
|
||||||
|
google_img,
|
||||||
|
yandex_img,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Backward pass
|
||||||
|
loss.backward()
|
||||||
|
|
||||||
|
# Gradient clipping
|
||||||
|
if self.config.get("grad_clip", 1.0) > 0:
|
||||||
|
torch.nn.utils.clip_grad_norm_(
|
||||||
|
self.model.parameters(),
|
||||||
|
self.config.get("grad_clip", 1.0),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Optimizer step
|
||||||
|
self.optimizer.step()
|
||||||
|
|
||||||
|
# Update statistics
|
||||||
|
total_loss += loss.item()
|
||||||
|
|
||||||
|
# Update progress bar
|
||||||
|
progress_bar.set_postfix({"loss": loss.item()})
|
||||||
|
|
||||||
|
# Log batch loss to TensorBoard
|
||||||
|
global_step = self.current_epoch * num_batches + batch_idx
|
||||||
|
self.writer.add_scalar("train/batch_loss", loss.item(), global_step)
|
||||||
|
|
||||||
|
avg_loss = total_loss / num_batches
|
||||||
|
self.train_losses.append(avg_loss)
|
||||||
|
|
||||||
|
return avg_loss
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def validate(self) -> Tuple[float, Dict]:
|
||||||
|
"""
|
||||||
|
Validate the model.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (average validation loss, validation metrics)
|
||||||
|
"""
|
||||||
|
self.model.eval()
|
||||||
|
total_loss = 0.0
|
||||||
|
all_metrics = []
|
||||||
|
|
||||||
|
progress_bar = tqdm(self.val_loader, desc="Validation")
|
||||||
|
for batch in progress_bar:
|
||||||
|
# Move data to device
|
||||||
|
google_img = batch["google_img"].to(self.device)
|
||||||
|
yandex_img = batch["yandex_img"].to(self.device)
|
||||||
|
target_homography = batch["homography"].to(self.device)
|
||||||
|
|
||||||
|
# Forward pass
|
||||||
|
pred_homography = self.model(google_img, yandex_img, return_matrix=True)
|
||||||
|
|
||||||
|
# Compute loss
|
||||||
|
loss = self.criterion(
|
||||||
|
pred_homography,
|
||||||
|
target_homography,
|
||||||
|
google_img,
|
||||||
|
yandex_img,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Compute metrics
|
||||||
|
metrics = self.criterion.compute_metrics(pred_homography, target_homography)
|
||||||
|
|
||||||
|
# Update statistics
|
||||||
|
total_loss += loss.item()
|
||||||
|
all_metrics.append(metrics)
|
||||||
|
|
||||||
|
# Update progress bar
|
||||||
|
progress_bar.set_postfix({"loss": loss.item()})
|
||||||
|
|
||||||
|
avg_loss = total_loss / len(self.val_loader)
|
||||||
|
self.val_losses.append(avg_loss)
|
||||||
|
|
||||||
|
# Aggregate metrics
|
||||||
|
avg_metrics = {}
|
||||||
|
for key in all_metrics[0].keys():
|
||||||
|
avg_metrics[key] = np.mean([m[key] for m in all_metrics])
|
||||||
|
|
||||||
|
self.val_metrics.append(avg_metrics)
|
||||||
|
|
||||||
|
return avg_loss, avg_metrics
|
||||||
|
|
||||||
|
def save_checkpoint(self, is_best: bool = False):
|
||||||
|
"""Save model checkpoint."""
|
||||||
|
checkpoint = {
|
||||||
|
"epoch": self.current_epoch,
|
||||||
|
"model_state_dict": self.model.state_dict(),
|
||||||
|
"optimizer_state_dict": self.optimizer.state_dict(),
|
||||||
|
"train_losses": self.train_losses,
|
||||||
|
"val_losses": self.val_losses,
|
||||||
|
"val_metrics": self.val_metrics,
|
||||||
|
"best_val_loss": self.best_val_loss,
|
||||||
|
"config": self.config,
|
||||||
|
}
|
||||||
|
|
||||||
|
if self.scheduler is not None:
|
||||||
|
checkpoint["scheduler_state_dict"] = self.scheduler.state_dict()
|
||||||
|
|
||||||
|
# Save latest checkpoint
|
||||||
|
checkpoint_path = self.output_dir / "checkpoint_latest.pth"
|
||||||
|
torch.save(checkpoint, checkpoint_path)
|
||||||
|
|
||||||
|
# Save best checkpoint
|
||||||
|
if is_best:
|
||||||
|
best_path = self.output_dir / "checkpoint_best.pth"
|
||||||
|
torch.save(checkpoint, best_path)
|
||||||
|
print(f"Best model saved to {best_path}")
|
||||||
|
|
||||||
|
def load_checkpoint(self, checkpoint_path: str):
|
||||||
|
"""Load model checkpoint."""
|
||||||
|
checkpoint = torch.load(checkpoint_path, map_location=self.device)
|
||||||
|
|
||||||
|
self.model.load_state_dict(checkpoint["model_state_dict"])
|
||||||
|
self.optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
|
||||||
|
|
||||||
|
if self.scheduler is not None and "scheduler_state_dict" in checkpoint:
|
||||||
|
self.scheduler.load_state_dict(checkpoint["scheduler_state_dict"])
|
||||||
|
|
||||||
|
self.current_epoch = checkpoint["epoch"]
|
||||||
|
self.train_losses = checkpoint["train_losses"]
|
||||||
|
self.val_losses = checkpoint["val_losses"]
|
||||||
|
self.val_metrics = checkpoint["val_metrics"]
|
||||||
|
self.best_val_loss = checkpoint["best_val_loss"]
|
||||||
|
|
||||||
|
print(f"Loaded checkpoint from epoch {self.current_epoch}")
|
||||||
|
|
||||||
|
def train(self, num_epochs: int):
|
||||||
|
"""
|
||||||
|
Train the model for specified number of epochs.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
num_epochs: Number of epochs to train
|
||||||
|
"""
|
||||||
|
print(f"Starting training for {num_epochs} epochs...")
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
for epoch in range(num_epochs):
|
||||||
|
self.current_epoch = epoch
|
||||||
|
|
||||||
|
# Train for one epoch
|
||||||
|
train_loss = self.train_epoch()
|
||||||
|
|
||||||
|
# Validate
|
||||||
|
val_loss, val_metrics = self.validate()
|
||||||
|
|
||||||
|
# Update learning rate scheduler
|
||||||
|
if self.scheduler is not None:
|
||||||
|
if isinstance(self.scheduler, optim.lr_scheduler.ReduceLROnPlateau):
|
||||||
|
self.scheduler.step(val_loss)
|
||||||
|
else:
|
||||||
|
self.scheduler.step()
|
||||||
|
|
||||||
|
# Log to TensorBoard
|
||||||
|
self.writer.add_scalar("train/epoch_loss", train_loss, epoch)
|
||||||
|
self.writer.add_scalar("val/epoch_loss", val_loss, epoch)
|
||||||
|
for metric_name, metric_value in val_metrics.items():
|
||||||
|
self.writer.add_scalar(f"val/{metric_name}", metric_value, epoch)
|
||||||
|
|
||||||
|
# Print epoch summary
|
||||||
|
print(f"\nEpoch {epoch + 1}/{num_epochs}:")
|
||||||
|
print(f" Train Loss: {train_loss:.6f}")
|
||||||
|
print(f" Val Loss: {val_loss:.6f}")
|
||||||
|
print(" Val Metrics:")
|
||||||
|
for metric_name, metric_value in val_metrics.items():
|
||||||
|
print(f" {metric_name}: {metric_value:.6f}")
|
||||||
|
|
||||||
|
# Save checkpoint
|
||||||
|
is_best = val_loss < self.best_val_loss
|
||||||
|
if is_best:
|
||||||
|
self.best_val_loss = val_loss
|
||||||
|
|
||||||
|
self.save_checkpoint(is_best=is_best)
|
||||||
|
|
||||||
|
# Early stopping
|
||||||
|
if self.config.get("early_stopping_patience", 0) > 0:
|
||||||
|
if (
|
||||||
|
epoch - np.argmin(self.val_losses)
|
||||||
|
>= self.config["early_stopping_patience"]
|
||||||
|
):
|
||||||
|
print(f"Early stopping at epoch {epoch + 1}")
|
||||||
|
break
|
||||||
|
|
||||||
|
# Training completed
|
||||||
|
training_time = time.time() - start_time
|
||||||
|
print(f"\nTraining completed in {training_time:.2f} seconds")
|
||||||
|
print(f"Best validation loss: {self.best_val_loss:.6f}")
|
||||||
|
|
||||||
|
# Save final model
|
||||||
|
final_model_path = self.output_dir / "model_final.pth"
|
||||||
|
torch.save(self.model.state_dict(), final_model_path)
|
||||||
|
print(f"Final model saved to {final_model_path}")
|
||||||
|
|
||||||
|
# Close TensorBoard writer
|
||||||
|
self.writer.close()
|
||||||
|
|
||||||
|
def evaluate(self, test_loader: Optional[DataLoader] = None) -> Dict:
|
||||||
|
"""
|
||||||
|
Evaluate the model on test data.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
test_loader: Test data loader (uses validation loader if None)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary of evaluation metrics
|
||||||
|
"""
|
||||||
|
if test_loader is None:
|
||||||
|
test_loader = self.val_loader
|
||||||
|
|
||||||
|
self.model.eval()
|
||||||
|
all_metrics = []
|
||||||
|
|
||||||
|
print("Evaluating model...")
|
||||||
|
with torch.no_grad():
|
||||||
|
for batch in tqdm(test_loader):
|
||||||
|
# Move data to device
|
||||||
|
google_img = batch["google_img"].to(self.device)
|
||||||
|
yandex_img = batch["yandex_img"].to(self.device)
|
||||||
|
target_homography = batch["homography"].to(self.device)
|
||||||
|
|
||||||
|
# Forward pass
|
||||||
|
pred_homography = self.model(google_img, yandex_img, return_matrix=True)
|
||||||
|
|
||||||
|
# Compute metrics
|
||||||
|
metrics = self.criterion.compute_metrics(
|
||||||
|
pred_homography, target_homography
|
||||||
|
)
|
||||||
|
all_metrics.append(metrics)
|
||||||
|
|
||||||
|
# Aggregate metrics
|
||||||
|
avg_metrics = {}
|
||||||
|
for key in all_metrics[0].keys():
|
||||||
|
avg_metrics[key] = np.mean([m[key] for m in all_metrics])
|
||||||
|
|
||||||
|
# Print evaluation results
|
||||||
|
print("\nEvaluation Results:")
|
||||||
|
for metric_name, metric_value in avg_metrics.items():
|
||||||
|
print(f" {metric_name}: {metric_value:.6f}")
|
||||||
|
|
||||||
|
# Save evaluation results
|
||||||
|
eval_path = self.output_dir / "evaluation_results.json"
|
||||||
|
with open(eval_path, "w") as f:
|
||||||
|
json.dump(avg_metrics, f, indent=2)
|
||||||
|
print(f"Evaluation results saved to {eval_path}")
|
||||||
|
|
||||||
|
return avg_metrics
|
||||||
|
|
||||||
|
|
||||||
|
def parse_args():
|
||||||
|
"""Parse command line arguments."""
|
||||||
|
parser = argparse.ArgumentParser(description="Train homography estimation model")
|
||||||
|
|
||||||
|
# Data arguments
|
||||||
|
parser.add_argument(
|
||||||
|
"--data_dir",
|
||||||
|
type=str,
|
||||||
|
default=r"C:\Users\admin\Projects\autopilot\datasets\ya_go_maps\images",
|
||||||
|
help="Directory containing image pairs",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--batch_size", type=int, default=32, help="Batch size for training"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--image_size",
|
||||||
|
type=int,
|
||||||
|
nargs=2,
|
||||||
|
default=[256, 256],
|
||||||
|
help="Image size (height width)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--train_split", type=float, default=0.8, help="Train/validation split ratio"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--num_workers", type=int, default=4, help="Number of data loader workers"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Model arguments
|
||||||
|
parser.add_argument(
|
||||||
|
"--model_type", type=str, default="cnn", choices=["cnn"], help="Model type"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--hidden_channels", type=int, default=64, help="Number of hidden channels"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--num_blocks", type=int, default=4, help="Number of convolutional blocks"
|
||||||
|
)
|
||||||
|
parser.add_argument("--dropout_rate", type=float, default=0.3, help="Dropout rate")
|
||||||
|
parser.add_argument(
|
||||||
|
"--use_batch_norm", action="store_true", help="Use batch normalization"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Training arguments
|
||||||
|
parser.add_argument("--epochs", type=int, default=100, help="Number of epochs")
|
||||||
|
parser.add_argument("--lr", type=float, default=1e-3, help="Learning rate")
|
||||||
|
parser.add_argument("--weight_decay", type=float, default=1e-4, help="Weight decay")
|
||||||
|
parser.add_argument(
|
||||||
|
"--optimizer", type=str, default="adam", choices=["adam", "adamw", "sgd"]
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--scheduler",
|
||||||
|
type=str,
|
||||||
|
default="plateau",
|
||||||
|
choices=["plateau", "cosine", "step", "none"],
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--grad_clip", type=float, default=1.0, help="Gradient clipping value"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Loss arguments
|
||||||
|
parser.add_argument(
|
||||||
|
"--matrix_weight", type=float, default=1.0, help="Weight for matrix loss"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--geometric_weight",
|
||||||
|
type=float,
|
||||||
|
default=0.5,
|
||||||
|
help="Weight for geometric loss",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--reg_weight", type=float, default=0.1, help="Weight for regularization loss"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Other arguments
|
||||||
|
parser.add_argument(
|
||||||
|
"--output_dir",
|
||||||
|
type=str,
|
||||||
|
default="runs/homography",
|
||||||
|
help="Output directory for checkpoints and logs",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--resume",
|
||||||
|
type=str,
|
||||||
|
help="Path to checkpoint to resume training from",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--eval_only",
|
||||||
|
action="store_true",
|
||||||
|
help="Only evaluate the model (no training)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--seed", type=int, default=42, help="Random seed for reproducibility"
|
||||||
|
)
|
||||||
|
|
||||||
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
"""Main training function."""
|
||||||
|
args = parse_args()
|
||||||
|
|
||||||
|
# Set random seeds for reproducibility
|
||||||
|
torch.manual_seed(args.seed)
|
||||||
|
np.random.seed(args.seed)
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
torch.cuda.manual_seed(args.seed)
|
||||||
|
torch.cuda.manual_seed_all(args.seed)
|
||||||
|
|
||||||
|
# Set device
|
||||||
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
print(f"Using device: {device}")
|
||||||
|
|
||||||
|
# Create data loaders
|
||||||
|
print("Creating data loaders...")
|
||||||
|
train_loader, val_loader = create_data_loaders(
|
||||||
|
root_dir=args.data_dir,
|
||||||
|
batch_size=args.batch_size,
|
||||||
|
train_split=args.train_split,
|
||||||
|
num_workers=args.num_workers,
|
||||||
|
image_size=tuple(args.image_size),
|
||||||
|
augment_train=True,
|
||||||
|
augment_val=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"Train batches: {len(train_loader)}")
|
||||||
|
print(f"Val batches: {len(val_loader)}")
|
||||||
|
|
||||||
|
# Create model
|
||||||
|
print("Creating model...")
|
||||||
|
model = create_homography_model(
|
||||||
|
model_type=args.model_type,
|
||||||
|
input_size=tuple(args.image_size),
|
||||||
|
input_channels=3,
|
||||||
|
hidden_channels=args.hidden_channels,
|
||||||
|
num_blocks=args.num_blocks,
|
||||||
|
dropout_rate=args.dropout_rate,
|
||||||
|
use_batch_norm=args.use_batch_norm,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create trainer configuration
|
||||||
|
config = {
|
||||||
|
# Model config
|
||||||
|
"model_type": args.model_type,
|
||||||
|
"hidden_channels": args.hidden_channels,
|
||||||
|
"num_blocks": args.num_blocks,
|
||||||
|
"dropout_rate": args.dropout_rate,
|
||||||
|
"use_batch_norm": args.use_batch_norm,
|
||||||
|
"image_size": args.image_size,
|
||||||
|
# Training config
|
||||||
|
"epochs": args.epochs,
|
||||||
|
"batch_size": args.batch_size,
|
||||||
|
"learning_rate": args.lr,
|
||||||
|
"weight_decay": args.weight_decay,
|
||||||
|
"optimizer": args.optimizer,
|
||||||
|
"scheduler": args.scheduler,
|
||||||
|
"grad_clip": args.grad_clip,
|
||||||
|
# Loss config
|
||||||
|
"matrix_weight": args.matrix_weight,
|
||||||
|
"geometric_weight": args.geometric_weight,
|
||||||
|
"reg_weight": args.reg_weight,
|
||||||
|
"grid_size": 8,
|
||||||
|
# Data config
|
||||||
|
"data_dir": args.data_dir,
|
||||||
|
"train_split": args.train_split,
|
||||||
|
"num_workers": args.num_workers,
|
||||||
|
# Output config
|
||||||
|
"output_dir": args.output_dir,
|
||||||
|
"seed": args.seed,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Create trainer
|
||||||
|
trainer = HomographyTrainer(
|
||||||
|
model=model,
|
||||||
|
train_loader=train_loader,
|
||||||
|
val_loader=val_loader,
|
||||||
|
device=device,
|
||||||
|
config=config,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Resume from checkpoint if specified
|
||||||
|
if args.resume:
|
||||||
|
print(f"Resuming from checkpoint: {args.resume}")
|
||||||
|
trainer.load_checkpoint(args.resume)
|
||||||
|
|
||||||
|
# Evaluate only mode
|
||||||
|
if args.eval_only:
|
||||||
|
trainer.evaluate()
|
||||||
|
return
|
||||||
|
|
||||||
|
# Train the model
|
||||||
|
trainer.train(num_epochs=args.epochs)
|
||||||
|
|
||||||
|
# Final evaluation
|
||||||
|
print("\nPerforming final evaluation...")
|
||||||
|
trainer.evaluate()
|
||||||
|
|
||||||
|
print("\nTraining completed successfully!")
|
||||||
|
print(f"Results saved to: {args.output_dir}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
611
models/SiaN/train_homography_.py
Normal file
611
models/SiaN/train_homography_.py
Normal file
@@ -0,0 +1,611 @@
|
|||||||
|
"""
|
||||||
|
Training script for homography estimation between Google and Yandex map images.
|
||||||
|
|
||||||
|
This script trains a CNN model to estimate homography matrices that align
|
||||||
|
Google map images with Yandex map images.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
from datetime import datetime
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Dict, List, Optional, Tuple
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.optim as optim
|
||||||
|
from homography import HomographyDataset, create_data_loaders
|
||||||
|
from homography_cnn import HomographyCNN, HomographyLoss, create_homography_model
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
|
||||||
|
class HomographyTrainer:
|
||||||
|
"""Trainer class for homography estimation model."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model: nn.Module,
|
||||||
|
train_loader: DataLoader,
|
||||||
|
val_loader: DataLoader,
|
||||||
|
device: torch.device,
|
||||||
|
config: Dict,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Initialize the trainer.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model: Homography estimation model
|
||||||
|
train_loader: Training data loader
|
||||||
|
val_loader: Validation data loader
|
||||||
|
device: Device to run training on
|
||||||
|
config: Training configuration dictionary
|
||||||
|
"""
|
||||||
|
self.model = model.to(device)
|
||||||
|
self.train_loader = train_loader
|
||||||
|
self.val_loader = val_loader
|
||||||
|
self.device = device
|
||||||
|
self.config = config
|
||||||
|
|
||||||
|
# Loss function
|
||||||
|
self.criterion = HomographyLoss(
|
||||||
|
matrix_weight=config.get("matrix_weight", 1.0),
|
||||||
|
geometric_weight=config.get("geometric_weight", 0.5),
|
||||||
|
reg_weight=config.get("reg_weight", 0.1),
|
||||||
|
grid_size=config.get("grid_size", 8),
|
||||||
|
).to(device)
|
||||||
|
|
||||||
|
# Optimizer
|
||||||
|
optimizer_name = config.get("optimizer", "adam").lower()
|
||||||
|
lr = config.get("learning_rate", 1e-3)
|
||||||
|
weight_decay = config.get("weight_decay", 1e-4)
|
||||||
|
|
||||||
|
if optimizer_name == "adam":
|
||||||
|
self.optimizer = optim.Adam(
|
||||||
|
self.model.parameters(), lr=lr, weight_decay=weight_decay
|
||||||
|
)
|
||||||
|
elif optimizer_name == "adamw":
|
||||||
|
self.optimizer = optim.AdamW(
|
||||||
|
self.model.parameters(), lr=lr, weight_decay=weight_decay
|
||||||
|
)
|
||||||
|
elif optimizer_name == "sgd":
|
||||||
|
self.optimizer = optim.SGD(
|
||||||
|
self.model.parameters(),
|
||||||
|
lr=lr,
|
||||||
|
momentum=config.get("momentum", 0.9),
|
||||||
|
weight_decay=weight_decay,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown optimizer: {optimizer_name}")
|
||||||
|
|
||||||
|
# Learning rate scheduler
|
||||||
|
scheduler_name = config.get("scheduler", "plateau").lower()
|
||||||
|
if scheduler_name == "plateau":
|
||||||
|
self.scheduler = optim.lr_scheduler.ReduceLROnPlateau(
|
||||||
|
self.optimizer,
|
||||||
|
mode="min",
|
||||||
|
factor=config.get("scheduler_factor", 0.5),
|
||||||
|
patience=config.get("scheduler_patience", 5),
|
||||||
|
verbose=True,
|
||||||
|
)
|
||||||
|
elif scheduler_name == "cosine":
|
||||||
|
self.scheduler = optim.lr_scheduler.CosineAnnealingLR(
|
||||||
|
self.optimizer,
|
||||||
|
T_max=config.get("epochs", 100),
|
||||||
|
eta_min=config.get("min_lr", 1e-6),
|
||||||
|
)
|
||||||
|
elif scheduler_name == "step":
|
||||||
|
self.scheduler = optim.lr_scheduler.StepLR(
|
||||||
|
self.optimizer,
|
||||||
|
step_size=config.get("step_size", 30),
|
||||||
|
gamma=config.get("gamma", 0.1),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.scheduler = None
|
||||||
|
|
||||||
|
# Training state
|
||||||
|
self.current_epoch = 0
|
||||||
|
self.best_val_loss = float("inf")
|
||||||
|
self.train_losses: List[float] = []
|
||||||
|
self.val_losses: List[float] = []
|
||||||
|
self.val_metrics: List[Dict] = []
|
||||||
|
|
||||||
|
# Create output directory
|
||||||
|
self.output_dir = Path(config.get("output_dir", "runs/homography"))
|
||||||
|
self.output_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
# TensorBoard writer
|
||||||
|
self.writer = SummaryWriter(log_dir=self.output_dir / "tensorboard")
|
||||||
|
|
||||||
|
# Save configuration
|
||||||
|
config_path = self.output_dir / "config.json"
|
||||||
|
with open(config_path, "w") as f:
|
||||||
|
json.dump(config, f, indent=2)
|
||||||
|
|
||||||
|
print(f"Training configuration saved to {config_path}")
|
||||||
|
print(
|
||||||
|
f"Model has {sum(p.numel() for p in self.model.parameters()):,} parameters"
|
||||||
|
)
|
||||||
|
|
||||||
|
def train_epoch(self) -> float:
|
||||||
|
"""
|
||||||
|
Train for one epoch.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Average training loss for the epoch
|
||||||
|
"""
|
||||||
|
self.model.train()
|
||||||
|
total_loss = 0.0
|
||||||
|
num_batches = len(self.train_loader)
|
||||||
|
|
||||||
|
progress_bar = tqdm(self.train_loader, desc=f"Epoch {self.current_epoch + 1}")
|
||||||
|
for batch_idx, batch in enumerate(progress_bar):
|
||||||
|
# Move data to device
|
||||||
|
google_img = batch["google_img"].to(self.device)
|
||||||
|
yandex_img = batch["yandex_img"].to(self.device)
|
||||||
|
target_homography = batch["homography"].to(self.device)
|
||||||
|
|
||||||
|
# Forward pass
|
||||||
|
self.optimizer.zero_grad()
|
||||||
|
pred_homography = self.model(google_img, yandex_img, return_matrix=True)
|
||||||
|
|
||||||
|
# Compute loss
|
||||||
|
loss = self.criterion(
|
||||||
|
pred_homography,
|
||||||
|
target_homography,
|
||||||
|
google_img,
|
||||||
|
yandex_img,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Backward pass
|
||||||
|
loss.backward()
|
||||||
|
|
||||||
|
# Gradient clipping
|
||||||
|
if self.config.get("grad_clip", 1.0) > 0:
|
||||||
|
torch.nn.utils.clip_grad_norm_(
|
||||||
|
self.model.parameters(),
|
||||||
|
self.config.get("grad_clip", 1.0),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Optimizer step
|
||||||
|
self.optimizer.step()
|
||||||
|
|
||||||
|
# Update statistics
|
||||||
|
total_loss += loss.item()
|
||||||
|
|
||||||
|
# Update progress bar
|
||||||
|
progress_bar.set_postfix({"loss": loss.item()})
|
||||||
|
|
||||||
|
# Log batch loss to TensorBoard
|
||||||
|
global_step = self.current_epoch * num_batches + batch_idx
|
||||||
|
self.writer.add_scalar("train/batch_loss", loss.item(), global_step)
|
||||||
|
|
||||||
|
avg_loss = total_loss / num_batches
|
||||||
|
self.train_losses.append(avg_loss)
|
||||||
|
|
||||||
|
return avg_loss
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def validate(self) -> Tuple[float, Dict]:
|
||||||
|
"""
|
||||||
|
Validate the model.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (average validation loss, validation metrics)
|
||||||
|
"""
|
||||||
|
self.model.eval()
|
||||||
|
total_loss = 0.0
|
||||||
|
all_metrics = []
|
||||||
|
|
||||||
|
progress_bar = tqdm(self.val_loader, desc="Validation")
|
||||||
|
for batch in progress_bar:
|
||||||
|
# Move data to device
|
||||||
|
google_img = batch["google_img"].to(self.device)
|
||||||
|
yandex_img = batch["yandex_img"].to(self.device)
|
||||||
|
target_homography = batch["homography"].to(self.device)
|
||||||
|
|
||||||
|
# Forward pass
|
||||||
|
pred_homography = self.model(google_img, yandex_img, return_matrix=True)
|
||||||
|
|
||||||
|
# Compute loss
|
||||||
|
loss = self.criterion(
|
||||||
|
pred_homography,
|
||||||
|
target_homography,
|
||||||
|
google_img,
|
||||||
|
yandex_img,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Compute metrics
|
||||||
|
metrics = self.criterion.compute_metrics(pred_homography, target_homography)
|
||||||
|
|
||||||
|
# Update statistics
|
||||||
|
total_loss += loss.item()
|
||||||
|
all_metrics.append(metrics)
|
||||||
|
|
||||||
|
# Update progress bar
|
||||||
|
progress_bar.set_postfix({"loss": loss.item()})
|
||||||
|
|
||||||
|
avg_loss = total_loss / len(self.val_loader)
|
||||||
|
self.val_losses.append(avg_loss)
|
||||||
|
|
||||||
|
# Aggregate metrics
|
||||||
|
avg_metrics = {}
|
||||||
|
for key in all_metrics[0].keys():
|
||||||
|
avg_metrics[key] = np.mean([m[key] for m in all_metrics])
|
||||||
|
|
||||||
|
self.val_metrics.append(avg_metrics)
|
||||||
|
|
||||||
|
return avg_loss, avg_metrics
|
||||||
|
|
||||||
|
def save_checkpoint(self, is_best: bool = False):
|
||||||
|
"""Save model checkpoint."""
|
||||||
|
checkpoint = {
|
||||||
|
"epoch": self.current_epoch,
|
||||||
|
"model_state_dict": self.model.state_dict(),
|
||||||
|
"optimizer_state_dict": self.optimizer.state_dict(),
|
||||||
|
"train_losses": self.train_losses,
|
||||||
|
"val_losses": self.val_losses,
|
||||||
|
"val_metrics": self.val_metrics,
|
||||||
|
"best_val_loss": self.best_val_loss,
|
||||||
|
"config": self.config,
|
||||||
|
}
|
||||||
|
|
||||||
|
if self.scheduler is not None:
|
||||||
|
checkpoint["scheduler_state_dict"] = self.scheduler.state_dict()
|
||||||
|
|
||||||
|
# Save latest checkpoint
|
||||||
|
checkpoint_path = self.output_dir / "checkpoint_latest.pth"
|
||||||
|
torch.save(checkpoint, checkpoint_path)
|
||||||
|
|
||||||
|
# Save best checkpoint
|
||||||
|
if is_best:
|
||||||
|
best_path = self.output_dir / "checkpoint_best.pth"
|
||||||
|
torch.save(checkpoint, best_path)
|
||||||
|
print(f"Best model saved to {best_path}")
|
||||||
|
|
||||||
|
def load_checkpoint(self, checkpoint_path: str):
|
||||||
|
"""Load model checkpoint."""
|
||||||
|
checkpoint = torch.load(checkpoint_path, map_location=self.device)
|
||||||
|
|
||||||
|
self.model.load_state_dict(checkpoint["model_state_dict"])
|
||||||
|
self.optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
|
||||||
|
|
||||||
|
if self.scheduler is not None and "scheduler_state_dict" in checkpoint:
|
||||||
|
self.scheduler.load_state_dict(checkpoint["scheduler_state_dict"])
|
||||||
|
|
||||||
|
self.current_epoch = checkpoint["epoch"]
|
||||||
|
self.train_losses = checkpoint["train_losses"]
|
||||||
|
self.val_losses = checkpoint["val_losses"]
|
||||||
|
self.val_metrics = checkpoint["val_metrics"]
|
||||||
|
self.best_val_loss = checkpoint["best_val_loss"]
|
||||||
|
|
||||||
|
print(f"Loaded checkpoint from epoch {self.current_epoch}")
|
||||||
|
|
||||||
|
def train(self, num_epochs: int):
|
||||||
|
"""
|
||||||
|
Train the model for specified number of epochs.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
num_epochs: Number of epochs to train
|
||||||
|
"""
|
||||||
|
print(f"Starting training for {num_epochs} epochs...")
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
for epoch in range(num_epochs):
|
||||||
|
self.current_epoch = epoch
|
||||||
|
|
||||||
|
# Train for one epoch
|
||||||
|
train_loss = self.train_epoch()
|
||||||
|
|
||||||
|
# Validate
|
||||||
|
val_loss, val_metrics = self.validate()
|
||||||
|
|
||||||
|
# Update learning rate scheduler
|
||||||
|
if self.scheduler is not None:
|
||||||
|
if isinstance(self.scheduler, optim.lr_scheduler.ReduceLROnPlateau):
|
||||||
|
self.scheduler.step(val_loss)
|
||||||
|
else:
|
||||||
|
self.scheduler.step()
|
||||||
|
|
||||||
|
# Log to TensorBoard
|
||||||
|
self.writer.add_scalar("train/epoch_loss", train_loss, epoch)
|
||||||
|
self.writer.add_scalar("val/epoch_loss", val_loss, epoch)
|
||||||
|
for metric_name, metric_value in val_metrics.items():
|
||||||
|
self.writer.add_scalar(f"val/{metric_name}", metric_value, epoch)
|
||||||
|
|
||||||
|
# Print epoch summary
|
||||||
|
print(f"\nEpoch {epoch + 1}/{num_epochs}:")
|
||||||
|
print(f" Train Loss: {train_loss:.6f}")
|
||||||
|
print(f" Val Loss: {val_loss:.6f}")
|
||||||
|
print(" Val Metrics:")
|
||||||
|
for metric_name, metric_value in val_metrics.items():
|
||||||
|
print(f" {metric_name}: {metric_value:.6f}")
|
||||||
|
|
||||||
|
# Save checkpoint
|
||||||
|
is_best = val_loss < self.best_val_loss
|
||||||
|
if is_best:
|
||||||
|
self.best_val_loss = val_loss
|
||||||
|
|
||||||
|
self.save_checkpoint(is_best=is_best)
|
||||||
|
|
||||||
|
# Early stopping
|
||||||
|
if self.config.get("early_stopping_patience", 0) > 0:
|
||||||
|
if (
|
||||||
|
epoch - np.argmin(self.val_losses)
|
||||||
|
>= self.config["early_stopping_patience"]
|
||||||
|
):
|
||||||
|
print(f"Early stopping at epoch {epoch + 1}")
|
||||||
|
break
|
||||||
|
|
||||||
|
# Training completed
|
||||||
|
training_time = time.time() - start_time
|
||||||
|
print(f"\nTraining completed in {training_time:.2f} seconds")
|
||||||
|
print(f"Best validation loss: {self.best_val_loss:.6f}")
|
||||||
|
|
||||||
|
# Save final model
|
||||||
|
final_model_path = self.output_dir / "model_final.pth"
|
||||||
|
torch.save(self.model.state_dict(), final_model_path)
|
||||||
|
print(f"Final model saved to {final_model_path}")
|
||||||
|
|
||||||
|
# Close TensorBoard writer
|
||||||
|
self.writer.close()
|
||||||
|
|
||||||
|
def evaluate(self, test_loader: Optional[DataLoader] = None) -> Dict:
|
||||||
|
"""
|
||||||
|
Evaluate the model on test data.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
test_loader: Test data loader (uses validation loader if None)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary of evaluation metrics
|
||||||
|
"""
|
||||||
|
if test_loader is None:
|
||||||
|
test_loader = self.val_loader
|
||||||
|
|
||||||
|
self.model.eval()
|
||||||
|
all_metrics = []
|
||||||
|
|
||||||
|
print("Evaluating model...")
|
||||||
|
with torch.no_grad():
|
||||||
|
for batch in tqdm(test_loader):
|
||||||
|
# Move data to device
|
||||||
|
google_img = batch["google_img"].to(self.device)
|
||||||
|
yandex_img = batch["yandex_img"].to(self.device)
|
||||||
|
target_homography = batch["homography"].to(self.device)
|
||||||
|
|
||||||
|
# Forward pass
|
||||||
|
pred_homography = self.model(google_img, yandex_img, return_matrix=True)
|
||||||
|
|
||||||
|
# Compute metrics
|
||||||
|
metrics = self.criterion.compute_metrics(
|
||||||
|
pred_homography, target_homography
|
||||||
|
)
|
||||||
|
all_metrics.append(metrics)
|
||||||
|
|
||||||
|
# Aggregate metrics
|
||||||
|
avg_metrics = {}
|
||||||
|
for key in all_metrics[0].keys():
|
||||||
|
avg_metrics[key] = np.mean([m[key] for m in all_metrics])
|
||||||
|
|
||||||
|
# Print evaluation results
|
||||||
|
print("\nEvaluation Results:")
|
||||||
|
for metric_name, metric_value in avg_metrics.items():
|
||||||
|
print(f" {metric_name}: {metric_value:.6f}")
|
||||||
|
|
||||||
|
# Save evaluation results
|
||||||
|
eval_path = self.output_dir / "evaluation_results.json"
|
||||||
|
with open(eval_path, "w") as f:
|
||||||
|
json.dump(avg_metrics, f, indent=2)
|
||||||
|
print(f"Evaluation results saved to {eval_path}")
|
||||||
|
|
||||||
|
return avg_metrics
|
||||||
|
|
||||||
|
|
||||||
|
def parse_args():
|
||||||
|
"""Parse command line arguments."""
|
||||||
|
parser = argparse.ArgumentParser(description="Train homography estimation model")
|
||||||
|
|
||||||
|
# Data arguments
|
||||||
|
parser.add_argument(
|
||||||
|
"--data_dir",
|
||||||
|
type=str,
|
||||||
|
default=r"C:\Users\admin\Projects\autopilot\datasets\ya_go_maps\images",
|
||||||
|
help="Directory containing image pairs",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--batch_size", type=int, default=32, help="Batch size for training"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--image_size",
|
||||||
|
type=int,
|
||||||
|
nargs=2,
|
||||||
|
default=[256, 256],
|
||||||
|
help="Image size (height width)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--train_split", type=float, default=0.8, help="Train/validation split ratio"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--num_workers", type=int, default=4, help="Number of data loader workers"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Model arguments
|
||||||
|
parser.add_argument(
|
||||||
|
"--model_type", type=str, default="cnn", choices=["cnn"], help="Model type"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--hidden_channels", type=int, default=64, help="Number of hidden channels"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--num_blocks", type=int, default=4, help="Number of convolutional blocks"
|
||||||
|
)
|
||||||
|
parser.add_argument("--dropout_rate", type=float, default=0.3, help="Dropout rate")
|
||||||
|
parser.add_argument(
|
||||||
|
"--use_batch_norm", action="store_true", help="Use batch normalization"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Training arguments
|
||||||
|
parser.add_argument("--epochs", type=int, default=100, help="Number of epochs")
|
||||||
|
parser.add_argument("--lr", type=float, default=1e-3, help="Learning rate")
|
||||||
|
parser.add_argument("--weight_decay", type=float, default=1e-4, help="Weight decay")
|
||||||
|
parser.add_argument(
|
||||||
|
"--optimizer", type=str, default="adam", choices=["adam", "adamw", "sgd"]
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--scheduler",
|
||||||
|
type=str,
|
||||||
|
default="plateau",
|
||||||
|
choices=["plateau", "cosine", "step", "none"],
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--grad_clip", type=float, default=1.0, help="Gradient clipping value"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Loss arguments
|
||||||
|
parser.add_argument(
|
||||||
|
"--matrix_weight", type=float, default=1.0, help="Weight for matrix loss"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--geometric_weight",
|
||||||
|
type=float,
|
||||||
|
default=0.5,
|
||||||
|
help="Weight for geometric loss",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--reg_weight", type=float, default=0.1, help="Weight for regularization loss"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Other arguments
|
||||||
|
parser.add_argument(
|
||||||
|
"--output_dir",
|
||||||
|
type=str,
|
||||||
|
default="runs/homography",
|
||||||
|
help="Output directory for checkpoints and logs",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--resume",
|
||||||
|
type=str,
|
||||||
|
help="Path to checkpoint to resume training from",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--eval_only",
|
||||||
|
action="store_true",
|
||||||
|
help="Only evaluate the model (no training)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--seed", type=int, default=42, help="Random seed for reproducibility"
|
||||||
|
)
|
||||||
|
|
||||||
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
"""Main training function."""
|
||||||
|
args = parse_args()
|
||||||
|
|
||||||
|
# Set random seeds for reproducibility
|
||||||
|
torch.manual_seed(args.seed)
|
||||||
|
np.random.seed(args.seed)
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
torch.cuda.manual_seed(args.seed)
|
||||||
|
torch.cuda.manual_seed_all(args.seed)
|
||||||
|
|
||||||
|
# Set device
|
||||||
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
print(f"Using device: {device}")
|
||||||
|
|
||||||
|
# Create data loaders
|
||||||
|
print("Creating data loaders...")
|
||||||
|
train_loader, val_loader = create_data_loaders(
|
||||||
|
root_dir=args.data_dir,
|
||||||
|
batch_size=args.batch_size,
|
||||||
|
train_split=args.train_split,
|
||||||
|
num_workers=args.num_workers,
|
||||||
|
image_size=tuple(args.image_size),
|
||||||
|
augment_train=True,
|
||||||
|
augment_val=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"Train batches: {len(train_loader)}")
|
||||||
|
print(f"Val batches: {len(val_loader)}")
|
||||||
|
|
||||||
|
# Create model
|
||||||
|
print("Creating model...")
|
||||||
|
model = create_homography_model(
|
||||||
|
model_type=args.model_type,
|
||||||
|
input_size=tuple(args.image_size),
|
||||||
|
input_channels=3,
|
||||||
|
hidden_channels=args.hidden_channels,
|
||||||
|
num_blocks=args.num_blocks,
|
||||||
|
dropout_rate=args.dropout_rate,
|
||||||
|
use_batch_norm=args.use_batch_norm,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create trainer configuration
|
||||||
|
config = {
|
||||||
|
# Model config
|
||||||
|
"model_type": args.model_type,
|
||||||
|
"hidden_channels": args.hidden_channels,
|
||||||
|
"num_blocks": args.num_blocks,
|
||||||
|
"dropout_rate": args.dropout_rate,
|
||||||
|
"use_batch_norm": args.use_batch_norm,
|
||||||
|
"image_size": args.image_size,
|
||||||
|
# Training config
|
||||||
|
"epochs": args.epochs,
|
||||||
|
"batch_size": args.batch_size,
|
||||||
|
"learning_rate": args.lr,
|
||||||
|
"weight_decay": args.weight_decay,
|
||||||
|
"optimizer": args.optimizer,
|
||||||
|
"scheduler": args.scheduler,
|
||||||
|
"grad_clip": args.grad_clip,
|
||||||
|
# Loss config
|
||||||
|
"matrix_weight": args.matrix_weight,
|
||||||
|
"geometric_weight": args.geometric_weight,
|
||||||
|
"reg_weight": args.reg_weight,
|
||||||
|
"grid_size": 8,
|
||||||
|
# Data config
|
||||||
|
"data_dir": args.data_dir,
|
||||||
|
"train_split": args.train_split,
|
||||||
|
"num_workers": args.num_workers,
|
||||||
|
# Output config
|
||||||
|
"output_dir": args.output_dir,
|
||||||
|
"seed": args.seed,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Create trainer
|
||||||
|
trainer = HomographyTrainer(
|
||||||
|
model=model,
|
||||||
|
train_loader=train_loader,
|
||||||
|
val_loader=val_loader,
|
||||||
|
device=device,
|
||||||
|
config=config,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Resume from checkpoint if specified
|
||||||
|
if args.resume:
|
||||||
|
print(f"Resuming from checkpoint: {args.resume}")
|
||||||
|
trainer.load_checkpoint(args.resume)
|
||||||
|
|
||||||
|
# Evaluate only mode
|
||||||
|
if args.eval_only:
|
||||||
|
trainer.evaluate()
|
||||||
|
return
|
||||||
|
|
||||||
|
# Train the model
|
||||||
|
trainer.train(num_epochs=args.epochs)
|
||||||
|
|
||||||
|
# Final evaluation
|
||||||
|
print("\nPerforming final evaluation...")
|
||||||
|
trainer.evaluate()
|
||||||
|
|
||||||
|
print("\nTraining completed successfully!")
|
||||||
|
print(f"Results saved to: {args.output_dir}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
54
position.py
54
position.py
@@ -42,22 +42,45 @@ class Position:
|
|||||||
f"roll={math.degrees(self.roll):.1f}°)"
|
f"roll={math.degrees(self.roll):.1f}°)"
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_homography_matrix(self, K: np.ndarray = constants.K, sliding: bool = True) -> np.ndarray:
|
def __imul__(self, scalar: float):
|
||||||
|
self.x *= scalar
|
||||||
|
self.y *= scalar
|
||||||
|
self.z *= scalar
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __mul__(self, scalar: float) -> 'Position':
|
||||||
|
pos = self.copy()
|
||||||
|
pos *= scalar
|
||||||
|
return pos
|
||||||
|
|
||||||
|
def __itruediv__(self, scalar: float):
|
||||||
|
self.x /= scalar
|
||||||
|
self.y /= scalar
|
||||||
|
self.z /= scalar
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __truediv__(self, scalar: float) -> 'Position':
|
||||||
|
pos = self.copy()
|
||||||
|
pos /= scalar
|
||||||
|
return pos
|
||||||
|
|
||||||
|
def get_homography_matrix(self, K_in: np.ndarray = constants.K, K_out: np.ndarray | None = None, sliding: bool = True) -> np.ndarray:
|
||||||
""" Возвращает матрицу гомографии """
|
""" Возвращает матрицу гомографии """
|
||||||
R = self.get_rotation_matrix()
|
R = self.get_rotation_matrix()
|
||||||
T = self.get_translation_matrix()
|
T = self.get_translation_matrix(K_in)
|
||||||
if not sliding:
|
if not sliding:
|
||||||
T[0, 2] = T[1, 2] = 0
|
T[0, 2] = T[1, 2] = 0
|
||||||
return K @ R @ T @ np.linalg.inv(K)
|
if K_out is None: K_out = K_in
|
||||||
|
return K_out @ R @ T @ np.linalg.inv(K_in)
|
||||||
|
|
||||||
def copy(self) -> 'Position':
|
def copy(self) -> 'Position':
|
||||||
"""Создает полную копию объекта"""
|
"""Создает полную копию объекта"""
|
||||||
return Position(self.x, self.y, self.z, self.yaw, self.pitch, self.roll)
|
return Position(self.x, self.y, self.z, self.yaw, self.pitch, self.roll)
|
||||||
|
|
||||||
def get_translation_matrix(self) -> np.ndarray:
|
def get_translation_matrix(self, K: np.ndarray = constants.K) -> np.ndarray:
|
||||||
return np.array([
|
return np.array([
|
||||||
[1, 0, self.x / constants._K_FOCUS_DISTANCE],
|
[1, 0, self.x / K[0][0]],
|
||||||
[0, 1, self.y / constants._K_FOCUS_DISTANCE],
|
[0, 1, self.y / K[0][0]],
|
||||||
[0, 0, self.z]
|
[0, 0, self.z]
|
||||||
])
|
])
|
||||||
|
|
||||||
@@ -101,6 +124,13 @@ class Position:
|
|||||||
R = np.array(R)
|
R = np.array(R)
|
||||||
t = np.array(t)
|
t = np.array(t)
|
||||||
|
|
||||||
|
# print(cv2.decomposeHomographyMat(inv(H), K))
|
||||||
|
# _, _, z, _ = cv2.decomposeHomographyMat(inv(H), K)
|
||||||
|
# print(z)
|
||||||
|
# z = z.copy()
|
||||||
|
# z *= constants._K_FOCUS_DISTANCE
|
||||||
|
# print(z)
|
||||||
|
|
||||||
T = inv(R) @ inv(K) @ H @ K
|
T = inv(R) @ inv(K) @ H @ K
|
||||||
ind = np.array([A[2][0] ** 2 + A[2][1] ** 2 for A in T])
|
ind = np.array([A[2][0] ** 2 + A[2][1] ** 2 for A in T])
|
||||||
top_k = max(1, len(T) // 2)
|
top_k = max(1, len(T) // 2)
|
||||||
@@ -116,14 +146,16 @@ class Position:
|
|||||||
|
|
||||||
R = R[best_id]
|
R = R[best_id]
|
||||||
rot = Rotation.from_matrix(R).as_euler('XYZ').flatten()
|
rot = Rotation.from_matrix(R).as_euler('XYZ').flatten()
|
||||||
self.roll = rot[0]
|
self.roll = min(np.radians(5), max(np.radians(-5), rot[0]))
|
||||||
self.pitch = rot[1]
|
self.pitch = min(np.radians(5), max(np.radians(-5), rot[1]))
|
||||||
self.yaw = rot[2]
|
self.yaw = rot[2]
|
||||||
|
|
||||||
t = t[best_id].flatten()
|
t = t[best_id].flatten()
|
||||||
self.x += -T[0] * constants._K_FOCUS_DISTANCE * self.z
|
self.x -= T[0] * constants._K_FOCUS_DISTANCE
|
||||||
self.y += T[1] * constants._K_FOCUS_DISTANCE * self.z
|
self.y += T[1] * constants._K_FOCUS_DISTANCE
|
||||||
self.z = 1 + T[2]
|
self.z = max(0.7, min(1.3, 1 + T[2]))
|
||||||
|
T[0] *= constants._K_FOCUS_DISTANCE
|
||||||
|
T[1] *= constants._K_FOCUS_DISTANCE
|
||||||
|
|
||||||
def apply(self, homography_matrix: np.ndarray, K = constants.K) -> 'Position':
|
def apply(self, homography_matrix: np.ndarray, K = constants.K) -> 'Position':
|
||||||
"""Применяет матрицу трансформации для вычисления новой позиции и ориентации."""
|
"""Применяет матрицу трансформации для вычисления новой позиции и ориентации."""
|
||||||
|
|||||||
40
simulator.py
40
simulator.py
@@ -8,12 +8,13 @@ import numpy as np
|
|||||||
from position import Position
|
from position import Position
|
||||||
from vision_chunk import VisionChunk
|
from vision_chunk import VisionChunk
|
||||||
from yandex_map import YandexMap
|
from yandex_map import YandexMap
|
||||||
|
from google_map import GoogleMap
|
||||||
import constants
|
import constants
|
||||||
import utility
|
import utility
|
||||||
|
|
||||||
class Simulator:
|
class Simulator:
|
||||||
def __init__(self, yandex_map: YandexMap = None):
|
def __init__(self, online_map: YandexMap | GoogleMap = None):
|
||||||
self.yandex_map = yandex_map
|
self.online_map = online_map
|
||||||
# Используем новый конструктор с yaw, pitch, roll
|
# Используем новый конструктор с yaw, pitch, roll
|
||||||
self.pos = Position(x=0, y=0, z=1, yaw=0, pitch=0, roll=0)
|
self.pos = Position(x=0, y=0, z=1, yaw=0, pitch=0, roll=0)
|
||||||
|
|
||||||
@@ -35,24 +36,26 @@ class Simulator:
|
|||||||
Возвращает квадратное изображение 700x700.
|
Возвращает квадратное изображение 700x700.
|
||||||
"""
|
"""
|
||||||
img_array = np.array(image)
|
img_array = np.array(image)
|
||||||
print(img_array.shape)
|
|
||||||
h, w, _ = img_array.shape
|
h, w, _ = img_array.shape
|
||||||
|
|
||||||
# Применяем трансформацию
|
# Применяем трансформацию
|
||||||
pos = self.pos.copy()
|
pos = self.pos.copy()
|
||||||
pos.x = 0
|
pos.x = 0
|
||||||
pos.y = 0
|
pos.y = 0
|
||||||
K = utility.calc_camera_matrix(w, h)
|
|
||||||
K = constants.K
|
K_in = utility.calc_camera_matrix(w, h)
|
||||||
img_array = img_array[:constants.CHUNK_WIDTH, :constants.CHUNK_WIDTH]
|
K_out = constants.K
|
||||||
transformed = cv2.warpPerspective(img_array, pos.get_homography_matrix(K), (constants.CHUNK_WIDTH, constants.CHUNK_WIDTH))
|
H = pos.get_homography_matrix(K_in, K_out)
|
||||||
|
|
||||||
|
shape = (constants.CHUNK_WIDTH, constants.CHUNK_WIDTH)
|
||||||
|
transformed = cv2.warpPerspective(img_array, H, shape)
|
||||||
|
|
||||||
return Image.fromarray(transformed)
|
return Image.fromarray(transformed)
|
||||||
|
|
||||||
def update_trajectory(self, dx: float, dy: float):
|
def update_trajectory(self, dx: float, dy: float):
|
||||||
"""Обновляет координаты дрона"""
|
"""Обновляет координаты дрона"""
|
||||||
self.pos.x += dx * self.pos.z
|
self.pos.x += dx
|
||||||
self.pos.y += dy * self.pos.z
|
self.pos.y += dy
|
||||||
|
|
||||||
def handle(self, dangle: float, velocity: float = 50) -> None:
|
def handle(self, dangle: float, velocity: float = 50) -> None:
|
||||||
"""
|
"""
|
||||||
@@ -60,27 +63,17 @@ class Simulator:
|
|||||||
dangle - изменение угла курса (радианы)
|
dangle - изменение угла курса (радианы)
|
||||||
velocity - скорость движения
|
velocity - скорость движения
|
||||||
"""
|
"""
|
||||||
from selenium.webdriver.common.by import By
|
|
||||||
from selenium.webdriver.common.action_chains import ActionChains
|
|
||||||
|
|
||||||
html = self.yandex_map.driver.find_element(By.TAG_NAME, 'html')
|
|
||||||
action = ActionChains(self.yandex_map.driver)
|
|
||||||
action.move_to_element_with_offset(html, 200, 200)
|
|
||||||
action.click_and_hold()
|
|
||||||
|
|
||||||
# Обновляем yaw в объекте Position
|
# Обновляем yaw в объекте Position
|
||||||
self.pos.yaw += dangle
|
self.pos.yaw += dangle
|
||||||
velocity = max(velocity, 10)
|
velocity = max(velocity, 10)
|
||||||
|
|
||||||
# Вычисляем смещение на основе текущего yaw
|
# Вычисляем смещение на основе текущего yaw
|
||||||
dx = math.cos(math.pi / 2 + self.pos.yaw) * velocity / self.pos.z
|
dx = int(math.cos(math.pi / 2 + self.pos.yaw) * velocity)
|
||||||
dy = math.sin(math.pi / 2 + self.pos.yaw) * velocity / self.pos.z
|
dy = int(math.sin(math.pi / 2 + self.pos.yaw) * velocity)
|
||||||
|
|
||||||
self.update_trajectory(dx, dy)
|
self.update_trajectory(dx, dy)
|
||||||
|
self.online_map.move(dx, dy)
|
||||||
action.move_by_offset(-dx, dy)
|
|
||||||
action.release()
|
|
||||||
action.perform()
|
|
||||||
|
|
||||||
def set_zoom(self, zoom_level: float):
|
def set_zoom(self, zoom_level: float):
|
||||||
"""Программное изменение масштаба"""
|
"""Программное изменение масштаба"""
|
||||||
@@ -88,8 +81,7 @@ class Simulator:
|
|||||||
|
|
||||||
def get_chunk(self) -> VisionChunk:
|
def get_chunk(self) -> VisionChunk:
|
||||||
"""Получить текущий снимок с камеры дрона"""
|
"""Получить текущий снимок с камеры дрона"""
|
||||||
png = self.yandex_map.driver.get_screenshot_as_png()
|
im = self.online_map.make_screenshot()
|
||||||
im = Image.open(BytesIO(png))
|
|
||||||
|
|
||||||
# Применяем перспективную трансформацию
|
# Применяем перспективную трансформацию
|
||||||
transformed_im = self._apply_perspective_transform(im)
|
transformed_im = self._apply_perspective_transform(im)
|
||||||
|
|||||||
189
test_chunks.ipynb
Normal file
189
test_chunks.ipynb
Normal file
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
6
timer.py
6
timer.py
@@ -29,3 +29,9 @@ class Timer:
|
|||||||
self.elapsed = 0.
|
self.elapsed = 0.
|
||||||
self.enabled = False
|
self.enabled = False
|
||||||
self.last_enabled = 0.
|
self.last_enabled = 0.
|
||||||
|
|
||||||
|
def loop(self) -> float:
|
||||||
|
v = self.get_diff()
|
||||||
|
self.stop()
|
||||||
|
self.start()
|
||||||
|
return v
|
||||||
24
todo.md
24
todo.md
@@ -2,16 +2,16 @@
|
|||||||
[!] Проверка корректности выявления ориентира на кадре
|
[!] Проверка корректности выявления ориентира на кадре
|
||||||
[!] Исправление коррекции координат на основе сопоставления с ориентиром
|
[!] Исправление коррекции координат на основе сопоставления с ориентиром
|
||||||
|
|
||||||
[-] FPS счетчик
|
[+] FPS счетчик
|
||||||
| [-] Оптимизация детекции точек
|
| [+] Оптимизация детекции точек
|
||||||
[-] Оформление статистики при тестовых запусках
|
[+] Оформление статистики при тестовых запусках
|
||||||
[-] Проведение тестовых запусков
|
[+] Проведение тестовых запусков
|
||||||
[-] Оформление отчета
|
[+] Оформление отчета
|
||||||
[-] Эксперименты с разными детекторами (SIFT, KAZE)
|
[+] Эксперименты с разными детекторами (SIFT, KAZE)
|
||||||
|
|
||||||
[?] Изменение масштаба во время полёта, обработка этой трансформации
|
[+] Изменение масштаба во время полёта, обработка этой трансформации
|
||||||
[?] Поворот ориентиров
|
[+] Поворот ориентиров
|
||||||
[?] Ограничение выбора точек при построении маршрута, чтобы ориентиры полностью попадали в кадр
|
[+] Ограничение выбора точек при построении маршрута, чтобы ориентиры полностью попадали в кадр
|
||||||
|
|
||||||
[+] График межкадрового смещения
|
[+] График межкадрового смещения
|
||||||
| [+] График межкадровых смещениях по версии матрицы гомографии
|
| [+] График межкадровых смещениях по версии матрицы гомографии
|
||||||
@@ -19,6 +19,6 @@
|
|||||||
| [+] Исследовать причину погрешности при развороте
|
| [+] Исследовать причину погрешности при развороте
|
||||||
[+] Устранение большой погрешности при повороте
|
[+] Устранение большой погрешности при повороте
|
||||||
|
|
||||||
[ ] Переделать ключевые точки -> Optical Flow
|
[+] Переделать ключевые точки -> Optical Flow
|
||||||
[ ] Добавить перспективу
|
[+] Добавить перспективу
|
||||||
[ ] Эталоны на Google Maps, полёт тот же
|
[+] Эталоны на Google Maps, полёт тот же
|
||||||
|
|||||||
143
utility.py
143
utility.py
@@ -1,7 +1,11 @@
|
|||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
from datetime import datetime
|
||||||
|
from urllib.parse import parse_qs, urlparse, unquote
|
||||||
|
|
||||||
|
import constants
|
||||||
import cv2
|
import cv2
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import constants
|
import re
|
||||||
|
|
||||||
def cv2_to_pil(cv_image: np.ndarray) -> Image.Image:
|
def cv2_to_pil(cv_image: np.ndarray) -> Image.Image:
|
||||||
"""
|
"""
|
||||||
@@ -19,11 +23,10 @@ def get_normals(H: np.ndarray, R: np.ndarray, T: np.ndarray) -> np.ndarray:
|
|||||||
n = cv2.decomposeHomographyMat(H, constants.K, R, T)
|
n = cv2.decomposeHomographyMat(H, constants.K, R, T)
|
||||||
return n
|
return n
|
||||||
|
|
||||||
def estimate_transformation_matrix(src_pts: np.ndarray, dst_pts: np.ndarray) -> tuple[np.ndarray, float | None]:
|
def estimate_transformation_matrix(src_pts: np.ndarray, dst_pts: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
|
||||||
"""Оценивает матрицу трансформации на основе сопоставленных точек"""
|
"""Оценивает матрицу трансформации на основе сопоставленных точек"""
|
||||||
# Используем RANSAC для оценки матрицы гомографии
|
# Используем RANSAC для оценки матрицы гомографии
|
||||||
H, _ = cv2.findHomography(src_pts, dst_pts, cv2.RANSAC, 3.0, maxIters=1000)
|
return cv2.findHomography(src_pts, dst_pts, cv2.RANSAC, 3.0, maxIters=1000)
|
||||||
return H
|
|
||||||
|
|
||||||
def calc_camera_matrix(w: float, h: float):
|
def calc_camera_matrix(w: float, h: float):
|
||||||
f = constants._K_FOCUS_DISTANCE
|
f = constants._K_FOCUS_DISTANCE
|
||||||
@@ -32,3 +35,135 @@ def calc_camera_matrix(w: float, h: float):
|
|||||||
[0, f, h / 2],
|
[0, f, h / 2],
|
||||||
[0, 0, 1]
|
[0, 0, 1]
|
||||||
])
|
])
|
||||||
|
|
||||||
|
|
||||||
|
def generate_folder_name():
|
||||||
|
"""
|
||||||
|
Генерирует название для папки с текущей датой и временем до секунд.
|
||||||
|
Формат: YYYY-MM-DD_HH-MM-SS
|
||||||
|
"""
|
||||||
|
now = datetime.now()
|
||||||
|
folder_name = now.strftime("%Y-%m-%d_%H-%M-%S")
|
||||||
|
return folder_name
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def parse_yandex_maps_url(url):
|
||||||
|
"""
|
||||||
|
Парсит URL Яндекс.Карт и извлекает lat, lon и zoom
|
||||||
|
Формат: ?ll=lon,lat&z=zoom
|
||||||
|
"""
|
||||||
|
# Декодируем URL (на случай %2C вместо запятых)
|
||||||
|
url = unquote(url)
|
||||||
|
|
||||||
|
# Парсим URL
|
||||||
|
parsed_url = urlparse(url)
|
||||||
|
params = parse_qs(parsed_url.query)
|
||||||
|
|
||||||
|
if 'll' in params and 'z' in params:
|
||||||
|
# ll содержит "lon,lat"
|
||||||
|
ll_value = params['ll'][0]
|
||||||
|
lat, lon = map(float, ll_value.split(','))
|
||||||
|
zoom = int(params['z'][0])
|
||||||
|
|
||||||
|
return {
|
||||||
|
'lat': lat,
|
||||||
|
'lon': lon,
|
||||||
|
'zoom': zoom
|
||||||
|
}
|
||||||
|
return None
|
||||||
|
|
||||||
|
def parse_google_maps_url(url):
|
||||||
|
"""
|
||||||
|
Парсит URL Google Maps и извлекает lat, lon и zoom
|
||||||
|
Формат: /@lat,lon,zoom[m|z]
|
||||||
|
"""
|
||||||
|
pattern = r'/@([-\d.]+),([-\d.]+),(\d+)([mz])'
|
||||||
|
match = re.search(pattern, url)
|
||||||
|
|
||||||
|
if match:
|
||||||
|
lon = float(match.group(1))
|
||||||
|
lat = float(match.group(2))
|
||||||
|
zoom_value = int(match.group(3))
|
||||||
|
zoom_unit = match.group(4)
|
||||||
|
|
||||||
|
return {
|
||||||
|
'lat': lat,
|
||||||
|
'lon': lon,
|
||||||
|
'zoom': zoom_value,
|
||||||
|
'zoom_unit': zoom_unit
|
||||||
|
}
|
||||||
|
return None
|
||||||
|
|
||||||
|
def google_map_js_move_script(dx, dy) -> str:
|
||||||
|
return """
|
||||||
|
async function sleep(ms) {
|
||||||
|
return new Promise((resolve, reject) => {
|
||||||
|
setTimeout(() => resolve(), ms);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
async function simulateDrag(element, offsetX, offsetY) {
|
||||||
|
const rect = element.getBoundingClientRect();
|
||||||
|
const startX = rect.left + rect.width / 2;
|
||||||
|
const startY = rect.top + rect.height / 2;
|
||||||
|
const step = 10
|
||||||
|
const endX = startX + offsetX + step;
|
||||||
|
const endY = startY + offsetY + step;
|
||||||
|
|
||||||
|
element.dispatchEvent(new MouseEvent('mousedown', {
|
||||||
|
view: window,
|
||||||
|
bubbles: true,
|
||||||
|
cancelable: true,
|
||||||
|
clientX: startX,
|
||||||
|
clientY: startY,
|
||||||
|
button: 0
|
||||||
|
}));
|
||||||
|
|
||||||
|
let currentX = startX;
|
||||||
|
let currentY = startY;
|
||||||
|
|
||||||
|
function move(stepX, stepY) {
|
||||||
|
currentX += stepX;
|
||||||
|
currentY += stepY;
|
||||||
|
|
||||||
|
document.dispatchEvent(new MouseEvent('mousemove', {
|
||||||
|
view: window,
|
||||||
|
bubbles: true,
|
||||||
|
cancelable: false,
|
||||||
|
clientX: currentX,
|
||||||
|
clientY: currentY
|
||||||
|
}));
|
||||||
|
}
|
||||||
|
|
||||||
|
await sleep(50);
|
||||||
|
move(step, step)
|
||||||
|
|
||||||
|
while (currentX != endX || currentY != endY) {
|
||||||
|
await sleep(50);
|
||||||
|
const stepX = Math.min(step, Math.max(-step, endX - currentX));
|
||||||
|
const stepY = Math.min(step, Math.max(-step, endY - currentY));
|
||||||
|
move(stepX, stepY);
|
||||||
|
}
|
||||||
|
|
||||||
|
await sleep(50)
|
||||||
|
document.dispatchEvent(new MouseEvent('mouseup', {
|
||||||
|
view: window,
|
||||||
|
bubbles: true,
|
||||||
|
cancelable: false,
|
||||||
|
clientX: endX,
|
||||||
|
clientY: endY,
|
||||||
|
button: 0
|
||||||
|
}));
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
const canvas = document.querySelector('canvas.H1VXrf.JRr1M.DnOnV');
|
||||||
|
""" + f"simulateDrag(canvas, {int(-dx)}, {int(dy)});" + """
|
||||||
|
}
|
||||||
|
"""
|
||||||
@@ -1,10 +1,13 @@
|
|||||||
|
import constants
|
||||||
import cv2
|
import cv2
|
||||||
import json
|
import json
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
from PIL import Image
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from position import Position
|
||||||
|
from timer import Timer
|
||||||
from typing import Literal, Optional, Tuple
|
from typing import Literal, Optional, Tuple
|
||||||
from PIL import Image
|
|
||||||
|
|
||||||
FeatureMethod = Literal["orb", "sift", "akaze", "brisk"]
|
FeatureMethod = Literal["orb", "sift", "akaze", "brisk"]
|
||||||
DEFAULT_METHOD = "orb"
|
DEFAULT_METHOD = "orb"
|
||||||
@@ -14,6 +17,7 @@ class VisionChunk:
|
|||||||
image: Image.Image
|
image: Image.Image
|
||||||
feature_method: FeatureMethod = DEFAULT_METHOD
|
feature_method: FeatureMethod = DEFAULT_METHOD
|
||||||
|
|
||||||
|
pos: Optional[Position] = field(default=None, init=False)
|
||||||
keypoints: Optional[list] = field(default=None, init=False)
|
keypoints: Optional[list] = field(default=None, init=False)
|
||||||
descriptors: Optional[np.ndarray] = field(default=None, init=False)
|
descriptors: Optional[np.ndarray] = field(default=None, init=False)
|
||||||
_detector: Optional[cv2.Feature2D] = field(default=None, init=False, repr=False)
|
_detector: Optional[cv2.Feature2D] = field(default=None, init=False, repr=False)
|
||||||
@@ -27,12 +31,12 @@ class VisionChunk:
|
|||||||
self._detector = cv2.ORB_create(
|
self._detector = cv2.ORB_create(
|
||||||
nfeatures=1000,
|
nfeatures=1000,
|
||||||
scaleFactor=1.2,
|
scaleFactor=1.2,
|
||||||
nlevels=32,
|
nlevels=16,
|
||||||
edgeThreshold=31,
|
edgeThreshold=31,
|
||||||
firstLevel=0,
|
firstLevel=0,
|
||||||
WTA_K=2,
|
WTA_K=2,
|
||||||
patchSize=31,
|
patchSize=31,
|
||||||
fastThreshold=20,
|
fastThreshold=10,
|
||||||
)
|
)
|
||||||
elif self.feature_method == "sift":
|
elif self.feature_method == "sift":
|
||||||
self._detector = cv2.SIFT_create(
|
self._detector = cv2.SIFT_create(
|
||||||
@@ -70,30 +74,55 @@ class VisionChunk:
|
|||||||
return self._matcher
|
return self._matcher
|
||||||
|
|
||||||
def _preprocess(self, img_np: np.ndarray) -> np.ndarray:
|
def _preprocess(self, img_np: np.ndarray) -> np.ndarray:
|
||||||
"""CLAHE предобработка для улучшения контраста"""
|
"""Предобработка для улучшения сопоставления между снимками разного времени"""
|
||||||
if len(img_np.shape) == 3:
|
if len(img_np.shape) == 3:
|
||||||
gray = cv2.cvtColor(img_np, cv2.COLOR_RGB2GRAY)
|
gray = cv2.cvtColor(img_np, cv2.COLOR_RGB2GRAY)
|
||||||
else:
|
else:
|
||||||
gray = img_np
|
gray = img_np
|
||||||
|
|
||||||
|
# Гауссовское размытие для подавления шума и мелких различий
|
||||||
|
# blurred = cv2.GaussianBlur(gray, (5, 5), 1.0)
|
||||||
|
|
||||||
|
# CLAHE для выравнивания контраста между снимками
|
||||||
clahe = cv2.createCLAHE(clipLimit=3.0, tileGridSize=(8, 8))
|
clahe = cv2.createCLAHE(clipLimit=3.0, tileGridSize=(8, 8))
|
||||||
return clahe.apply(gray)
|
enhanced = clahe.apply(gray)
|
||||||
|
|
||||||
|
# Опционально: нормализация гистограммы для устранения различий в освещении
|
||||||
|
normalized = cv2.normalize(enhanced, None, 0, 255, cv2.NORM_MINMAX)
|
||||||
|
|
||||||
|
return normalized
|
||||||
|
|
||||||
|
|
||||||
def compute_keypoints(self, force: bool = False) -> Tuple[list[cv2.KeyPoint], Optional[np.ndarray]]:
|
def compute_keypoints(self, force: bool = False) -> Tuple[list[cv2.KeyPoint], Optional[np.ndarray]]:
|
||||||
if self.keypoints is not None and self.descriptors is not None and not force:
|
if self.keypoints is not None and self.descriptors is not None and not force:
|
||||||
return self.keypoints, self.descriptors
|
return self.keypoints, self.descriptors
|
||||||
|
|
||||||
|
timer = Timer()
|
||||||
|
timer.start()
|
||||||
detector = self._get_detector()
|
detector = self._get_detector()
|
||||||
|
|
||||||
|
if constants.DEBUG_FPS:
|
||||||
|
print(f"[VC-DETECTION]: get_detector: {timer.loop() * 1000:.2f} ms")
|
||||||
|
|
||||||
# PIL -> OpenCV (RGB->BGR)
|
# PIL -> OpenCV (RGB->BGR)
|
||||||
img_np = np.array(self.image)
|
img_np = np.array(self.image)
|
||||||
if img_np.ndim == 3 and img_np.shape[2] == 3:
|
if img_np.ndim == 3 and img_np.shape[2] == 3:
|
||||||
img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR)
|
img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR)
|
||||||
|
|
||||||
|
if constants.DEBUG_FPS:
|
||||||
|
print(f"[VC-DETECTION]: converting: {timer.loop() * 1000:.2f} ms")
|
||||||
|
|
||||||
# CLAHE предобработка
|
# CLAHE предобработка
|
||||||
preprocessed = self._preprocess(img_np)
|
preprocessed = self._preprocess(img_np)
|
||||||
|
|
||||||
|
if constants.DEBUG_FPS:
|
||||||
|
print(f"[VC-DETECTION]: preprocess: {timer.loop() * 1000:.2f} ms")
|
||||||
|
|
||||||
keypoints, descriptors = detector.detectAndCompute(preprocessed, None)
|
keypoints, descriptors = detector.detectAndCompute(preprocessed, None)
|
||||||
|
|
||||||
|
if constants.DEBUG_FPS:
|
||||||
|
print(f"[VC-DETECTION]: detect and compute: {timer.loop() * 1000:.2f} ms")
|
||||||
|
|
||||||
# Получаем массив response для всех точек
|
# Получаем массив response для всех точек
|
||||||
responses = np.array([kp.response for kp in keypoints])
|
responses = np.array([kp.response for kp in keypoints])
|
||||||
|
|
||||||
@@ -104,6 +133,9 @@ class VisionChunk:
|
|||||||
best_keypoints = [keypoints[i] for i in top_indices]
|
best_keypoints = [keypoints[i] for i in top_indices]
|
||||||
best_descriptors = descriptors[top_indices]
|
best_descriptors = descriptors[top_indices]
|
||||||
|
|
||||||
|
if constants.DEBUG_FPS:
|
||||||
|
print(f"[VC-DETECTION]: filtration: {timer.loop() * 1000:.2f} ms")
|
||||||
|
|
||||||
self.keypoints = best_keypoints
|
self.keypoints = best_keypoints
|
||||||
self.descriptors = best_descriptors
|
self.descriptors = best_descriptors
|
||||||
return self.keypoints, self.descriptors
|
return self.keypoints, self.descriptors
|
||||||
@@ -122,15 +154,29 @@ class VisionChunk:
|
|||||||
Возвращает: src_pts, dst_pts, good_matches, kp1, kp2 (отцентрированные координаты)
|
Возвращает: src_pts, dst_pts, good_matches, kp1, kp2 (отцентрированные координаты)
|
||||||
"""
|
"""
|
||||||
# Вычисляем keypoints для обоих
|
# Вычисляем keypoints для обоих
|
||||||
|
timer = Timer()
|
||||||
|
timer.start()
|
||||||
|
|
||||||
kp1, des1 = self.compute_keypoints()
|
kp1, des1 = self.compute_keypoints()
|
||||||
|
|
||||||
|
if constants.DEBUG_FPS:
|
||||||
|
print(f"[VC-KEYPOINTS]: computing 1: {timer.loop() * 1000:.2f} ms")
|
||||||
|
|
||||||
kp2, des2 = other.compute_keypoints()
|
kp2, des2 = other.compute_keypoints()
|
||||||
|
|
||||||
|
if constants.DEBUG_FPS:
|
||||||
|
print(f"[VC-KEYPOINTS]: computing 2: {timer.loop() * 1000:.2f} ms")
|
||||||
|
|
||||||
if des1 is None or des2 is None or len(kp1) < 4 or len(kp2) < 4:
|
if des1 is None or des2 is None or len(kp1) < 4 or len(kp2) < 4:
|
||||||
return None, None, None, None, None
|
return None, None, None, None, None
|
||||||
|
|
||||||
# kNN matching + Lowe ratio test
|
# kNN matching + Lowe ratio test
|
||||||
matcher = self._get_matcher()
|
matcher = self._get_matcher()
|
||||||
matches_knn = matcher.knnMatch(des1, des2, k=2)
|
matches_knn = matcher.knnMatch(des1, des2, k=2)
|
||||||
|
|
||||||
|
if constants.DEBUG_FPS:
|
||||||
|
print(f"[VC-KEYPOINTS]: matching: {timer.loop() * 1000:.2f} ms")
|
||||||
|
|
||||||
good_matches: list[cv2.DMatch] = []
|
good_matches: list[cv2.DMatch] = []
|
||||||
|
|
||||||
for m_n in matches_knn:
|
for m_n in matches_knn:
|
||||||
@@ -147,15 +193,6 @@ class VisionChunk:
|
|||||||
if len(good_matches) < 4:
|
if len(good_matches) < 4:
|
||||||
return None, None, None, None, None
|
return None, None, None, None, None
|
||||||
|
|
||||||
# Центр изображений
|
|
||||||
img1_cv = self.to_cv2_gray()
|
|
||||||
img2_cv = other.to_cv2_gray()
|
|
||||||
h1, w1 = img1_cv.shape
|
|
||||||
h2, w2 = img2_cv.shape
|
|
||||||
cx1, cy1 = w1 // 2, h1 // 2
|
|
||||||
cx2, cy2 = w2 // 2, h2 // 2
|
|
||||||
|
|
||||||
# Отцентрированные координаты (x_rel, y_rel)
|
|
||||||
src_pts = []
|
src_pts = []
|
||||||
dst_pts = []
|
dst_pts = []
|
||||||
|
|
||||||
@@ -169,6 +206,9 @@ class VisionChunk:
|
|||||||
src_pts = np.float32(src_pts).reshape(-1, 1, 2)
|
src_pts = np.float32(src_pts).reshape(-1, 1, 2)
|
||||||
dst_pts = np.float32(dst_pts).reshape(-1, 1, 2)
|
dst_pts = np.float32(dst_pts).reshape(-1, 1, 2)
|
||||||
|
|
||||||
|
if constants.DEBUG_FPS:
|
||||||
|
print(f"[VC-KEYPOINTS]: filtration: {timer.loop() * 1000:.2f} ms")
|
||||||
|
|
||||||
return src_pts, dst_pts, good_matches, kp1, kp2
|
return src_pts, dst_pts, good_matches, kp1, kp2
|
||||||
|
|
||||||
def to_cv2_gray(self) -> np.ndarray:
|
def to_cv2_gray(self) -> np.ndarray:
|
||||||
|
|||||||
@@ -3,14 +3,16 @@
|
|||||||
Модуль для управления общим окном визуализации
|
Модуль для управления общим окном визуализации
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
from PIL import Image
|
||||||
|
from enum import Enum
|
||||||
|
from scipy.interpolate import make_interp_spline
|
||||||
|
|
||||||
|
import cv2
|
||||||
|
import matplotlib
|
||||||
import matplotlib.axes
|
import matplotlib.axes
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
import matplotlib.patches as patches
|
import matplotlib.patches as patches
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from enum import Enum
|
|
||||||
import cv2
|
|
||||||
from PIL import Image
|
|
||||||
import matplotlib
|
|
||||||
|
|
||||||
# Настройки matplotlib
|
# Настройки matplotlib
|
||||||
matplotlib.use('TkAgg')
|
matplotlib.use('TkAgg')
|
||||||
@@ -93,14 +95,14 @@ class VisualizationManager:
|
|||||||
self.ax_matches.axis('off')
|
self.ax_matches.axis('off')
|
||||||
|
|
||||||
# Сопоставление точек (средний средний угол)
|
# Сопоставление точек (средний средний угол)
|
||||||
self.ax_chunk_matches = self.fig.add_subplot(gs[1, 2])
|
self.ax_chunk_matches = self.fig.add_subplot(gs[1, 1:3])
|
||||||
self.ax_chunk_matches.set_title('Chunk Matching')
|
self.ax_chunk_matches.set_title('Chunk Matching')
|
||||||
self.ax_chunk_matches.axis('off')
|
self.ax_chunk_matches.axis('off')
|
||||||
|
|
||||||
# Визуализация движения ключевых точек (левый нижний угол)
|
# Визуализация движения ключевых точек (левый нижний угол)
|
||||||
self.ax_motion_vectors = self.fig.add_subplot(gs[1, 1])
|
# self.ax_motion_vectors = self.fig.add_subplot(gs[1, 1])
|
||||||
self.ax_motion_vectors.set_title('Motion Vectors - Движение ключевых точек')
|
# self.ax_motion_vectors.set_title('Motion Vectors - Движение ключевых точек')
|
||||||
self.ax_motion_vectors.axis('off')
|
# self.ax_motion_vectors.axis('off')
|
||||||
|
|
||||||
# Визуализация движения ключевых точек на основе матрицы гомографии
|
# Визуализация движения ключевых точек на основе матрицы гомографии
|
||||||
self.ax_motion_gomography = self.fig.add_subplot(gs[0, 1])
|
self.ax_motion_gomography = self.fig.add_subplot(gs[0, 1])
|
||||||
@@ -157,7 +159,7 @@ class VisualizationManager:
|
|||||||
# Рисуем текущую целевую точку
|
# Рисуем текущую целевую точку
|
||||||
if self.target_idx < len(self.target_pts):
|
if self.target_idx < len(self.target_pts):
|
||||||
pt = self.target_pts[self.target_idx]
|
pt = self.target_pts[self.target_idx]
|
||||||
self.ax_global_map.plot(pt[0], pt[1], 'yo', markersize=8, label='Цель (0, 0)')
|
self.ax_global_map.plot(pt[0], pt[1], 'yo', markersize=8, label='Цель')
|
||||||
|
|
||||||
self.ax_global_map.legend()
|
self.ax_global_map.legend()
|
||||||
|
|
||||||
@@ -208,7 +210,37 @@ class VisualizationManager:
|
|||||||
self.ax_error_plot.grid(True, alpha=0.3)
|
self.ax_error_plot.grid(True, alpha=0.3)
|
||||||
|
|
||||||
if len(self.error_times) > 1:
|
if len(self.error_times) > 1:
|
||||||
self.ax_error_plot.plot(self.error_times, self.position_errors, 'b-', linewidth=2)
|
# Оригинальный график (более прозрачный)
|
||||||
|
self.ax_error_plot.plot(self.error_times, self.position_errors, 'b-',
|
||||||
|
linewidth=1, alpha=0.4, label='Погрешность данных')
|
||||||
|
|
||||||
|
if len(self.error_times) > 5:
|
||||||
|
# Сглаженный график
|
||||||
|
smoothed_times = np.linspace(self.error_times[0], self.error_times[-1], 300)
|
||||||
|
spl = make_interp_spline(self.error_times, self.position_errors, k=3)
|
||||||
|
smoothed_errors = spl(smoothed_times)
|
||||||
|
|
||||||
|
|
||||||
|
self.ax_error_plot.plot(smoothed_times, smoothed_errors, 'orange',
|
||||||
|
linewidth=2, label='Сглаженный тренд')
|
||||||
|
|
||||||
|
# if len(self.position_errors) > 5: # Достаточно данных для сглаживания
|
||||||
|
# window_size = min(11, len(self.position_errors) // 3) # Адаптивный размер окна
|
||||||
|
# if window_size % 2 == 0: # Должен быть нечетным
|
||||||
|
# window_size += 1
|
||||||
|
|
||||||
|
# # Метод скользящего среднего
|
||||||
|
# smoothed_errors = np.convolve(
|
||||||
|
# self.position_errors,
|
||||||
|
# np.ones(window_size) / window_size,
|
||||||
|
# mode='valid'
|
||||||
|
# )
|
||||||
|
|
||||||
|
# # Корректируем временную ось для сглаженных данных
|
||||||
|
# offset = (window_size - 1) // 2
|
||||||
|
# smoothed_times = self.error_times[offset:offset + len(smoothed_errors)]
|
||||||
|
|
||||||
|
self.ax_error_plot.legend(loc='upper right')
|
||||||
|
|
||||||
# Автоматически масштабируем оси
|
# Автоматически масштабируем оси
|
||||||
if len(self.position_errors) > 0:
|
if len(self.position_errors) > 0:
|
||||||
@@ -219,6 +251,7 @@ class VisualizationManager:
|
|||||||
else:
|
else:
|
||||||
self.ax_error_plot.set_ylim(0, 1)
|
self.ax_error_plot.set_ylim(0, 1)
|
||||||
|
|
||||||
|
|
||||||
def update_matches(self, img1: np.ndarray, img2: np.ndarray,
|
def update_matches(self, img1: np.ndarray, img2: np.ndarray,
|
||||||
kp1, kp2, matches, transformation_info=None):
|
kp1, kp2, matches, transformation_info=None):
|
||||||
"""Обновляет визуализацию сопоставления точек"""
|
"""Обновляет визуализацию сопоставления точек"""
|
||||||
|
|||||||
@@ -1,16 +1,17 @@
|
|||||||
import math
|
|
||||||
from io import BytesIO
|
|
||||||
from time import sleep
|
|
||||||
import os
|
|
||||||
|
|
||||||
import cv2
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
from io import BytesIO
|
||||||
from selenium.webdriver.common.actions.wheel_input import ScrollOrigin
|
from selenium.webdriver.common.actions.wheel_input import ScrollOrigin
|
||||||
from selenium import webdriver
|
from selenium import webdriver
|
||||||
from selenium.webdriver.common.by import By
|
from selenium.webdriver.common.by import By
|
||||||
from selenium.webdriver.common.action_chains import ActionChains
|
from selenium.webdriver.common.action_chains import ActionChains
|
||||||
|
from time import sleep
|
||||||
|
|
||||||
|
import constants
|
||||||
|
import cv2
|
||||||
|
import math
|
||||||
|
import numpy as np
|
||||||
|
import os
|
||||||
|
import utility
|
||||||
|
|
||||||
def generateURL(lat: float, lon: float, zoom: int):
|
def generateURL(lat: float, lon: float, zoom: int):
|
||||||
return f"https://yandex.ru/maps/43/kazan/?l=sat&ll={lat}%2C{lon}&z={zoom}"
|
return f"https://yandex.ru/maps/43/kazan/?l=sat&ll={lat}%2C{lon}&z={zoom}"
|
||||||
@@ -24,6 +25,7 @@ class YandexMap:
|
|||||||
self.initial_lat = initial_lat
|
self.initial_lat = initial_lat
|
||||||
self.initial_lon = initial_lon
|
self.initial_lon = initial_lon
|
||||||
self.initial_zoom = initial_zoom
|
self.initial_zoom = initial_zoom
|
||||||
|
self.pixel_ratio = constants.YANDEX_PIXEL_RATIO[self.initial_zoom]
|
||||||
|
|
||||||
options = webdriver.ChromeOptions()
|
options = webdriver.ChromeOptions()
|
||||||
# options.add_experimental_option("detach", True)
|
# options.add_experimental_option("detach", True)
|
||||||
@@ -32,9 +34,8 @@ class YandexMap:
|
|||||||
self.driver.maximize_window()
|
self.driver.maximize_window()
|
||||||
sleep(2)
|
sleep(2)
|
||||||
|
|
||||||
action = ActionChains(self.driver)
|
|
||||||
|
|
||||||
# Закрытие левой панели
|
# Закрытие левой панели
|
||||||
|
action = ActionChains(self.driver)
|
||||||
action.click(self.driver.find_element(By.CLASS_NAME, 'sidebar-toggle-button'))
|
action.click(self.driver.find_element(By.CLASS_NAME, 'sidebar-toggle-button'))
|
||||||
action.perform()
|
action.perform()
|
||||||
|
|
||||||
@@ -43,8 +44,25 @@ class YandexMap:
|
|||||||
self.driver.execute_script("arguments[0].remove();", self.driver.find_element(By.XPATH, "//nav[@class='map-controls']"))
|
self.driver.execute_script("arguments[0].remove();", self.driver.find_element(By.XPATH, "//nav[@class='map-controls']"))
|
||||||
self.driver.execute_script("arguments[0].remove();", self.driver.find_element(By.XPATH, "//footer"))
|
self.driver.execute_script("arguments[0].remove();", self.driver.find_element(By.XPATH, "//footer"))
|
||||||
|
|
||||||
|
self.move(39, -9)
|
||||||
|
|
||||||
sleep(0.2)
|
sleep(0.2)
|
||||||
|
|
||||||
|
def open(self, lat, lon, zoom):
|
||||||
|
self.initial_lat = lat
|
||||||
|
self.initial_lon = lon
|
||||||
|
self.initial_zoom = zoom
|
||||||
|
self.pixel_ratio = constants.YANDEX_PIXEL_RATIO[self.initial_zoom]
|
||||||
|
self.driver.get(generateURL(lat, lon, zoom))
|
||||||
|
sleep(2)
|
||||||
|
|
||||||
|
# Закрытие левой панели
|
||||||
|
action = ActionChains(self.driver)
|
||||||
|
action.click(self.driver.find_element(By.CLASS_NAME, 'sidebar-toggle-button'))
|
||||||
|
action.perform()
|
||||||
|
|
||||||
|
self.move(39, -9)
|
||||||
|
|
||||||
def save_photo(self, filename: str) -> bytes:
|
def save_photo(self, filename: str) -> bytes:
|
||||||
return self.driver.save_screenshot(filename)
|
return self.driver.save_screenshot(filename)
|
||||||
|
|
||||||
@@ -68,51 +86,26 @@ class YandexMap:
|
|||||||
if i != count - 1:
|
if i != count - 1:
|
||||||
sleep(0.25)
|
sleep(0.25)
|
||||||
|
|
||||||
def make_as_center(self, x: float, y: float):
|
def move(self, dx: float, dy: float):
|
||||||
html = self.driver.find_element(By.TAG_NAME, 'html')
|
html = self.driver.find_element(By.TAG_NAME, 'html')
|
||||||
|
|
||||||
action = ActionChains(self.driver)
|
action = ActionChains(self.driver)
|
||||||
action.move_to_element_with_offset(html, 0, 0)
|
action.move_to_element_with_offset(html, 0, 0)
|
||||||
action.click_and_hold()
|
action.click_and_hold()
|
||||||
|
|
||||||
dx = (x - 0.5) * self.get_size()[0]
|
|
||||||
dy = (0.5 - y) * self.get_size()[1]
|
|
||||||
print(dx, dy)
|
|
||||||
action.move_by_offset(-dx, dy)
|
action.move_by_offset(-dx, dy)
|
||||||
action.release()
|
action.release()
|
||||||
action.perform()
|
action.perform()
|
||||||
|
|
||||||
def make_screenshot(self, x: float, y: float, width: float, height: float) -> cv2.typing.MatLike:
|
|
||||||
# Сохраняем скриншот
|
def make_as_center(self, x: float, y: float):
|
||||||
self.save_photo("temp.png")
|
dx = (x - 0.5) * self.get_size()[0]
|
||||||
|
dy = (0.5 - y) * self.get_size()[1]
|
||||||
# Загружаем изображение
|
self.move(dx, dy)
|
||||||
image = cv2.imread("temp.png")
|
|
||||||
|
def make_screenshot(self) -> Image.Image:
|
||||||
if image is None:
|
png = self.driver.get_screenshot_as_png()
|
||||||
raise ValueError("Не удалось загрузить изображение temp.png")
|
im = Image.open(BytesIO(png))
|
||||||
|
return utility.cv2_to_pil(np.array(im)[:, :])
|
||||||
# Получаем размеры исходного изображения
|
|
||||||
img_height, img_width = image.shape[:2]
|
def get_geolocation(self):
|
||||||
|
current_url = self.driver.current_url
|
||||||
# Преобразуем относительные координаты в абсолютные пиксели
|
return utility.parse_yandex_maps_url(current_url)
|
||||||
center_x = int(x * img_width)
|
|
||||||
center_y = int(y * img_height)
|
|
||||||
crop_width = int(width * img_width)
|
|
||||||
crop_height = int(height * img_height)
|
|
||||||
|
|
||||||
# Вычисляем координаты прямоугольника для кадрирования
|
|
||||||
x1 = max(0, center_x - crop_width // 2)
|
|
||||||
y1 = max(0, center_y - crop_height // 2)
|
|
||||||
x2 = min(img_width, center_x + crop_width // 2)
|
|
||||||
y2 = min(img_height, center_y + crop_height // 2)
|
|
||||||
|
|
||||||
# Проверяем, что прямоугольник имеет положительные размеры
|
|
||||||
if x2 <= x1 or y2 <= y1:
|
|
||||||
raise ValueError("Некорректные размеры для кадрирования")
|
|
||||||
|
|
||||||
# Кадрируем изображение
|
|
||||||
cropped_image = image[y1:y2, x1:x2]
|
|
||||||
|
|
||||||
# Если нужно вернуть изображение как результат функции:
|
|
||||||
return cropped_image
|
|
||||||
|
|||||||
Reference in New Issue
Block a user