Compare commits
13 Commits
31c0f13361
...
feat/pitch
| Author | SHA1 | Date | |
|---|---|---|---|
| 0cc210968f | |||
| 6040f3b253 | |||
| 6d4208d100 | |||
|
|
339e5f210c | ||
|
|
5385641d28 | ||
| 64c9215f5b | |||
| 7cd700c1fa | |||
| 05c249ed78 | |||
| fbd0d01b35 | |||
| e00878daad | |||
| ceca8a6e75 | |||
| 6456d18212 | |||
| 3ee3599b87 |
3
.gitignore
vendored
3
.gitignore
vendored
@@ -1,4 +1,5 @@
|
||||
.venv
|
||||
__pycache__
|
||||
*.png
|
||||
images
|
||||
trajectories
|
||||
z
|
||||
|
||||
222
autopilot.py
222
autopilot.py
@@ -3,6 +3,7 @@ from pathlib import Path
|
||||
import math
|
||||
import random
|
||||
|
||||
import constants
|
||||
import cv2
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
@@ -56,14 +57,16 @@ class AutoPilot(Pilot):
|
||||
|
||||
# Положение на основе ориентира
|
||||
reserved_pos: Position | None
|
||||
proccessing_time: float
|
||||
|
||||
def __init__(self, points = [], chunks = [], viz_manager=None):
|
||||
def __init__(self, points = [], chunks = [], viz_manager=None, pixel_ratio: float = 1.):
|
||||
self.prev_chunk = None
|
||||
self.pos = Position(0, 0, 1, 0, 0, 0)
|
||||
self.chunks = chunks
|
||||
self.frame_count = 0
|
||||
self.vis_manager = viz_manager # Менеджер визуализации
|
||||
self.reserved_pos = None
|
||||
self.pixel_ratio = pixel_ratio
|
||||
|
||||
# Пороговые значения качества сопоставления/гомографии
|
||||
self.min_inliers: int = 12
|
||||
@@ -76,6 +79,7 @@ class AutoPilot(Pilot):
|
||||
self.target_idx = 0
|
||||
|
||||
self.timer = Timer()
|
||||
self.chunk_points = np.array([[chunk.pos.x, chunk.pos.y] for chunk in self.chunks])
|
||||
|
||||
def get_position(self) -> tuple[float, float]:
|
||||
return self.pos.x, self.pos.y
|
||||
@@ -91,7 +95,7 @@ class AutoPilot(Pilot):
|
||||
h, w = prev_gray.shape[:2]
|
||||
|
||||
# Создаем сетку точек для отслеживания (аналогично вашему step=20)
|
||||
step = 35
|
||||
step = 20
|
||||
grid_points = []
|
||||
for y in range(step, h - step, step):
|
||||
for x in range(step, w - step, step):
|
||||
@@ -133,9 +137,6 @@ class AutoPilot(Pilot):
|
||||
"""
|
||||
self.pos.iapply(homography_matrix)
|
||||
|
||||
if self.reserved_pos is not None:
|
||||
self.reserved_pos.iapply(homography_matrix)
|
||||
|
||||
def get_drone_state(self) -> dict:
|
||||
"""
|
||||
Возвращает текущее состояние БПЛА
|
||||
@@ -149,43 +150,122 @@ class AutoPilot(Pilot):
|
||||
}
|
||||
|
||||
def get_position_by_chunk(self) -> Position | None:
|
||||
# Пытаемся найти ориентир на картинке:
|
||||
landmark_timer = Timer()
|
||||
landmark_timer.start()
|
||||
|
||||
cur_pos = np.array([self.pos.x, self.pos.y])
|
||||
closest_chunk_idx = ((self.chunk_points - cur_pos) ** 2).sum(1).argmin()
|
||||
|
||||
current_chunk = self.prev_chunk
|
||||
landmark_chunk = self.chunks[self.target_idx]
|
||||
landmark_chunk = self.chunks[closest_chunk_idx]
|
||||
|
||||
if constants.DEBUG_FPS:
|
||||
print(f"[LANDMARK]: Closest chunk finding: {landmark_timer.loop() * 1000:.2f} ms")
|
||||
|
||||
# Краевой случай: отсутствие чанков
|
||||
if current_chunk is None or landmark_chunk is None:
|
||||
return None
|
||||
|
||||
landmark_timer.start()
|
||||
src_pts, dst_pts, matches, kp1, kp2 = landmark_chunk.detect_and_match_keypoints(current_chunk)
|
||||
|
||||
if constants.DEBUG_FPS:
|
||||
print(f"[LANDMARK]: detect and match keypoints: {landmark_timer.loop() * 1000:.2f} ms")
|
||||
|
||||
landmark_timer.stop()
|
||||
|
||||
# Визуализация (если нужна)
|
||||
if src_pts is not None and dst_pts is not None and self.vis_manager:
|
||||
was_enabled = self.timer.enabled
|
||||
if was_enabled: self.timer.stop()
|
||||
self.vis_manager.update_chunk_matches(landmark_chunk.to_numpy(), current_chunk.to_numpy(), kp1, kp2, matches)
|
||||
if was_enabled: self.timer.start()
|
||||
if was_enabled:
|
||||
self.timer.stop()
|
||||
self.vis_manager.update_chunk_matches(
|
||||
landmark_chunk.to_numpy(),
|
||||
current_chunk.to_numpy(),
|
||||
kp1, kp2, matches
|
||||
)
|
||||
if was_enabled:
|
||||
self.timer.start()
|
||||
|
||||
if src_pts is not None and dst_pts is not None:
|
||||
# Оцениваем матрицу трансформации
|
||||
landmark_transform = self.estimate_transformation_matrix(src_pts, dst_pts)
|
||||
# Если ориентир достоверно найден — скорректируем глобальные координаты и угол
|
||||
if landmark_transform is not None:
|
||||
ok_scale = (self.min_scale <= landmark_transform['scale'] <= self.max_scale)
|
||||
ok_inliers = (landmark_transform.get('inliers', 0) >= self.min_inliers)
|
||||
ratio = landmark_transform.get('inliers_ratio', 0.0)
|
||||
ok_ratio = (ratio >= self.min_inlier_ratio)
|
||||
rmse = landmark_transform.get('rmse', None)
|
||||
ok_rmse = (rmse is not None and rmse <= self.max_reproj_rmse)
|
||||
if ok_scale and ok_inliers and ok_ratio and ok_rmse:
|
||||
# print("[HELP]")
|
||||
# print("Matrix", landmark_transform['homography'])
|
||||
# print("Position", self.x, self.y)
|
||||
# print("Position of point", self.points[self.target_idx])
|
||||
# print("[PILOT]", rmse, ratio, ok_rmse)
|
||||
# if False:
|
||||
# Считаем абсолютную позу относительно координат ориентира
|
||||
landmark_world_x, landmark_world_y = self.points[self.target_idx]
|
||||
landmark = Position(landmark_world_x, landmark_world_y, 1, 0)
|
||||
homography = landmark_transform['homography']
|
||||
# homography = np.linalg.inv(homography)
|
||||
print(f" [Pilot] Landmark correction applied (inliers={landmark_transform['inliers']}, ratio={ratio:.2f}, rmse={rmse:.2f})")
|
||||
return landmark @ homography
|
||||
return None
|
||||
landmark_timer.start()
|
||||
# Краевой случай: нет точек или недостаточно матчей
|
||||
if src_pts is None or dst_pts is None:
|
||||
return None
|
||||
|
||||
num_matches = len(src_pts)
|
||||
if num_matches < 20:
|
||||
return None
|
||||
|
||||
# Оценка матрицы гомографии
|
||||
landmark_timer.loop()
|
||||
landmark_transform, mask = estimate_transformation_matrix(src_pts, dst_pts)
|
||||
num_inliers = int(np.sum(mask))
|
||||
|
||||
if constants.DEBUG_FPS:
|
||||
print(f"[LANDMARK]: matrix estimation: {landmark_timer.loop() * 1000:.2f} ms")
|
||||
|
||||
# Краевой случай: матрица не найдена
|
||||
if landmark_transform is None or mask is None:
|
||||
return None
|
||||
|
||||
# === КРИТЕРИИ ПРИНЯТИЯ РЕШЕНИЯ ===
|
||||
|
||||
# 1. Минимальное количество инлайеров (абсолютное)
|
||||
MIN_INLIERS_ABSOLUTE = 6
|
||||
if num_inliers < MIN_INLIERS_ABSOLUTE:
|
||||
return None
|
||||
|
||||
# 2. Процент инлайеров от общего числа матчей
|
||||
inlier_ratio = num_inliers / num_matches
|
||||
|
||||
if constants.DEBUG_LANDMARK:
|
||||
print("[LANDMARK]: inlier_ratio=", inlier_ratio)
|
||||
|
||||
MIN_INLIER_RATIO = 0.6
|
||||
if inlier_ratio < MIN_INLIER_RATIO:
|
||||
return None
|
||||
|
||||
# 3. Проверка качества гомографии (детерминант для выявления вырожденных случаев)
|
||||
det = np.linalg.det(landmark_transform[:2, :2])
|
||||
# Детерминант должен быть близок к 1 (без сильного масштабирования)
|
||||
if abs(det) < 0.1 or abs(det) > 10.0:
|
||||
return None
|
||||
|
||||
# 4. Проверка на валидность перспективного преобразования
|
||||
# Элементы третьей строки не должны быть слишком большими
|
||||
if abs(landmark_transform[2, 0]) > 0.01 or abs(landmark_transform[2, 1]) > 0.01:
|
||||
return None
|
||||
|
||||
# 5. Дополнительная проверка: средняя ошибка репроекции для инлайеров
|
||||
inlier_src = src_pts[mask.ravel() == 1]
|
||||
inlier_dst = dst_pts[mask.ravel() == 1]
|
||||
|
||||
# Преобразуем точки через найденную гомографию
|
||||
transformed_pts = cv2.perspectiveTransform(inlier_src, landmark_transform)
|
||||
|
||||
# Вычисляем ошибку репроекции
|
||||
reprojection_errors = np.sqrt(np.sum((transformed_pts - inlier_dst) ** 2, axis=2))
|
||||
mean_error = np.mean(reprojection_errors)
|
||||
|
||||
MAX_MEAN_REPROJECTION_ERROR = 1.1 # пиксели
|
||||
|
||||
if constants.DEBUG_LANDMARK:
|
||||
print("[LANDMARK]: Mean_error=", mean_error)
|
||||
|
||||
if mean_error > MAX_MEAN_REPROJECTION_ERROR:
|
||||
return None
|
||||
|
||||
# 6. Проверка стабильности: если слишком много хороших совпадений, но мало инлайеров - подозрительно
|
||||
if num_matches > 50 and inlier_ratio < 0.15:
|
||||
return None
|
||||
|
||||
# === ВСЕ ПРОВЕРКИ ПРОЙДЕНЫ ===
|
||||
print("[LANDMARK]: Correction Applied")
|
||||
|
||||
if constants.DEBUG_FPS:
|
||||
print(f"[LANDMARK]: time: {landmark_timer.get_elapsed() * 1000:.2f} ms")
|
||||
|
||||
return landmark_chunk.pos.apply(landmark_transform)
|
||||
|
||||
|
||||
def handle(self, current_chunk: VisionChunk) -> PilotCommand:
|
||||
@@ -197,17 +277,11 @@ class AutoPilot(Pilot):
|
||||
self.timer.stop()
|
||||
return PilotCommand(processing_time=self.timer.get_elapsed())
|
||||
|
||||
# Расстояние до цели
|
||||
distance_to_target = math.sqrt(
|
||||
(self.points[self.target_idx][0] - self.pos.x) ** 2 +
|
||||
(self.points[self.target_idx][1] - self.pos.y) ** 2
|
||||
)
|
||||
|
||||
# Вычисляем оптический поток для покадрового сравнения
|
||||
matching_timer = Timer()
|
||||
matching_timer.start()
|
||||
# src_pts, dst_pts = self.calculate_optical_flow(self.prev_chunk, current_chunk)
|
||||
src_pts, dst_pts, _, _, _ = self.prev_chunk.detect_and_match_keypoints(current_chunk)
|
||||
src_pts, dst_pts = self.calculate_optical_flow(self.prev_chunk, current_chunk)
|
||||
# src_pts, dst_pts, _, _, _ = self.prev_chunk.detect_and_match_keypoints(current_chunk)
|
||||
matching_timer.stop()
|
||||
print(f"Matching calculating: {matching_timer.get_elapsed() * 1000:.2f} ms")
|
||||
|
||||
@@ -217,7 +291,7 @@ class AutoPilot(Pilot):
|
||||
|
||||
matrix_estimation_timer = Timer()
|
||||
matrix_estimation_timer.start()
|
||||
homography_matrix = estimate_transformation_matrix(src_pts, dst_pts)
|
||||
homography_matrix, _ = estimate_transformation_matrix(src_pts, dst_pts)
|
||||
matrix_estimation_timer.stop()
|
||||
print(f"Transformation matrix updating: {matrix_estimation_timer.get_elapsed() * 1000:.2f} ms")
|
||||
|
||||
@@ -235,17 +309,24 @@ class AutoPilot(Pilot):
|
||||
|
||||
self.timer.start()
|
||||
|
||||
chunk_timer = Timer()
|
||||
chunk_timer.start()
|
||||
|
||||
# Пытаемся найти ориентир на картинке:
|
||||
self.prev_chunk = current_chunk
|
||||
# npos = self.get_position_by_chunk()
|
||||
# if npos is not None:
|
||||
# self.reserved_pos = npos
|
||||
# Для улучшения среднего FPS
|
||||
if self.frame_count % 5 == 0:
|
||||
pos_by_chunk = self.get_position_by_chunk()
|
||||
if pos_by_chunk is not None:
|
||||
self.pos = pos_by_chunk
|
||||
|
||||
chunk_timer.stop()
|
||||
print(f"Chunk timer: {chunk_timer.get_elapsed() * 1000:.2f} ms")
|
||||
command = self.make_command()
|
||||
self.timer.reset()
|
||||
return command
|
||||
|
||||
def make_command(self) -> PilotCommand:
|
||||
# Расстояние до цели
|
||||
distance_to_target = math.sqrt(
|
||||
(self.points[self.target_idx][0] - self.pos.x) ** 2 +
|
||||
(self.points[self.target_idx][1] - self.pos.y) ** 2
|
||||
) * self.pixel_ratio
|
||||
|
||||
if distance_to_target < 35:
|
||||
self.target_idx += 1
|
||||
@@ -258,16 +339,32 @@ class AutoPilot(Pilot):
|
||||
(self.points[self.target_idx][1] - self.pos.y) ** 2
|
||||
)
|
||||
|
||||
if self.reserved_pos is not None:
|
||||
self.pos = self.reserved_pos
|
||||
self.reserved_pos = None
|
||||
angle_trajectory = self.pos.yaw + math.pi / 2
|
||||
|
||||
# Проверка на слепую зону
|
||||
R = 120
|
||||
blind = np.array([
|
||||
[
|
||||
self.pos.x * self.pixel_ratio + R * np.cos(angle_trajectory - np.pi / 2),
|
||||
self.pos.y * self.pixel_ratio + R * np.sin(angle_trajectory - np.pi / 2),
|
||||
],
|
||||
[
|
||||
self.pos.x * self.pixel_ratio + R * np.cos(angle_trajectory + np.pi / 2),
|
||||
self.pos.y * self.pixel_ratio + R * np.sin(angle_trajectory + np.pi / 2),
|
||||
]
|
||||
])
|
||||
|
||||
blind -= self.points[self.target_idx] * self.pixel_ratio
|
||||
blind = np.hypot(blind[:, 0], blind[:, 1])
|
||||
|
||||
print("R: ", blind)
|
||||
if np.min(blind) < R:
|
||||
return PilotCommand(0, 10, False, self.timer.get_elapsed())
|
||||
|
||||
|
||||
|
||||
# Вычисляем угол к цели
|
||||
target_angle = math.atan2(self.points[self.target_idx][1] - self.pos.y, self.points[self.target_idx][0] - self.pos.x)
|
||||
|
||||
angle_trajectory = self.pos.yaw + math.pi / 2
|
||||
|
||||
# print("[ANGLE]", angle_trajectory, "->", target_angle)
|
||||
|
||||
# Вычисляем разность углов (направление поворота)
|
||||
angle_diff = target_angle - angle_trajectory
|
||||
@@ -277,14 +374,13 @@ class AutoPilot(Pilot):
|
||||
if angle_diff >= math.pi:
|
||||
angle_diff -= 2 * math.pi
|
||||
|
||||
d_r = max(10, min(35., distance_to_target / 2))
|
||||
d_a_limit = d_r / 10 * 0.01
|
||||
d_r = max(5, min(10., distance_to_target / 2))
|
||||
d_a_limit = np.radians(5)
|
||||
|
||||
command = PilotCommand(
|
||||
max(min(d_a_limit, angle_diff), -d_a_limit),
|
||||
d_r, False, self.timer.get_elapsed()
|
||||
)
|
||||
self.timer.reset()
|
||||
return command
|
||||
|
||||
def reset_position(self, x: float = 0.0, y: float = 0.0, angle: float = 0.0):
|
||||
|
||||
16
constants.py
16
constants.py
@@ -1,5 +1,18 @@
|
||||
import numpy as np
|
||||
|
||||
# Ширина 1 пикселя в метрах
|
||||
YANDEX_PIXEL_RATIO = {
|
||||
15: 2830 / 1049,
|
||||
18: 350 / 1049,
|
||||
}
|
||||
|
||||
GOOGLE_PIXEL_RATIO = {
|
||||
15: 2766 / 1031,
|
||||
18: 346 / 1031,
|
||||
}
|
||||
|
||||
WINDOW_HEIGHT = 1031
|
||||
|
||||
# Ширина и высота снимка в пикселях
|
||||
CHUNK_WIDTH = 700
|
||||
|
||||
@@ -17,3 +30,6 @@ K = np.array([
|
||||
[0, _K_FOCUS_DISTANCE, _K_CENTER],
|
||||
[0, 0, 1]
|
||||
])
|
||||
|
||||
DEBUG_FPS: bool = False
|
||||
DEBUG_LANDMARK: bool = False
|
||||
|
||||
2
datasets/ya_go_maps/.gitignore
vendored
Normal file
2
datasets/ya_go_maps/.gitignore
vendored
Normal file
@@ -0,0 +1,2 @@
|
||||
images/homography_cache
|
||||
*.zip
|
||||
43
datasets/ya_go_maps/generate_dataset.py
Normal file
43
datasets/ya_go_maps/generate_dataset.py
Normal file
@@ -0,0 +1,43 @@
|
||||
import numpy as np
|
||||
from pathlib import Path
|
||||
from ...google_map import GoogleMap
|
||||
from ...simulator import Simulator
|
||||
from ...yandex_map import YandexMap
|
||||
|
||||
LAT_MIN, LAT_MAX = 44.960236, 54.967830
|
||||
LON_MIN, LON_MAX = 53.084167, 58.677977
|
||||
|
||||
def create_new_asset(yandex_map, google_map):
|
||||
folder = Path('dataset_ya_go_maps')
|
||||
|
||||
id = 0
|
||||
print(id)
|
||||
while (folder / f"{id:0{4}}_google.png").exists():
|
||||
id += 1
|
||||
|
||||
google_file = folder / f"{id:0{4}}_google.png"
|
||||
yandex_file = folder / f"{id:0{4}}_yandex.png"
|
||||
|
||||
lat = np.random.rand() * (LAT_MAX - LAT_MIN) + LAT_MIN
|
||||
lon = np.random.rand() * (LON_MAX - LON_MIN) + LON_MIN
|
||||
|
||||
yandex_map.open(lat, lon, 18)
|
||||
google_map.open(lat, lon, 18)
|
||||
|
||||
simulator = Simulator()
|
||||
simulator._apply_perspective_transform(yandex_map.make_screenshot()).save(yandex_file)
|
||||
simulator._apply_perspective_transform(google_map.make_screenshot()).save(google_file)
|
||||
|
||||
def main():
|
||||
folder = Path('dataset_ya_go_maps')
|
||||
if not folder.exists():
|
||||
folder.mkdir()
|
||||
|
||||
yandex_map = YandexMap(initial_zoom=15)
|
||||
google_map = GoogleMap(initial_zoom=15)
|
||||
|
||||
for i in range(4):
|
||||
create_new_asset(yandex_map, google_map)
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
89
google_map.py
Normal file
89
google_map.py
Normal file
@@ -0,0 +1,89 @@
|
||||
from io import BytesIO
|
||||
from PIL import Image
|
||||
from selenium.webdriver.common.actions.wheel_input import ScrollOrigin
|
||||
from selenium import webdriver
|
||||
from selenium.webdriver.common.by import By
|
||||
from selenium.webdriver.common.action_chains import ActionChains
|
||||
from time import sleep
|
||||
|
||||
import constants
|
||||
import cv2
|
||||
import math
|
||||
import numpy as np
|
||||
import os
|
||||
import utility
|
||||
|
||||
def generateURL(lat: float, lon: float, zoom: int):
|
||||
return f"https://www.google.com/maps/@{lon},{lat},{zoom}z"
|
||||
|
||||
class GoogleMap:
|
||||
initial_zoom: int
|
||||
initial_lat: float
|
||||
initial_lon: float
|
||||
pixel_ratio: float
|
||||
|
||||
def __init__(self, initial_lat=49.103814, initial_lon=55.794258, initial_zoom=18):
|
||||
self.initial_lat = initial_lat
|
||||
self.initial_lon = initial_lon
|
||||
self.initial_zoom = initial_zoom
|
||||
self.pixel_ratio = constants.GOOGLE_PIXEL_RATIO[self.initial_zoom]
|
||||
|
||||
options = webdriver.ChromeOptions()
|
||||
# options.add_experimental_option("detach", True)
|
||||
self.driver = webdriver.Chrome(options)
|
||||
self.driver.get(generateURL(initial_lat, initial_lon, initial_zoom))
|
||||
self.driver.maximize_window()
|
||||
|
||||
action = ActionChains(self.driver)
|
||||
sleep(5)
|
||||
self.driver.execute_script('document.querySelector(\'.yHc72.qk5Wte\').click()')
|
||||
sleep(5)
|
||||
|
||||
def open(self, lat, lon, zoom):
|
||||
self.initial_lat = lat
|
||||
self.initial_lon = lon
|
||||
self.initial_zoom = zoom
|
||||
self.pixel_ratio = constants.GOOGLE_PIXEL_RATIO[self.initial_zoom]
|
||||
self.driver.get(generateURL(lat, lon, zoom))
|
||||
|
||||
def save_photo(self, filename: str):
|
||||
im = self.make_screenshot()
|
||||
im.save(filename)
|
||||
|
||||
def destroy(self):
|
||||
self.driver.close()
|
||||
|
||||
def get_size(self) -> tuple[int, int]:
|
||||
html = self.driver.find_element(By.TAG_NAME, 'html')
|
||||
return (html.size['width'], html.size['height'])
|
||||
|
||||
def scroll(self, x: float, y: float, count: int = 1, inner_zoom: bool = True):
|
||||
html = self.driver.find_element(By.TAG_NAME, 'html')
|
||||
|
||||
x_offset = (x - 0.5) * (html.size['width'] - 72) + 72
|
||||
y_offset = (y - 0.5) * html.size['height']
|
||||
action = ActionChains(self.driver)
|
||||
|
||||
for i in range(count-1):
|
||||
action.scroll_from_origin(ScrollOrigin(html, int(x_offset), int(y_offset)), 0, -100 if inner_zoom else 100)
|
||||
action.perform()
|
||||
if i != count - 1:
|
||||
sleep(0.25)
|
||||
|
||||
def move(self, dx: float, dy: float):
|
||||
self.driver.execute_script(utility.google_map_js_move_script(dx, dy))
|
||||
|
||||
def make_as_center(self, x: float, y: float):
|
||||
dx = (x - 0.5) * self.get_size()[0]
|
||||
dy = (0.5 - y) * self.get_size()[1]
|
||||
self.move(dx, dy)
|
||||
sleep(1)
|
||||
|
||||
def make_screenshot(self) -> Image.Image:
|
||||
png = self.driver.get_screenshot_as_png()
|
||||
im = Image.open(BytesIO(png))
|
||||
return utility.cv2_to_pil(np.array(im)[:, 72:])
|
||||
|
||||
def get_geolocation(self):
|
||||
current_url = self.driver.current_url
|
||||
return utility.parse_google_maps_url(current_url)
|
||||
384
main.py
384
main.py
@@ -1,23 +1,33 @@
|
||||
from google_map import GoogleMap
|
||||
from pathlib import Path
|
||||
from position import Position
|
||||
from simulator import Simulator
|
||||
from time import sleep
|
||||
from trajectory_drawer import TrajectoryDrawer
|
||||
from utility import cv2_to_pil
|
||||
from vision_chunk import VisionChunk
|
||||
from visualization import VisualizationManager
|
||||
from yandex_map import YandexMap
|
||||
from vision_chunk import VisionChunk
|
||||
from utility import cv2_to_pil
|
||||
import random
|
||||
|
||||
import argparse
|
||||
import autopilot
|
||||
import constants
|
||||
import cv2
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import pickle
|
||||
import random
|
||||
import utility
|
||||
|
||||
import autopilot
|
||||
def get_map(map_name: str = 'google', lat=49.103814, lon=55.794258, zoom=18):
|
||||
if map_name == 'google': return GoogleMap(lat, lon, zoom)
|
||||
if map_name == 'yandex': return YandexMap(lat, lon, zoom)
|
||||
return None
|
||||
|
||||
def make_global_photo(filename):
|
||||
yandexMap = YandexMap()
|
||||
yandexMap.save_photo(filename)
|
||||
yandexMap.destroy()
|
||||
def make_global_photo(filename, map_name: str = 'google', lat=49.103814, lon=55.794258, zoom=13):
|
||||
online_map: YandexMap | GoogleMap = get_map(map_name, lat, lon, zoom)
|
||||
online_map.save_photo(filename)
|
||||
online_map.destroy()
|
||||
|
||||
def get_trajectory_points(bg_img: str) -> list[(float, float)]:
|
||||
trajectoryDrawer = TrajectoryDrawer(bg_img)
|
||||
@@ -26,97 +36,132 @@ def get_trajectory_points(bg_img: str) -> list[(float, float)]:
|
||||
points = list(map(lambda p: [p[0] / trajectoryDrawer.img.shape[1], p[1] / trajectoryDrawer.img.shape[0]], trajectoryDrawer.points))
|
||||
return points
|
||||
|
||||
def main():
|
||||
# Скриншот местности
|
||||
# make_global_photo('map.jpg')
|
||||
def build(name: str, map_name: str, lat: float, lon: float):
|
||||
|
||||
# Получаем траекторию от пользователя
|
||||
# points = get_trajectory_points('map.jpg')
|
||||
# Создание папки с информацией о маршруте
|
||||
dir = Path('trajectories')
|
||||
if not dir.exists(): dir.mkdir()
|
||||
dir /= name
|
||||
assert not dir.exists()
|
||||
dir.mkdir()
|
||||
dir_chunks = dir / 'chunks'
|
||||
dir_chunks.mkdir()
|
||||
|
||||
# Trajectory #1
|
||||
# points = [[np.float64(0.5384504359393909), np.float64(0.4084520767967683)], [np.float64(0.4451750568707629), np.float64(0.38213330305374654)], [np.float64(0.49266070439660997), np.float64(0.2789637099811013)], [np.float64(0.36377108968359656), np.float64(0.3263375027185404)], [np.float64(0.3535955937852008), np.float64(0.4337180995900692)]]
|
||||
make_global_photo('map.jpg', map_name, lat, lon, 15)
|
||||
points = get_trajectory_points('map.jpg')
|
||||
online_map: YandexMap | GoogleMap = get_map(map_name, lat, lon, 15)
|
||||
|
||||
# Trajectory #2
|
||||
# points = [[np.float64(0.29197731306713737), np.float64(0.3452870198135161)], [np.float64(0.33494051797147517), np.float64(0.2010601397017569)], [np.float64(0.39768940934491587), np.float64(0.25369768718780034)], [np.float64(0.4027771572941138), np.float64(0.4158213334448144)], [np.float64(0.2914120077394487), np.float64(0.5547844588079692)]]
|
||||
|
||||
# Trajectory #3
|
||||
# points = [[np.float64(0.2755834585641664), np.float64(0.45687862048392835)], [np.float64(0.295934450360958), np.float64(0.5021469113219258)], [np.float64(0.32872215936689997), np.float64(0.4810918923275084)], [np.float64(0.3649017003389739), np.float64(0.5295184360146684)], [np.float64(0.3999506306556705), np.float64(0.49477765467387963)]]
|
||||
|
||||
# Trajectory #4
|
||||
# points = [[np.float64(0.42143223310783934), np.float64(0.6663760594783815)], [np.float64(0.4253893704016599), np.float64(0.5537317078582484)], [np.float64(0.5124463908657128), np.float64(0.5621537154560153)], [np.float64(0.5124463908657128), np.float64(0.6684815613778233)], [np.float64(0.42143223310783934), np.float64(0.6663760594783815)]]
|
||||
|
||||
# Trajectory #5
|
||||
# points = [[np.float64(0.5983728006743884), np.float64(0.7348048712102382)], [np.float64(0.5966768846913225), np.float64(0.5453097002604814)], [np.float64(0.6345523416464622), np.float64(0.7190136069644251)], [np.float64(0.6402053949233488), np.float64(0.5495207040593649)], [np.float64(0.5983728006743884), np.float64(0.7348048712102382)]]
|
||||
|
||||
# Trajectory #6
|
||||
# points = [[np.float64(0.4406526142492536), np.float64(0.28106921188054296)], [np.float64(0.38581799746345413), np.float64(0.2968604761263561)], [np.float64(0.3931669667234066), np.float64(0.353709027411283)], [np.float64(0.4248240650739713), np.float64(0.35265627646156217)], [np.float64(0.40616898926024564), np.float64(0.3179154951207735)]]
|
||||
|
||||
# Trajectory #7
|
||||
# points = [[np.float64(0.5491912371654754), np.float64(0.7505961354560512)], [np.float64(0.5537136797869846), np.float64(0.6863783275230781)], [np.float64(0.5017055896396284), np.float64(0.6653233085286606)], [np.float64(0.5520177638039186), np.float64(0.6042637534448503)], [np.float64(0.5593667330638712), np.float64(0.516885424618018)]]
|
||||
|
||||
# Trajectory #8
|
||||
# points =
|
||||
|
||||
# Trajectory #9
|
||||
# points =
|
||||
|
||||
# Trajectory #10
|
||||
# points =
|
||||
|
||||
print(points)
|
||||
|
||||
# Для каждой точки сделаем приближенный снимок
|
||||
yandexMap = YandexMap()
|
||||
chunks: list[VisionChunk] = []
|
||||
|
||||
plt.ion()
|
||||
for i in range(len(points)):
|
||||
point = points[i]
|
||||
yandexMap.scroll(point[0], point[1], 5, True)
|
||||
sleep(1)
|
||||
cv2_img = yandexMap.make_screenshot(point[0], point[1], 0.2, 0.2)
|
||||
img = cv2_to_pil(cv2_img)
|
||||
chunk = VisionChunk(img)
|
||||
Path('chunks').mkdir(exist_ok=True)
|
||||
chunk.save_image(Path('.') / 'chunks' / f'chunk_{i}.png')
|
||||
plt.subplot(1, len(points), i+1)
|
||||
plt.imshow(img)
|
||||
plt.pause(0.25)
|
||||
yandexMap.scroll(point[0], point[1], 5, False)
|
||||
|
||||
plt.tight_layout()
|
||||
|
||||
# Выделим на каждой картинке ключевые точки
|
||||
for i in range(len(points)):
|
||||
chunk = VisionChunk.load_image(Path('chunks') / f'chunk_{i}.png')
|
||||
chunks.append(chunk)
|
||||
kp, des = chunk.compute_keypoints()
|
||||
|
||||
plt.subplot(1, len(points), i+1)
|
||||
plt.imshow(chunk.image)
|
||||
kp_coords = np.array([j.pt for j in kp])
|
||||
if len(kp_coords) > 0:
|
||||
plt.scatter(kp_coords[:, 0], kp_coords[:, 1], c='red', s=20, alpha=0.7, marker='o')
|
||||
plt.pause(0.2)
|
||||
plt.ioff()
|
||||
|
||||
plt.show(block=True)
|
||||
|
||||
# Начнём симуляцию полёта с первой точки
|
||||
yandexMap.scroll(points[0][0], points[0][1], 5, True)
|
||||
sleep(0.2)
|
||||
yandexMap.make_as_center(*points[0])
|
||||
sleep(1)
|
||||
|
||||
vis_manager = VisualizationManager()
|
||||
width, height = yandexMap.get_size()
|
||||
# print(width, height)
|
||||
width, height = online_map.get_size()
|
||||
points_coords = np.array(list(map(lambda p: [
|
||||
(p[0] - points[0][0]) * width, (points[0][1] - p[1]) * height
|
||||
], points)))
|
||||
points_coords *= 2 ** 4
|
||||
pilot = autopilot.AutoPilot(points_coords, chunks, vis_manager)
|
||||
simulator = Simulator(yandexMap)
|
||||
|
||||
points_coords *= online_map.pixel_ratio
|
||||
|
||||
# Начнём симуляцию полёта с первой точки
|
||||
online_map.make_as_center(*points[0])
|
||||
sleep(3)
|
||||
online_map.scroll(0.5, 0.5, 10, True)
|
||||
sleep(2)
|
||||
geo = online_map.get_geolocation()
|
||||
online_map.open(geo['lat'], geo['lon'], 18)
|
||||
sleep(5)
|
||||
|
||||
points_coords_pixel = points_coords.copy() / online_map.pixel_ratio
|
||||
|
||||
pilot = autopilot.AutoPilot(points_coords_pixel, pixel_ratio=online_map.pixel_ratio)
|
||||
simulator = Simulator(online_map)
|
||||
pilot.target_idx = 0
|
||||
|
||||
pilot.pos = simulator.pos.copy()
|
||||
command = pilot.make_command()
|
||||
|
||||
positions: list[Position] = []
|
||||
|
||||
print("points_coords_pixel:", points_coords_pixel)
|
||||
for i in range(5000):
|
||||
|
||||
if command.stop:
|
||||
break
|
||||
|
||||
chunk = simulator.get_chunk()
|
||||
pilot.pos = simulator.pos.copy()
|
||||
command = pilot.make_command()
|
||||
command.velocity /= online_map.pixel_ratio
|
||||
|
||||
print("Position:", simulator.pos)
|
||||
|
||||
# Save Image
|
||||
chunk.save_image(dir_chunks / f"chunk_{len(positions)}.png")
|
||||
|
||||
positions.append(simulator.pos.copy() * online_map.pixel_ratio)
|
||||
|
||||
simulator.handle(command.dangle, command.velocity)
|
||||
if i == 0 and map_name == 'google':
|
||||
simulator.pos.x = 0
|
||||
simulator.pos.y = 0
|
||||
sleep(1.5)
|
||||
|
||||
data = {
|
||||
'points': points_coords,
|
||||
'chunk_positions': positions,
|
||||
'initial_geolocation': geo
|
||||
}
|
||||
|
||||
print(points_coords)
|
||||
|
||||
file_positions = dir / 'positions.pkl'
|
||||
with file_positions.open('wb') as file:
|
||||
pickle.dump(data, file)
|
||||
|
||||
print("WRITE POINTS:", points)
|
||||
|
||||
sleep(15)
|
||||
online_map.destroy()
|
||||
|
||||
def run(name: str, map_name: str, ref_min_distance: float):
|
||||
dir = Path('trajectories')
|
||||
assert dir.exists()
|
||||
dir /= name
|
||||
assert dir.exists(), "Укажите корректное название маршрута"
|
||||
dir_chunks = dir / 'chunks'
|
||||
file_positions = dir / 'positions.pkl'
|
||||
|
||||
with file_positions.open('rb') as file:
|
||||
data = pickle.load(file)
|
||||
|
||||
initial_geolocation = data['initial_geolocation']
|
||||
|
||||
lat = initial_geolocation['lat']
|
||||
lon = initial_geolocation['lon']
|
||||
online_map: YandexMap | GoogleMap = get_map(map_name, lat, lon, 18)
|
||||
sleep(2)
|
||||
|
||||
chunks: list[VisionChunk] = []
|
||||
for i in range(len(data['chunk_positions'])):
|
||||
pos = data['chunk_positions'][i]
|
||||
if len(chunks) == 0 or np.hypot(chunks[-1].pos.x - pos.x, chunks[-1].pos.y - pos.y) > ref_min_distance:
|
||||
chunk = VisionChunk.load_image(dir_chunks / f"chunk_{i}.png")
|
||||
chunk.pos = data['chunk_positions'][i] / online_map.pixel_ratio
|
||||
chunks.append(chunk)
|
||||
|
||||
r = 0
|
||||
for i in range(len(data['points']) - 1):
|
||||
r += np.hypot(
|
||||
data['points'][i][0] - data['points'][i+1][0],
|
||||
data['points'][i][1] - data['points'][i+1][1]
|
||||
)
|
||||
print("R: ", r)
|
||||
|
||||
return
|
||||
|
||||
points = data['points'] / online_map.pixel_ratio
|
||||
print("READ POINTS:", points)
|
||||
|
||||
vis_manager = VisualizationManager()
|
||||
pilot = autopilot.AutoPilot(points, chunks, vis_manager, online_map.pixel_ratio)
|
||||
simulator = Simulator(online_map)
|
||||
pilot.target_idx = 0
|
||||
|
||||
chunk = simulator.get_chunk()
|
||||
@@ -125,38 +170,26 @@ def main():
|
||||
vis_manager.update_display()
|
||||
vis_manager.pause(1)
|
||||
|
||||
vis_manager.set_target_points(points_coords)
|
||||
vis_manager.set_target_points(data['points'])
|
||||
|
||||
proc_time = np.array([])
|
||||
|
||||
zoom_next_event = random.randint(5, 10)
|
||||
|
||||
errors = []
|
||||
chunk_errors = []
|
||||
chunk_improves = []
|
||||
|
||||
last_chunk_index = 0
|
||||
|
||||
sleep(1)
|
||||
|
||||
for i in range(10000000000):
|
||||
print(f"Image #{i}")
|
||||
if i == zoom_next_event:
|
||||
r = random.randint(0, 1)
|
||||
direction = ['up', 'down'][r]
|
||||
# simulator.change_zoom(direction)
|
||||
zoom_next_event = i + random.randint(20, 40)
|
||||
|
||||
|
||||
# if i > 0:
|
||||
# simulator.set_pitch(np.sin(i / 10) * 5)
|
||||
# simulator.set_roll(np.sin(i / 15) * 5)
|
||||
# simulator.set_zoom(1.0 + np.sin(i / 10) * 0.3)
|
||||
if i > 0:
|
||||
simulator.set_pitch(np.sin(i / 10) * 5)
|
||||
simulator.set_roll(np.sin(i / 15) * 5)
|
||||
simulator.set_zoom(1.0 + np.sin(i / 10) * 0.3)
|
||||
|
||||
if command.stop:
|
||||
break
|
||||
|
||||
# simulator.handle(command.dangle, command.velocity)
|
||||
chunk = simulator.get_chunk()
|
||||
command = pilot.handle(chunk)
|
||||
command.velocity /= online_map.pixel_ratio
|
||||
|
||||
proc_time = np.append(proc_time, command.proccessing_time)
|
||||
|
||||
@@ -167,34 +200,133 @@ def main():
|
||||
vis_manager.pause(0.2)
|
||||
|
||||
vis_manager.set_target_index(pilot.target_idx)
|
||||
vis_manager.update_drone_trajectory(pilot.pos.x, pilot.pos.y)
|
||||
vis_manager.update_global_map(simulator.pos.x, simulator.pos.y)
|
||||
vis_manager.update_error_plot(i, pilot.pos.x, pilot.pos.y, simulator.pos.x, simulator.pos.y)
|
||||
vis_manager.update_drone_trajectory(pilot.pos.x * online_map.pixel_ratio, pilot.pos.y * online_map.pixel_ratio)
|
||||
vis_manager.update_global_map(simulator.pos.x * online_map.pixel_ratio, simulator.pos.y * online_map.pixel_ratio)
|
||||
vis_manager.update_error_plot(i, pilot.pos.x * online_map.pixel_ratio, pilot.pos.y * online_map.pixel_ratio, simulator.pos.x * online_map.pixel_ratio, simulator.pos.y * online_map.pixel_ratio)
|
||||
|
||||
errors.append(np.hypot(pilot.pos.x - simulator.pos.x, pilot.pos.y - simulator.pos.y))
|
||||
if last_chunk_index != pilot.target_idx:
|
||||
last_chunk_index = pilot.target_idx
|
||||
chunk_errors.append(errors[-1])
|
||||
chunk_improves.append(errors[-1] - errors[max(len(errors) - 2, 0)])
|
||||
errors.append(np.hypot((pilot.pos.x - simulator.pos.x) * online_map.pixel_ratio, (pilot.pos.y - simulator.pos.y) * online_map.pixel_ratio))
|
||||
|
||||
vis_manager.update_display()
|
||||
vis_manager.pause(0.2)
|
||||
|
||||
last_proc_times = proc_time[-10:]
|
||||
last_proc_times = proc_time[-30:]
|
||||
print(F"\nImage #{i}")
|
||||
print("Average FPS:", 1 / last_proc_times.mean())
|
||||
print("Pilot coords:", pilot.pos)
|
||||
print("Simulator coords:", simulator.pos)
|
||||
sleep(0.5)
|
||||
simulator.handle(command.dangle, command.velocity)
|
||||
if i == 0 and map_name == 'google':
|
||||
simulator.pos.x = 0
|
||||
simulator.pos.y = 0
|
||||
|
||||
print("Errors:", errors)
|
||||
print("MSE:", (np.array(errors) ** 2).mean())
|
||||
print("RMSE:", (np.array(errors) ** 2).mean() ** 0.5)
|
||||
print("Chunk errors:", chunk_errors)
|
||||
print("Chunk error improves:", chunk_improves)
|
||||
print("Average FPS:", 1 / proc_time.mean())
|
||||
vis_manager.show_final()
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
def parse_args():
|
||||
"""Парсер аргументов командной строки"""
|
||||
parser = argparse.ArgumentParser(description='Обработка траекторий')
|
||||
|
||||
#TODO
|
||||
# Добавляем обязательный аргумент --mode
|
||||
parser.add_argument(
|
||||
'--mode',
|
||||
type=str,
|
||||
required=True,
|
||||
choices=['standalone', 'build', 'run'],
|
||||
help='Режим работы: standalone, build или run'
|
||||
)
|
||||
|
||||
# Добавляем опциональный аргумент --name
|
||||
parser.add_argument(
|
||||
'--name',
|
||||
type=str,
|
||||
default=utility.generate_folder_name(),
|
||||
help='Название траектории'
|
||||
)
|
||||
|
||||
# Координаты
|
||||
parser.add_argument(
|
||||
'--lat',
|
||||
type=float,
|
||||
default=49.103814,
|
||||
help='Широта (по умолчанию 49.103814)'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--lon',
|
||||
type=float,
|
||||
default=55.794258,
|
||||
help='Долгота (по умолчанию 55.794258)'
|
||||
)
|
||||
|
||||
# Источник эталонных изображений (ориентиров)
|
||||
parser.add_argument(
|
||||
'--reference',
|
||||
type=str,
|
||||
default='google',
|
||||
choices=['google', 'yandex'],
|
||||
help='Откуда берутся эталонные изображения (ориентиры): google или yandex (по умолчанию google)'
|
||||
)
|
||||
|
||||
# Место проведения симуляции
|
||||
parser.add_argument(
|
||||
'--simulation',
|
||||
type=str,
|
||||
default='yandex',
|
||||
choices=['google', 'yandex'],
|
||||
help='Где проводится симуляция: google или yandex (по умолчанию yandex)'
|
||||
)
|
||||
|
||||
# Место проведения симуляции
|
||||
parser.add_argument(
|
||||
'--ref-min-distance',
|
||||
type=float,
|
||||
default=100,
|
||||
help='Минимальное расстояние между эталонами'
|
||||
)
|
||||
|
||||
# Место проведения симуляции
|
||||
parser.add_argument(
|
||||
'--debug-fps',
|
||||
action='store_true',
|
||||
help='Включить отладку FPS'
|
||||
)
|
||||
|
||||
# Место проведения симуляции
|
||||
parser.add_argument(
|
||||
'--debug-landmark',
|
||||
action='store_true',
|
||||
help='Включить отладку эталонов'
|
||||
)
|
||||
|
||||
# Парсим аргументы
|
||||
args = parser.parse_args()
|
||||
|
||||
# Проверяем, что для build и run указан --name
|
||||
if args.mode in ['build', 'run'] and not args.name:
|
||||
parser.error(f"--name обязателен для режима {args.mode}")
|
||||
|
||||
return args
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parse_args()
|
||||
name = args.name
|
||||
mode = args.mode
|
||||
sim: str = args.simulation
|
||||
ref: str = args.reference
|
||||
lat: float = args.lat
|
||||
lon: float = args.lon
|
||||
rmd: float = args.ref_min_distance
|
||||
|
||||
constants.DEBUG_FPS = args.debug_fps
|
||||
constants.DEBUG_LANDMARK = args.debug_landmark
|
||||
|
||||
if mode == 'build' or mode == 'standalone':
|
||||
build(name, ref, lat, lon)
|
||||
|
||||
if mode == 'run' or mode == 'standalone':
|
||||
run(name, sim, rmd)
|
||||
|
||||
BIN
map.jpg
BIN
map.jpg
Binary file not shown.
|
Before Width: | Height: | Size: 3.4 MiB After Width: | Height: | Size: 3.3 MiB |
1
models/GAN/.gitignore
vendored
Normal file
1
models/GAN/.gitignore
vendored
Normal file
@@ -0,0 +1 @@
|
||||
runs
|
||||
254
models/GAN/README.md
Normal file
254
models/GAN/README.md
Normal file
@@ -0,0 +1,254 @@
|
||||
# GAN Trainer для преобразования изображений Yandex → Google
|
||||
|
||||
Этот модуль содержит реализацию тренера для GAN (Generative Adversarial Network) модели, предназначенной для преобразования изображений карт Yandex в стиль Google Maps.
|
||||
|
||||
## Структура проекта
|
||||
|
||||
```
|
||||
autopilot/models/GAN/
|
||||
├── gan.py # Основная реализация GAN модели
|
||||
├── trainer.py # Тренер для обучения GAN
|
||||
├── test_trainer.py # Тесты для тренера
|
||||
├── train_example.py # Пример использования тренера
|
||||
└── README.md # Этот файл
|
||||
```
|
||||
|
||||
## Модель GAN
|
||||
|
||||
Модель состоит из двух основных компонентов:
|
||||
|
||||
### 1. Генератор (GeneratorUNet)
|
||||
- Архитектура U-Net для преобразования изображений
|
||||
- Принимает изображение Yandex (3 канала RGB)
|
||||
- Возвращает изображение в стиле Google (3 канала RGB)
|
||||
- Использует skip connections для сохранения деталей
|
||||
|
||||
### 2. Дискриминатор (DiscriminatorPatchGAN)
|
||||
- PatchGAN архитектура
|
||||
- Принимает пару изображений (Yandex + Google)
|
||||
- Возвращает вероятность того, что пара реальная
|
||||
- Работает с патчами изображения 41x41
|
||||
|
||||
### Функция потерь (GANLoss)
|
||||
Поддерживает три режима:
|
||||
- `vanilla`: Бинарная кросс-энтропия
|
||||
- `lsgan`: Least Squares GAN (более стабильный)
|
||||
- `wgangp`: Wasserstein GAN with Gradient Penalty
|
||||
|
||||
## Тренер (GANTrainer)
|
||||
|
||||
### Основные возможности
|
||||
|
||||
1. **Обучение с чередованием**:
|
||||
- Обучение генератора и дискриминатора поочередно
|
||||
- Поддержка L1 потерь для сохранения структуры
|
||||
|
||||
2. **Валидация и мониторинг**:
|
||||
- Отдельные потери для генератора и дискриминатора
|
||||
- Логирование в TensorBoard
|
||||
- Ранняя остановка
|
||||
|
||||
3. **Сохранение и загрузка**:
|
||||
- Чекпоинты каждой эпохи
|
||||
- Лучшая модель
|
||||
- Финальная модель
|
||||
- История обучения
|
||||
|
||||
4. **Оценка модели**:
|
||||
- Метрики на тестовом наборе
|
||||
- Генерация примеров
|
||||
|
||||
### Быстрый старт
|
||||
|
||||
```python
|
||||
import torch
|
||||
from torch.utils.data import DataLoader
|
||||
from models.GAN.gan import create_image_gan
|
||||
from models.GAN.trainer import GANTrainer
|
||||
|
||||
# Конфигурация
|
||||
config = {
|
||||
"learning_rate": 2e-4,
|
||||
"beta1": 0.5,
|
||||
"beta2": 0.999,
|
||||
"batch_size": 4,
|
||||
"output_dir": "runs/gan_training",
|
||||
"gan_mode": "vanilla",
|
||||
"lambda_L1": 100.0,
|
||||
"early_stopping_patience": 20,
|
||||
}
|
||||
|
||||
# Устройство
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
# Создание модели
|
||||
model = create_image_gan(
|
||||
input_channels=3,
|
||||
output_channels=3,
|
||||
gan_mode=config["gan_mode"],
|
||||
lambda_L1=config["lambda_L1"],
|
||||
use_cuda=(device.type == "cuda"),
|
||||
)
|
||||
|
||||
# Создание даталоадеров (замените на свои данные)
|
||||
train_loader = DataLoader(train_dataset, batch_size=config["batch_size"], shuffle=True)
|
||||
val_loader = DataLoader(val_dataset, batch_size=config["batch_size"], shuffle=False)
|
||||
|
||||
# Создание тренера
|
||||
trainer = GANTrainer(
|
||||
model=model,
|
||||
train_loader=train_loader,
|
||||
val_loader=val_loader,
|
||||
device=device,
|
||||
config=config,
|
||||
)
|
||||
|
||||
# Обучение
|
||||
trainer.train(num_epochs=100)
|
||||
|
||||
# Оценка
|
||||
metrics = trainer.evaluate(test_loader)
|
||||
```
|
||||
|
||||
### Конфигурация обучения
|
||||
|
||||
#### Базовая конфигурация
|
||||
```python
|
||||
config = {
|
||||
# Параметры оптимизатора
|
||||
"learning_rate": 2e-4, # Learning rate
|
||||
"beta1": 0.5, # Adam beta1
|
||||
"beta2": 0.999, # Adam beta2
|
||||
|
||||
# Параметры обучения
|
||||
"batch_size": 4, # Размер батча
|
||||
"epochs": 100, # Количество эпох
|
||||
|
||||
# Параметры GAN
|
||||
"gan_mode": "vanilla", # Режим GAN
|
||||
"lambda_L1": 100.0, # Вес L1 потерь
|
||||
|
||||
# Регуляризация
|
||||
"grad_clip": 1.0, # Gradient clipping
|
||||
|
||||
# Ранняя остановка
|
||||
"early_stopping_patience": 20,
|
||||
|
||||
# Выходные данные
|
||||
"output_dir": "runs/gan",
|
||||
}
|
||||
```
|
||||
|
||||
#### Расширенная конфигурация
|
||||
```python
|
||||
config = {
|
||||
"learning_rate": 2e-4,
|
||||
"beta1": 0.5,
|
||||
"beta2": 0.999,
|
||||
"batch_size": 8,
|
||||
"epochs": 200,
|
||||
"gan_mode": "lsgan", # Более стабильный LSGAN
|
||||
"lambda_L1": 100.0,
|
||||
"grad_clip": 1.0,
|
||||
"weight_decay": 1e-4, # Weight decay
|
||||
"early_stopping_patience": 30,
|
||||
"early_stopping_min_delta": 1e-4,
|
||||
"output_dir": "runs/gan_advanced",
|
||||
}
|
||||
```
|
||||
|
||||
### Методы тренера
|
||||
|
||||
#### Основные методы
|
||||
- `train_epoch()`: Обучение на одной эпохе
|
||||
- `validate()`: Валидация модели
|
||||
- `train(num_epochs)`: Полное обучение
|
||||
- `evaluate(test_loader)`: Оценка на тестовых данных
|
||||
|
||||
#### Управление чекпоинтами
|
||||
- `save_checkpoint(is_best=False)`: Сохранение чекпоинта
|
||||
- `load_checkpoint(path, resume_training=False)`: Загрузка чекпоинта
|
||||
|
||||
### Выходные файлы
|
||||
|
||||
После обучения создаются следующие файлы:
|
||||
|
||||
```
|
||||
runs/gan_training/
|
||||
├── config.json # Конфигурация обучения
|
||||
├── training_history.json # История потерь
|
||||
├── model_best.pth # Лучшая модель
|
||||
├── model_final.pth # Финальная модель
|
||||
├── checkpoint_epoch_1.pth # Чекпоинты каждой эпохи
|
||||
├── checkpoint_epoch_2.pth
|
||||
├── ...
|
||||
└── tensorboard/ # Логи TensorBoard
|
||||
├── events.out.tfevents...
|
||||
└── ...
|
||||
```
|
||||
|
||||
### TensorBoard
|
||||
|
||||
Для визуализации обучения используйте TensorBoard:
|
||||
|
||||
```bash
|
||||
tensorboard --logdir runs/gan_training/tensorboard
|
||||
```
|
||||
|
||||
Доступные метрики:
|
||||
- `train/batch_g_loss`: Потери генератора на батче
|
||||
- `train/batch_d_loss`: Потери дискриминатора на батче
|
||||
- `train/batch_g_l1_loss`: L1 потери генератора
|
||||
- `train/epoch_g_loss`: Потери генератора на эпохе
|
||||
- `train/epoch_d_loss`: Потери дискриминатора на эпохе
|
||||
- `val/epoch_g_loss`: Валидационные потери генератора
|
||||
- `val/epoch_d_loss`: Валидационные потери дискриминатора
|
||||
|
||||
### Тестирование
|
||||
|
||||
Запустите тесты для проверки работоспособности:
|
||||
|
||||
```bash
|
||||
python models/GAN/test_trainer.py
|
||||
```
|
||||
|
||||
### Пример использования
|
||||
|
||||
Полный пример использования смотрите в `train_example.py`.
|
||||
|
||||
### Советы по обучению
|
||||
|
||||
1. **Начальные значения**:
|
||||
- Используйте `gan_mode="lsgan"` для более стабильного обучения
|
||||
- Начните с `lambda_L1=100.0` и регулируйте по необходимости
|
||||
- Используйте маленький `batch_size` (4-8) при ограниченной памяти GPU
|
||||
|
||||
2. **Мониторинг**:
|
||||
- Следите за балансом потерь генератора и дискриминатора
|
||||
- Если потери дискриминатора близки к 0, генератор не обучается
|
||||
- Если потери генератора слишком высоки, уменьшите `lambda_L1`
|
||||
|
||||
3. **Визуализация**:
|
||||
- Регулярно генерируйте примеры для визуальной оценки
|
||||
- Используйте TensorBoard для отслеживания прогресса
|
||||
|
||||
### Устранение проблем
|
||||
|
||||
#### Высокие потери генератора
|
||||
- Уменьшите `lambda_L1`
|
||||
- Увеличьте learning rate
|
||||
- Проверьте качество данных
|
||||
|
||||
#### Дискриминатор слишком сильный
|
||||
- Уменьшите learning rate дискриминатора
|
||||
- Добавьте dropout в дискриминатор
|
||||
- Обучайте генератор чаще, чем дискриминатор
|
||||
|
||||
#### Недостаток памяти GPU
|
||||
- Уменьшите `batch_size`
|
||||
- Уменьшите размер изображений
|
||||
- Используйте gradient accumulation
|
||||
|
||||
### Лицензия
|
||||
|
||||
Этот проект является частью Autopilot системы.
|
||||
1777
models/GAN/gan.ipynb
Normal file
1777
models/GAN/gan.ipynb
Normal file
File diff suppressed because one or more lines are too long
393
models/GAN/gan.py
Normal file
393
models/GAN/gan.py
Normal file
@@ -0,0 +1,393 @@
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class UNetDownBlock(nn.Module):
|
||||
"""Блок downsampling для U-Net генератора"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
normalize: bool = True,
|
||||
dropout: float = 0.0,
|
||||
):
|
||||
super().__init__()
|
||||
layers = [
|
||||
nn.Conv2d(
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size=4,
|
||||
stride=2,
|
||||
padding=1,
|
||||
bias=False,
|
||||
)
|
||||
]
|
||||
if normalize:
|
||||
layers.append(nn.BatchNorm2d(out_channels))
|
||||
layers.append(nn.LeakyReLU(0.2, inplace=True))
|
||||
if dropout > 0:
|
||||
layers.append(nn.Dropout2d(dropout))
|
||||
self.model = nn.Sequential(*layers)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return self.model(x)
|
||||
|
||||
|
||||
class UNetUpBlock(nn.Module):
|
||||
"""Блок upsampling для U-Net генератора"""
|
||||
|
||||
def __init__(self, in_channels: int, out_channels: int, dropout: float = 0.0):
|
||||
super().__init__()
|
||||
layers = [
|
||||
nn.ConvTranspose2d(
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size=4,
|
||||
stride=2,
|
||||
padding=1,
|
||||
bias=False,
|
||||
),
|
||||
nn.BatchNorm2d(out_channels),
|
||||
nn.ReLU(inplace=True),
|
||||
]
|
||||
if dropout > 0:
|
||||
layers.append(nn.Dropout2d(dropout))
|
||||
self.model = nn.Sequential(*layers)
|
||||
|
||||
def forward(self, x: torch.Tensor, skip_input: torch.Tensor) -> torch.Tensor:
|
||||
x = self.model(x)
|
||||
# Обрезаем skip connection до размера x, если необходимо
|
||||
if x.shape != skip_input.shape:
|
||||
diffY = skip_input.size(2) - x.size(2)
|
||||
diffX = skip_input.size(3) - x.size(3)
|
||||
x = F.pad(
|
||||
x, [diffX // 2, diffX - diffX // 2, diffY // 2, diffY - diffY // 2]
|
||||
)
|
||||
x = torch.cat([x, skip_input], dim=1)
|
||||
return x
|
||||
|
||||
|
||||
class GeneratorUNet(nn.Module):
|
||||
"""Генератор на основе U-Net архитектуры для преобразования Yandex → Google"""
|
||||
|
||||
def __init__(self, in_channels: int = 3, out_channels: int = 3):
|
||||
super().__init__()
|
||||
|
||||
# Downsampling path
|
||||
self.down1 = UNetDownBlock(in_channels, 64, normalize=False)
|
||||
self.down2 = UNetDownBlock(64, 128)
|
||||
self.down3 = UNetDownBlock(128, 256)
|
||||
self.down4 = UNetDownBlock(256, 512)
|
||||
self.down5 = UNetDownBlock(512, 512)
|
||||
self.down6 = UNetDownBlock(512, 512)
|
||||
self.down7 = UNetDownBlock(512, 512)
|
||||
|
||||
# Bottleneck
|
||||
self.bottleneck = nn.Sequential(
|
||||
nn.Conv2d(512, 512, kernel_size=4, stride=2, padding=1, bias=False),
|
||||
nn.ReLU(inplace=True),
|
||||
)
|
||||
|
||||
# Upsampling path
|
||||
self.up1 = UNetUpBlock(512, 512, dropout=0.5)
|
||||
self.up2 = UNetUpBlock(1024, 512, dropout=0.5)
|
||||
self.up3 = UNetUpBlock(1024, 512, dropout=0.5)
|
||||
self.up4 = UNetUpBlock(1024, 512)
|
||||
self.up5 = UNetUpBlock(1024, 256)
|
||||
self.up6 = UNetUpBlock(512, 128)
|
||||
self.up7 = UNetUpBlock(256, 64)
|
||||
|
||||
# Final layer
|
||||
self.final = nn.Sequential(
|
||||
nn.ConvTranspose2d(128, out_channels, kernel_size=4, stride=2, padding=1),
|
||||
nn.Tanh(),
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
# Downsampling
|
||||
d1 = self.down1(x) # 350x350
|
||||
d2 = self.down2(d1) # 175x175
|
||||
d3 = self.down3(d2) # 88x88
|
||||
d4 = self.down4(d3) # 44x44
|
||||
d5 = self.down5(d4) # 22x22
|
||||
d6 = self.down6(d5) # 11x11
|
||||
d7 = self.down7(d6) # 6x6
|
||||
|
||||
# Bottleneck
|
||||
u = self.bottleneck(d7) # 3x3
|
||||
|
||||
# Upsampling with skip connections
|
||||
u = self.up1(u, d7) # 6x6
|
||||
u = self.up2(u, d6) # 11x11
|
||||
u = self.up3(u, d5) # 22x22
|
||||
u = self.up4(u, d4) # 44x44
|
||||
u = self.up5(u, d3) # 88x88
|
||||
u = self.up6(u, d2) # 175x175
|
||||
u = self.up7(u, d1) # 350x350
|
||||
|
||||
# Final output
|
||||
return self.final(u) # 700x700
|
||||
|
||||
|
||||
class DiscriminatorPatchGAN(nn.Module):
|
||||
"""Дискриминатор PatchGAN для изображений 700x700"""
|
||||
|
||||
def __init__(
|
||||
self, in_channels: int = 6
|
||||
): # 3 для реального + 3 для сгенерированного
|
||||
super().__init__()
|
||||
|
||||
def discriminator_block(
|
||||
in_filters: int, out_filters: int, normalization: bool = True
|
||||
):
|
||||
"""Блок дискриминатора"""
|
||||
layers = [
|
||||
nn.Conv2d(in_filters, out_filters, kernel_size=4, stride=2, padding=1)
|
||||
]
|
||||
if normalization:
|
||||
layers.append(nn.BatchNorm2d(out_filters))
|
||||
layers.append(nn.LeakyReLU(0.2, inplace=True))
|
||||
return layers
|
||||
|
||||
self.model = nn.Sequential(
|
||||
*discriminator_block(in_channels, 64, normalization=False), # 350x350
|
||||
*discriminator_block(64, 128), # 175x175
|
||||
*discriminator_block(128, 256), # 88x88
|
||||
*discriminator_block(256, 512), # 44x44
|
||||
nn.Conv2d(512, 1, kernel_size=4, stride=1, padding=1), # 41x41
|
||||
nn.Sigmoid(),
|
||||
)
|
||||
|
||||
def forward(self, img_A: torch.Tensor, img_B: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Принимает пару изображений (реальное и сгенерированное)
|
||||
и возвращает вероятность того, что пара реальная
|
||||
"""
|
||||
# Объединяем два изображения по каналам
|
||||
img_input = torch.cat((img_A, img_B), 1)
|
||||
return self.model(img_input)
|
||||
|
||||
|
||||
class GANLoss(nn.Module):
|
||||
"""Функция потерь для GAN"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
gan_mode: str = "vanilla",
|
||||
target_real_label: float = 1.0,
|
||||
target_fake_label: float = 0.0,
|
||||
):
|
||||
super().__init__()
|
||||
self.register_buffer("real_label", torch.tensor(target_real_label))
|
||||
self.register_buffer("fake_label", torch.tensor(target_fake_label))
|
||||
self.gan_mode = gan_mode
|
||||
|
||||
if gan_mode == "vanilla":
|
||||
self.loss = nn.BCEWithLogitsLoss()
|
||||
elif gan_mode == "lsgan":
|
||||
self.loss = nn.MSELoss()
|
||||
elif gan_mode == "wgangp":
|
||||
self.loss = None
|
||||
else:
|
||||
raise NotImplementedError(f"GAN mode {gan_mode} not implemented")
|
||||
|
||||
def get_target_tensor(
|
||||
self, prediction: torch.Tensor, target_is_real: bool
|
||||
) -> torch.Tensor:
|
||||
"""Создает тензор меток"""
|
||||
if target_is_real:
|
||||
target_tensor = self.real_label
|
||||
else:
|
||||
target_tensor = self.fake_label
|
||||
return target_tensor.expand_as(prediction)
|
||||
|
||||
def __call__(self, prediction: torch.Tensor, target_is_real: bool) -> torch.Tensor:
|
||||
"""Вычисляет потери"""
|
||||
if self.gan_mode in ["vanilla", "lsgan"]:
|
||||
target_tensor = self.get_target_tensor(prediction, target_is_real)
|
||||
loss = self.loss(prediction, target_tensor)
|
||||
elif self.gan_mode == "wgangp":
|
||||
if target_is_real:
|
||||
loss = -prediction.mean()
|
||||
else:
|
||||
loss = prediction.mean()
|
||||
return loss
|
||||
|
||||
|
||||
class ImageGAN(nn.Module):
|
||||
"""Основной класс GAN для преобразования изображений Yandex → Google"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_channels: int = 3,
|
||||
output_channels: int = 3,
|
||||
gan_mode: str = "vanilla",
|
||||
lambda_L1: float = 100.0,
|
||||
use_cuda: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.generator = GeneratorUNet(input_channels, output_channels)
|
||||
self.discriminator = DiscriminatorPatchGAN(input_channels + output_channels)
|
||||
self.gan_loss = GANLoss(gan_mode)
|
||||
self.l1_loss = nn.L1Loss()
|
||||
self.lambda_L1 = lambda_L1
|
||||
|
||||
self.device = torch.device(
|
||||
"cuda" if use_cuda and torch.cuda.is_available() else "cpu"
|
||||
)
|
||||
self.to(self.device)
|
||||
|
||||
def forward(self, yandex_image: torch.Tensor) -> torch.Tensor:
|
||||
"""Генерация изображения Google из Yandex"""
|
||||
return self.generator(yandex_image)
|
||||
|
||||
def generator_step(
|
||||
self, yandex_image: torch.Tensor, real_google_image: torch.Tensor
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Шаг обучения генератора
|
||||
|
||||
Returns:
|
||||
total_loss: общие потери генератора
|
||||
gan_loss: потери GAN
|
||||
l1_loss: потери L1
|
||||
"""
|
||||
# Генерируем изображение
|
||||
fake_google_image = self.generator(yandex_image)
|
||||
|
||||
# Оцениваем дискриминатором
|
||||
fake_pred = self.discriminator(yandex_image, fake_google_image)
|
||||
|
||||
# Потери GAN (пытаемся обмануть дискриминатор)
|
||||
gan_loss = self.gan_loss(fake_pred, True)
|
||||
|
||||
# Потери L1 для сохранения структуры
|
||||
l1_loss = self.l1_loss(fake_google_image, real_google_image) * self.lambda_L1
|
||||
|
||||
# Общие потери
|
||||
total_loss = gan_loss + l1_loss
|
||||
|
||||
return total_loss, gan_loss, l1_loss
|
||||
|
||||
def discriminator_step(
|
||||
self,
|
||||
yandex_image: torch.Tensor,
|
||||
real_google_image: torch.Tensor,
|
||||
fake_google_image: torch.Tensor,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Шаг обучения дискриминатора
|
||||
|
||||
Returns:
|
||||
total_loss: общие потери дискриминатора
|
||||
real_loss: потери на реальных изображениях
|
||||
fake_loss: потери на сгенерированных изображениях
|
||||
"""
|
||||
# Предсказания для реальных пар
|
||||
real_pred = self.discriminator(yandex_image, real_google_image)
|
||||
real_loss = self.gan_loss(real_pred, True)
|
||||
|
||||
# Предсказания для сгенерированных пар
|
||||
fake_pred = self.discriminator(yandex_image, fake_google_image.detach())
|
||||
fake_loss = self.gan_loss(fake_pred, False)
|
||||
|
||||
# Общие потери дискриминатора
|
||||
total_loss = (real_loss + fake_loss) * 0.5
|
||||
|
||||
return total_loss, real_loss, fake_loss
|
||||
|
||||
def to(self, device):
|
||||
"""Перемещает модель на устройство"""
|
||||
self.generator.to(device)
|
||||
self.discriminator.to(device)
|
||||
return self
|
||||
|
||||
def train_mode(self):
|
||||
"""Переключает модель в режим обучения"""
|
||||
self.generator.train()
|
||||
self.discriminator.train()
|
||||
|
||||
def eval_mode(self):
|
||||
"""Переключает модель в режим оценки"""
|
||||
self.generator.eval()
|
||||
self.discriminator.eval()
|
||||
|
||||
def save_checkpoint(self, path: str):
|
||||
"""Сохраняет чекпоинт модели"""
|
||||
checkpoint = {
|
||||
"generator_state_dict": self.generator.state_dict(),
|
||||
"discriminator_state_dict": self.discriminator.state_dict(),
|
||||
"generator_optimizer_state_dict": getattr(
|
||||
self.generator, "optimizer_state_dict", None
|
||||
),
|
||||
"discriminator_optimizer_state_dict": getattr(
|
||||
self.discriminator, "optimizer_state_dict", None
|
||||
),
|
||||
}
|
||||
torch.save(checkpoint, path)
|
||||
|
||||
def load_checkpoint(self, path: str):
|
||||
"""Загружает чекпоинт модели"""
|
||||
checkpoint = torch.load(path, map_location=self.device)
|
||||
self.generator.load_state_dict(checkpoint["generator_state_dict"])
|
||||
self.discriminator.load_state_dict(checkpoint["discriminator_state_dict"])
|
||||
|
||||
if checkpoint["generator_optimizer_state_dict"] is not None:
|
||||
self.generator.optimizer_state_dict = checkpoint[
|
||||
"generator_optimizer_state_dict"
|
||||
]
|
||||
if checkpoint["discriminator_optimizer_state_dict"] is not None:
|
||||
self.discriminator.optimizer_state_dict = checkpoint[
|
||||
"discriminator_optimizer_state_dict"
|
||||
]
|
||||
|
||||
|
||||
def create_image_gan(
|
||||
input_channels: int = 3,
|
||||
output_channels: int = 3,
|
||||
gan_mode: str = "vanilla",
|
||||
lambda_L1: float = 100.0,
|
||||
use_cuda: bool = True,
|
||||
) -> ImageGAN:
|
||||
"""
|
||||
Создает и возвращает модель GAN для преобразования изображений
|
||||
|
||||
Args:
|
||||
input_channels: количество входных каналов (обычно 3 для RGB)
|
||||
output_channels: количество выходных каналов (обычно 3 для RGB)
|
||||
gan_mode: режим GAN ('vanilla', 'lsgan', 'wgangp')
|
||||
lambda_L1: вес L1 потерь
|
||||
use_cuda: использовать ли CUDA если доступно
|
||||
|
||||
Returns:
|
||||
ImageGAN: модель GAN
|
||||
"""
|
||||
return ImageGAN(
|
||||
input_channels=input_channels,
|
||||
output_channels=output_channels,
|
||||
gan_mode=gan_mode,
|
||||
lambda_L1=lambda_L1,
|
||||
use_cuda=use_cuda,
|
||||
)
|
||||
|
||||
|
||||
# Вспомогательные функции для инициализации весов
|
||||
def weights_init_normal(m):
|
||||
"""Инициализация весов с нормальным распределением"""
|
||||
classname = m.__class__.__name__
|
||||
if classname.find("Conv") != -1:
|
||||
nn.init.normal_(m.weight.data, 0.0, 0.02)
|
||||
elif classname.find("BatchNorm") != -1:
|
||||
nn.init.normal_(m.weight.data, 1.0, 0.02)
|
||||
nn.init.constant_(m.batch_norm.bias.data, 0.0)
|
||||
|
||||
|
||||
def initialize_gan_weights(generator: nn.Module, discriminator: nn.Module):
|
||||
"""Инициализирует веса генератора и дискриминатора"""
|
||||
generator.apply(weights_init_normal)
|
||||
discriminator.apply(weights_init_normal)
|
||||
136
models/GAN/minimal_example.py
Normal file
136
models/GAN/minimal_example.py
Normal file
@@ -0,0 +1,136 @@
|
||||
"""
|
||||
Минимальный пример использования GAN trainer для преобразования Yandex → Google карт.
|
||||
|
||||
Этот пример показывает самый простой способ использования тренера.
|
||||
"""
|
||||
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from torch.utils.data import DataLoader, Dataset
|
||||
|
||||
# Добавляем путь к модулям
|
||||
sys.path.append(str(Path(__file__).parent.parent.parent))
|
||||
|
||||
from models.GAN.gan import create_image_gan
|
||||
from models.GAN.trainer import GANTrainer
|
||||
|
||||
|
||||
class SimpleMapDataset(Dataset):
|
||||
"""Простой датасет с фиктивными данными для примера."""
|
||||
|
||||
def __init__(self, num_samples=100, image_size=(256, 256)):
|
||||
self.num_samples = num_samples
|
||||
self.image_size = image_size
|
||||
|
||||
def __len__(self):
|
||||
return self.num_samples
|
||||
|
||||
def __getitem__(self, idx):
|
||||
# Создаем фиктивные изображения
|
||||
# В реальном коде замените на загрузку реальных изображений
|
||||
yandex_img = torch.randn(3, self.image_size[0], self.image_size[1])
|
||||
google_img = torch.randn(3, self.image_size[0], self.image_size[1])
|
||||
|
||||
return {"yandex_img": yandex_img, "google_img": google_img}
|
||||
|
||||
|
||||
def main():
|
||||
"""Основная функция минимального примера."""
|
||||
print("Минимальный пример использования GAN trainer")
|
||||
print("=" * 50)
|
||||
|
||||
# 1. Конфигурация (минимальный набор параметров)
|
||||
config = {
|
||||
"learning_rate": 2e-4,
|
||||
"batch_size": 4,
|
||||
"output_dir": "runs/gan_minimal",
|
||||
}
|
||||
|
||||
# 2. Устройство (CPU или GPU)
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
print(f"Используемое устройство: {device}")
|
||||
|
||||
# 3. Создание модели
|
||||
print("\nСоздание GAN модели...")
|
||||
model = create_image_gan(
|
||||
input_channels=3,
|
||||
output_channels=3,
|
||||
gan_mode="vanilla", # Простейший режим
|
||||
lambda_L1=100.0, # Стандартный вес L1 потерь
|
||||
use_cuda=(device.type == "cuda"),
|
||||
)
|
||||
|
||||
# 4. Создание даталоадеров
|
||||
print("Создание даталоадеров...")
|
||||
train_dataset = SimpleMapDataset(num_samples=50)
|
||||
val_dataset = SimpleMapDataset(num_samples=10)
|
||||
|
||||
train_loader = DataLoader(
|
||||
train_dataset,
|
||||
batch_size=config["batch_size"],
|
||||
shuffle=True,
|
||||
num_workers=0,
|
||||
)
|
||||
|
||||
val_loader = DataLoader(
|
||||
val_dataset,
|
||||
batch_size=config["batch_size"],
|
||||
shuffle=False,
|
||||
num_workers=0,
|
||||
)
|
||||
|
||||
print(f" Обучающих примеров: {len(train_dataset)}")
|
||||
print(f" Валидационных примеров: {len(val_dataset)}")
|
||||
|
||||
# 5. Создание тренера
|
||||
print("\nСоздание тренера...")
|
||||
trainer = GANTrainer(
|
||||
model=model,
|
||||
train_loader=train_loader,
|
||||
val_loader=val_loader,
|
||||
device=device,
|
||||
config=config,
|
||||
)
|
||||
|
||||
# 6. Обучение на небольшом количестве эпох
|
||||
print("\nЗапуск обучения (3 эпохи для примера)...")
|
||||
print("=" * 50)
|
||||
|
||||
trainer.train(num_epochs=3)
|
||||
|
||||
# 7. Генерация примеров
|
||||
print("\nГенерация примеров преобразования...")
|
||||
model.eval()
|
||||
|
||||
# Создаем тестовые данные
|
||||
test_yandex = torch.randn(2, 3, 256, 256).to(device)
|
||||
|
||||
with torch.no_grad():
|
||||
generated_google = model(test_yandex)
|
||||
|
||||
print(f"Входные изображения: {test_yandex.shape}")
|
||||
print(f"Сгенерированные изображения: {generated_google.shape}")
|
||||
print(
|
||||
f"Диапазон значений: [{generated_google.min():.3f}, {generated_google.max():.3f}]"
|
||||
)
|
||||
|
||||
# 8. Сохранение финальной модели
|
||||
print("\nСохранение модели...")
|
||||
model_save_path = "gan_model_minimal.pth"
|
||||
torch.save(model.state_dict(), model_save_path)
|
||||
print(f"Модель сохранена в: {model_save_path}")
|
||||
|
||||
print("\n" + "=" * 50)
|
||||
print("Минимальный пример завершен!")
|
||||
print("\nДля реального использования:")
|
||||
print("1. Замените SimpleMapDataset на ваш реальный датасет")
|
||||
print("2. Настройте параметры в config")
|
||||
print("3. Увеличьте количество эпох (например, до 100)")
|
||||
print("4. Используйте реальные изображения карт")
|
||||
print("=" * 50)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
349
models/GAN/test_gan.py
Normal file
349
models/GAN/test_gan.py
Normal file
@@ -0,0 +1,349 @@
|
||||
import os
|
||||
import sys
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
# Добавляем путь к модулю
|
||||
sys.path.append(
|
||||
os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
)
|
||||
|
||||
from gan import (
|
||||
DiscriminatorPatchGAN,
|
||||
GeneratorUNet,
|
||||
ImageGAN,
|
||||
create_image_gan,
|
||||
initialize_gan_weights,
|
||||
)
|
||||
|
||||
|
||||
def test_generator():
|
||||
"""Тестирование генератора"""
|
||||
print("=" * 60)
|
||||
print("Тестирование генератора...")
|
||||
print("=" * 60)
|
||||
|
||||
# Создаем генератор
|
||||
generator = GeneratorUNet(in_channels=3, out_channels=3)
|
||||
|
||||
# Инициализируем веса
|
||||
generator.apply(
|
||||
lambda m: (
|
||||
nn.init.normal_(m.weight.data, 0.0, 0.02)
|
||||
if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d)
|
||||
else None
|
||||
)
|
||||
)
|
||||
|
||||
# Создаем тестовый входной тензор (Yandex изображение)
|
||||
batch_size = 2
|
||||
height, width = 700, 700
|
||||
yandex_image = torch.randn(batch_size, 3, height, width)
|
||||
|
||||
print(f"Размер входного изображения: {yandex_image.shape}")
|
||||
print(
|
||||
f"Количество параметров генератора: {sum(p.numel() for p in generator.parameters()):,}"
|
||||
)
|
||||
|
||||
# Прямой проход
|
||||
with torch.no_grad():
|
||||
generated_image = generator(yandex_image)
|
||||
|
||||
print(f"Размер сгенерированного изображения: {generated_image.shape}")
|
||||
print(
|
||||
f"Диапазон значений сгенерированного изображения: [{generated_image.min():.3f}, {generated_image.max():.3f}]"
|
||||
)
|
||||
|
||||
# Проверка размеров
|
||||
assert generated_image.shape == (batch_size, 3, height, width), (
|
||||
f"Ожидался размер {(batch_size, 3, height, width)}, получен {generated_image.shape}"
|
||||
)
|
||||
|
||||
print("✓ Генератор работает корректно!")
|
||||
return generator
|
||||
|
||||
|
||||
def test_discriminator():
|
||||
"""Тестирование дискриминатора"""
|
||||
print("\n" + "=" * 60)
|
||||
print("Тестирование дискриминатора...")
|
||||
print("=" * 60)
|
||||
|
||||
# Создаем дискриминатор
|
||||
discriminator = DiscriminatorPatchGAN(in_channels=6) # 3 + 3 канала
|
||||
|
||||
# Инициализируем веса
|
||||
discriminator.apply(
|
||||
lambda m: (
|
||||
nn.init.normal_(m.weight.data, 0.0, 0.02)
|
||||
if isinstance(m, nn.Conv2d)
|
||||
else None
|
||||
)
|
||||
)
|
||||
|
||||
# Создаем тестовые тензоры
|
||||
batch_size = 2
|
||||
height, width = 700, 700
|
||||
|
||||
yandex_image = torch.randn(batch_size, 3, height, width)
|
||||
google_image = torch.randn(batch_size, 3, height, width)
|
||||
|
||||
print(f"Размер Yandex изображения: {yandex_image.shape}")
|
||||
print(f"Размер Google изображения: {google_image.shape}")
|
||||
print(
|
||||
f"Количество параметров дискриминатора: {sum(p.numel() for p in discriminator.parameters()):,}"
|
||||
)
|
||||
|
||||
# Прямой проход
|
||||
with torch.no_grad():
|
||||
prediction = discriminator(yandex_image, google_image)
|
||||
|
||||
print(f"Размер выхода дискриминатора: {prediction.shape}")
|
||||
print(
|
||||
f"Диапазон значений предсказания: [{prediction.min():.3f}, {prediction.max():.3f}]"
|
||||
)
|
||||
|
||||
# Проверка размеров (PatchGAN выдает карту вероятностей)
|
||||
expected_height = 43 # Для изображения 700x700 после 4 downsampling блоков
|
||||
expected_width = 43
|
||||
assert prediction.shape == (batch_size, 1, expected_height, expected_width), (
|
||||
f"Ожидался размер {(batch_size, 1, expected_height, expected_width)}, получен {prediction.shape}"
|
||||
)
|
||||
|
||||
print(
|
||||
f"✓ Дискриминатор работает корректно! Выходной размер: {prediction.shape[2]}x{prediction.shape[3]}"
|
||||
)
|
||||
return discriminator
|
||||
|
||||
|
||||
def test_gan_model():
|
||||
"""Тестирование полной GAN модели"""
|
||||
print("\n" + "=" * 60)
|
||||
print("Тестирование полной GAN модели...")
|
||||
print("=" * 60)
|
||||
|
||||
# Создаем GAN модель
|
||||
gan = ImageGAN(
|
||||
input_channels=3,
|
||||
output_channels=3,
|
||||
gan_mode="vanilla",
|
||||
lambda_L1=100.0,
|
||||
use_cuda=False, # Тестируем на CPU для простоты
|
||||
)
|
||||
|
||||
print(f"Устройство модели: {gan.device}")
|
||||
print(
|
||||
f"Количество параметров генератора: {sum(p.numel() for p in gan.generator.parameters()):,}"
|
||||
)
|
||||
print(
|
||||
f"Количество параметров дискриминатора: {sum(p.numel() for p in gan.discriminator.parameters()):,}"
|
||||
)
|
||||
print(f"Общее количество параметров: {sum(p.numel() for p in gan.parameters()):,}")
|
||||
|
||||
# Создаем тестовые данные
|
||||
batch_size = 2
|
||||
height, width = 700, 700
|
||||
|
||||
yandex_image = torch.randn(batch_size, 3, height, width)
|
||||
real_google_image = torch.randn(batch_size, 3, height, width)
|
||||
|
||||
print(f"\nТестирование прямого прохода...")
|
||||
with torch.no_grad():
|
||||
generated_image = gan(yandex_image)
|
||||
|
||||
print(f"Размер сгенерированного изображения: {generated_image.shape}")
|
||||
|
||||
print(f"\nТестирование шага генератора...")
|
||||
gan.train_mode()
|
||||
|
||||
# Тестируем шаг генератора
|
||||
total_loss, gan_loss, l1_loss = gan.generator_step(yandex_image, real_google_image)
|
||||
|
||||
print(f"Общие потери генератора: {total_loss.item():.6f}")
|
||||
print(f"Потери GAN: {gan_loss.item():.6f}")
|
||||
print(f"Потери L1: {l1_loss.item():.6f}")
|
||||
|
||||
print(f"\nТестирование шага дискриминатора...")
|
||||
# Создаем сгенерированное изображение для дискриминатора
|
||||
with torch.no_grad():
|
||||
fake_google_image = gan.generator(yandex_image)
|
||||
|
||||
total_d_loss, real_loss, fake_loss = gan.discriminator_step(
|
||||
yandex_image, real_google_image, fake_google_image
|
||||
)
|
||||
|
||||
print(f"Общие потери дискриминатора: {total_d_loss.item():.6f}")
|
||||
print(f"Потери на реальных изображениях: {real_loss.item():.6f}")
|
||||
print(f"Потери на сгенерированных изображениях: {fake_loss.item():.6f}")
|
||||
|
||||
print(f"\nТестирование режимов обучения/оценки...")
|
||||
gan.eval_mode()
|
||||
print(f"Генератор в режиме eval: {not gan.generator.training}")
|
||||
print(f"Дискриминатор в режиме eval: {not gan.discriminator.training}")
|
||||
|
||||
gan.train_mode()
|
||||
print(f"Генератор в режиме train: {gan.generator.training}")
|
||||
print(f"Дискриминатор в режиме train: {gan.discriminator.training}")
|
||||
|
||||
print("\n✓ Полная GAN модель работает корректно!")
|
||||
return gan
|
||||
|
||||
|
||||
def test_factory_function():
|
||||
"""Тестирование фабричной функции"""
|
||||
print("\n" + "=" * 60)
|
||||
print("Тестирование фабричной функции...")
|
||||
print("=" * 60)
|
||||
|
||||
# Тестируем разные режимы GAN
|
||||
for gan_mode in ["vanilla", "lsgan"]:
|
||||
print(f"\nСоздание GAN в режиме '{gan_mode}'...")
|
||||
gan = create_image_gan(
|
||||
input_channels=3,
|
||||
output_channels=3,
|
||||
gan_mode=gan_mode,
|
||||
lambda_L1=100.0,
|
||||
use_cuda=False,
|
||||
)
|
||||
|
||||
print(f" Режим GAN: {gan.gan_loss.gan_mode}")
|
||||
print(f" Вес L1 потерь: {gan.lambda_L1}")
|
||||
print(f" Устройство: {gan.device}")
|
||||
|
||||
# Быстрая проверка прямого прохода
|
||||
batch_size = 1
|
||||
yandex_image = torch.randn(batch_size, 3, 700, 700)
|
||||
|
||||
with torch.no_grad():
|
||||
generated = gan(yandex_image)
|
||||
|
||||
print(f" Размер выхода: {generated.shape}")
|
||||
print(f" ✓ GAN в режиме '{gan_mode}' создан успешно")
|
||||
|
||||
print("\n✓ Фабричная функция работает корректно!")
|
||||
|
||||
|
||||
def test_weights_initialization():
|
||||
"""Тестирование инициализации весов"""
|
||||
print("\n" + "=" * 60)
|
||||
print("Тестирование инициализации весов...")
|
||||
print("=" * 60)
|
||||
|
||||
# Создаем модели
|
||||
generator = GeneratorUNet(3, 3)
|
||||
discriminator = DiscriminatorPatchGAN(6)
|
||||
|
||||
# Инициализируем веса
|
||||
initialize_gan_weights(generator, discriminator)
|
||||
|
||||
# Проверяем средние значения весов
|
||||
def check_weights_mean(model, model_name):
|
||||
conv_weights = []
|
||||
for name, param in model.named_parameters():
|
||||
if "weight" in name and (
|
||||
"conv" in name.lower() or "Conv" in str(param.__class__)
|
||||
):
|
||||
conv_weights.append(param.data.mean().item())
|
||||
|
||||
if conv_weights:
|
||||
avg_mean = sum(conv_weights) / len(conv_weights)
|
||||
print(f" Среднее значение весов Conv слоев в {model_name}: {avg_mean:.6f}")
|
||||
# Проверяем, что веса инициализированы около 0
|
||||
assert abs(avg_mean) < 0.1, f"Веса {model_name} не инициализированы около 0"
|
||||
|
||||
check_weights_mean(generator, "генераторе")
|
||||
check_weights_mean(discriminator, "дискриминаторе")
|
||||
|
||||
print("✓ Инициализация весов работает корректно!")
|
||||
|
||||
|
||||
def test_memory_usage():
|
||||
"""Тестирование использования памяти"""
|
||||
print("\n" + "=" * 60)
|
||||
print("Тестирование использования памяти...")
|
||||
print("=" * 60)
|
||||
|
||||
import os
|
||||
|
||||
import psutil
|
||||
|
||||
# Получаем текущее использование памяти
|
||||
process = psutil.Process(os.getpid())
|
||||
memory_before = process.memory_info().rss / 1024 / 1024 # в MB
|
||||
|
||||
print(f"Память до создания моделей: {memory_before:.2f} MB")
|
||||
|
||||
# Создаем несколько моделей
|
||||
models = []
|
||||
for i in range(3):
|
||||
gan = create_image_gan(use_cuda=False)
|
||||
models.append(gan)
|
||||
|
||||
# Делаем тестовый проход
|
||||
batch_size = 1
|
||||
yandex_image = torch.randn(batch_size, 3, 700, 700)
|
||||
real_google_image = torch.randn(batch_size, 3, 700, 700)
|
||||
|
||||
with torch.no_grad():
|
||||
_ = gan(yandex_image)
|
||||
_ = gan.generator_step(yandex_image, real_google_image)
|
||||
|
||||
memory_after = process.memory_info().rss / 1024 / 1024 # в MB
|
||||
memory_used = memory_after - memory_before
|
||||
|
||||
print(f"Память после создания моделей: {memory_after:.2f} MB")
|
||||
print(f"Использовано памяти: {memory_used:.2f} MB")
|
||||
|
||||
# Очищаем модели
|
||||
del models
|
||||
import gc
|
||||
|
||||
gc.collect()
|
||||
|
||||
memory_final = process.memory_info().rss / 1024 / 1024
|
||||
print(f"Память после очистки: {memory_final:.2f} MB")
|
||||
|
||||
print("✓ Тестирование памяти завершено!")
|
||||
|
||||
|
||||
def main():
|
||||
"""Основная функция тестирования"""
|
||||
print("Начало тестирования GAN архитектуры для преобразования Yandex → Google")
|
||||
print("Размер изображения: 700x700 пикселей")
|
||||
print("=" * 60)
|
||||
|
||||
try:
|
||||
# Запускаем все тесты
|
||||
test_generator()
|
||||
test_discriminator()
|
||||
test_gan_model()
|
||||
test_factory_function()
|
||||
test_weights_initialization()
|
||||
test_memory_usage()
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("ВСЕ ТЕСТЫ ПРОЙДЕНЫ УСПЕШНО! 🎉")
|
||||
print("=" * 60)
|
||||
print("\nАрхитектура GAN готова к использованию для преобразования")
|
||||
print("изображений из стиля Yandex в стиль Google.")
|
||||
print("\nОсновные характеристики:")
|
||||
print(" • Генератор: U-Net архитектура")
|
||||
print(" • Дискриминатор: PatchGAN (43x43 патчей)")
|
||||
print(" • Размер входных/выходных изображений: 700x700")
|
||||
print(" • Поддержка режимов: vanilla, lsgan")
|
||||
print(" • L1 регуляризация для сохранения структуры")
|
||||
|
||||
except Exception as e:
|
||||
print(f"\n❌ Ошибка при тестировании: {e}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
return 1
|
||||
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
exit_code = main()
|
||||
sys.exit(exit_code)
|
||||
342
models/GAN/test_trainer.py
Normal file
342
models/GAN/test_trainer.py
Normal file
@@ -0,0 +1,342 @@
|
||||
"""
|
||||
Тестовый скрипт для проверки GAN trainer.
|
||||
"""
|
||||
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.utils.data import DataLoader, Dataset
|
||||
|
||||
# Добавляем путь к модулям
|
||||
sys.path.append(str(Path(__file__).parent.parent.parent))
|
||||
|
||||
from models.GAN.gan import create_image_gan
|
||||
from models.GAN.trainer import GANTrainer
|
||||
|
||||
|
||||
class SimpleDataset(Dataset):
|
||||
"""Простой датасет для тестирования."""
|
||||
|
||||
def __init__(self, num_samples=100, image_size=(256, 256)):
|
||||
self.num_samples = num_samples
|
||||
self.image_size = image_size
|
||||
|
||||
def __len__(self):
|
||||
return self.num_samples
|
||||
|
||||
def __getitem__(self, idx):
|
||||
# Создаем случайные изображения для тестирования
|
||||
yandex_img = torch.randn(3, self.image_size[0], self.image_size[1])
|
||||
google_img = torch.randn(3, self.image_size[0], self.image_size[1])
|
||||
|
||||
return {"yandex_img": yandex_img, "google_img": google_img}
|
||||
|
||||
|
||||
def test_gan_model():
|
||||
"""Тестирование GAN модели."""
|
||||
print("Тестирование GAN модели...")
|
||||
|
||||
# Создаем модель
|
||||
model = create_image_gan(
|
||||
input_channels=3,
|
||||
output_channels=3,
|
||||
gan_mode="vanilla",
|
||||
lambda_L1=100.0,
|
||||
use_cuda=False, # Используем CPU для тестирования
|
||||
)
|
||||
|
||||
# Тестируем forward pass
|
||||
batch_size = 2
|
||||
image_size = (256, 256)
|
||||
yandex_input = torch.randn(batch_size, 3, *image_size)
|
||||
|
||||
with torch.no_grad():
|
||||
output = model(yandex_input)
|
||||
|
||||
print(f"Входной размер: {yandex_input.shape}")
|
||||
print(f"Выходной размер: {output.shape}")
|
||||
print(f"Диапазон выходных значений: [{output.min():.3f}, {output.max():.3f}]")
|
||||
|
||||
# Проверяем, что выход в диапазоне [-1, 1] (из-за Tanh)
|
||||
assert output.min() >= -1.0 and output.max() <= 1.0, "Выход не в диапазоне [-1, 1]"
|
||||
print("✓ Forward pass работает корректно")
|
||||
|
||||
return model
|
||||
|
||||
|
||||
def test_generator_step():
|
||||
"""Тестирование шага генератора."""
|
||||
print("\nТестирование шага генератора...")
|
||||
|
||||
model = create_image_gan(use_cuda=False)
|
||||
model.train_mode()
|
||||
|
||||
batch_size = 2
|
||||
yandex_img = torch.randn(batch_size, 3, 256, 256)
|
||||
google_img = torch.randn(batch_size, 3, 256, 256)
|
||||
|
||||
# Тестируем generator_step
|
||||
total_loss, gan_loss, l1_loss = model.generator_step(yandex_img, google_img)
|
||||
|
||||
print(f"Total loss: {total_loss.item():.6f}")
|
||||
print(f"GAN loss: {gan_loss.item():.6f}")
|
||||
print(f"L1 loss: {l1_loss.item():.6f}")
|
||||
|
||||
assert total_loss.item() > 0, "Потери должны быть положительными"
|
||||
print("✓ Шаг генератора работает корректно")
|
||||
|
||||
|
||||
def test_discriminator_step():
|
||||
"""Тестирование шага дискриминатора."""
|
||||
print("\nТестирование шага дискриминатора...")
|
||||
|
||||
model = create_image_gan(use_cuda=False)
|
||||
model.train_mode()
|
||||
|
||||
batch_size = 2
|
||||
yandex_img = torch.randn(batch_size, 3, 256, 256)
|
||||
google_img = torch.randn(batch_size, 3, 256, 256)
|
||||
|
||||
# Генерируем fake изображение
|
||||
with torch.no_grad():
|
||||
fake_google_img = model.generator(yandex_img)
|
||||
|
||||
# Тестируем discriminator_step
|
||||
total_loss, real_loss, fake_loss = model.discriminator_step(
|
||||
yandex_img, google_img, fake_google_img
|
||||
)
|
||||
|
||||
print(f"Total loss: {total_loss.item():.6f}")
|
||||
print(f"Real loss: {real_loss.item():.6f}")
|
||||
print(f"Fake loss: {fake_loss.item():.6f}")
|
||||
|
||||
assert total_loss.item() > 0, "Потери должны быть положительными"
|
||||
print("✓ Шаг дискриминатора работает корректно")
|
||||
|
||||
|
||||
def test_trainer_initialization():
|
||||
"""Тестирование инициализации тренера."""
|
||||
print("\nТестирование инициализации тренера...")
|
||||
|
||||
# Создаем модель
|
||||
model = create_image_gan(use_cuda=False)
|
||||
|
||||
# Создаем даталоадеры
|
||||
train_dataset = SimpleDataset(num_samples=50)
|
||||
val_dataset = SimpleDataset(num_samples=10)
|
||||
|
||||
train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True)
|
||||
val_loader = DataLoader(val_dataset, batch_size=2, shuffle=False)
|
||||
|
||||
# Конфигурация
|
||||
config = {
|
||||
"learning_rate": 2e-4,
|
||||
"beta1": 0.5,
|
||||
"beta2": 0.999,
|
||||
"output_dir": "test_runs/gan",
|
||||
"early_stopping_patience": 10,
|
||||
}
|
||||
|
||||
# Создаем тренер
|
||||
device = torch.device("cpu")
|
||||
trainer = GANTrainer(
|
||||
model=model,
|
||||
train_loader=train_loader,
|
||||
val_loader=val_loader,
|
||||
device=device,
|
||||
config=config,
|
||||
)
|
||||
|
||||
print(f"Тренер создан успешно")
|
||||
print(f"Оптимизатор генератора: {type(trainer.optimizer_G).__name__}")
|
||||
print(f"Оптимизатор дискриминатора: {type(trainer.optimizer_D).__name__}")
|
||||
print(f"Выходная директория: {trainer.output_dir}")
|
||||
|
||||
assert trainer.output_dir.exists(), "Выходная директория не создана"
|
||||
print("✓ Тренер инициализирован корректно")
|
||||
|
||||
return trainer, train_loader, val_loader
|
||||
|
||||
|
||||
def test_train_epoch():
|
||||
"""Тестирование одной эпохи обучения."""
|
||||
print("\nТестирование одной эпохи обучения...")
|
||||
|
||||
model = create_image_gan(use_cuda=False)
|
||||
train_dataset = SimpleDataset(num_samples=20)
|
||||
train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True)
|
||||
val_loader = DataLoader(SimpleDataset(num_samples=5), batch_size=2, shuffle=False)
|
||||
|
||||
config = {
|
||||
"learning_rate": 2e-4,
|
||||
"beta1": 0.5,
|
||||
"beta2": 0.999,
|
||||
"output_dir": "test_runs/gan",
|
||||
}
|
||||
|
||||
device = torch.device("cpu")
|
||||
trainer = GANTrainer(
|
||||
model=model,
|
||||
train_loader=train_loader,
|
||||
val_loader=val_loader,
|
||||
device=device,
|
||||
config=config,
|
||||
)
|
||||
|
||||
# Запускаем одну эпоху обучения
|
||||
avg_g_loss, avg_d_loss = trainer.train_epoch()
|
||||
|
||||
print(f"Средние потери за эпоху:")
|
||||
print(f" Генератор: {avg_g_loss:.6f}")
|
||||
print(f" Дискриминатор: {avg_d_loss:.6f}")
|
||||
|
||||
assert avg_g_loss > 0, "Потери генератора должны быть положительными"
|
||||
assert avg_d_loss > 0, "Потери дискриминатора должны быть положительными"
|
||||
print("✓ Эпоха обучения завершена успешно")
|
||||
|
||||
|
||||
def test_validation():
|
||||
"""Тестирование валидации."""
|
||||
print("\nТестирование валидации...")
|
||||
|
||||
model = create_image_gan(use_cuda=False)
|
||||
train_loader = DataLoader(SimpleDataset(num_samples=10), batch_size=2, shuffle=True)
|
||||
val_loader = DataLoader(SimpleDataset(num_samples=5), batch_size=2, shuffle=False)
|
||||
|
||||
config = {
|
||||
"learning_rate": 2e-4,
|
||||
"beta1": 0.5,
|
||||
"beta2": 0.999,
|
||||
"output_dir": "test_runs/gan",
|
||||
}
|
||||
|
||||
device = torch.device("cpu")
|
||||
trainer = GANTrainer(
|
||||
model=model,
|
||||
train_loader=train_loader,
|
||||
val_loader=val_loader,
|
||||
device=device,
|
||||
config=config,
|
||||
)
|
||||
|
||||
# Запускаем валидацию
|
||||
val_g_loss, val_d_loss = trainer.validate()
|
||||
|
||||
print(f"Потери на валидации:")
|
||||
print(f" Генератор: {val_g_loss:.6f}")
|
||||
print(f" Дискриминатор: {val_d_loss:.6f}")
|
||||
|
||||
assert val_g_loss > 0, "Потери генератора должны быть положительными"
|
||||
assert val_d_loss > 0, "Потери дискриминатора должны быть положительными"
|
||||
print("✓ Валидация завершена успешно")
|
||||
|
||||
|
||||
def test_checkpoint_saving():
|
||||
"""Тестирование сохранения чекпоинтов."""
|
||||
print("\nТестирование сохранения чекпоинтов...")
|
||||
|
||||
model = create_image_gan(use_cuda=False)
|
||||
train_loader = DataLoader(SimpleDataset(num_samples=10), batch_size=2, shuffle=True)
|
||||
val_loader = DataLoader(SimpleDataset(num_samples=5), batch_size=2, shuffle=False)
|
||||
|
||||
config = {
|
||||
"learning_rate": 2e-4,
|
||||
"beta1": 0.5,
|
||||
"beta2": 0.999,
|
||||
"output_dir": "test_runs/gan_checkpoint",
|
||||
}
|
||||
|
||||
device = torch.device("cpu")
|
||||
trainer = GANTrainer(
|
||||
model=model,
|
||||
train_loader=train_loader,
|
||||
val_loader=val_loader,
|
||||
device=device,
|
||||
config=config,
|
||||
)
|
||||
|
||||
# Сохраняем чекпоинт
|
||||
trainer.save_checkpoint(is_best=True)
|
||||
|
||||
# Проверяем, что файлы созданы
|
||||
checkpoint_files = list(trainer.output_dir.glob("*.pth"))
|
||||
print(f"Создано файлов чекпоинтов: {len(checkpoint_files)}")
|
||||
|
||||
for file in checkpoint_files:
|
||||
print(f" - {file.name}")
|
||||
|
||||
assert len(checkpoint_files) > 0, "Файлы чекпоинтов не созданы"
|
||||
print("✓ Чекпоинты сохранены успешно")
|
||||
|
||||
# Тестируем загрузку чекпоинта
|
||||
checkpoint_path = checkpoint_files[0]
|
||||
print(f"\nТестируем загрузку чекпоинта: {checkpoint_path}")
|
||||
|
||||
# Создаем новую модель и тренер
|
||||
new_model = create_image_gan(use_cuda=False)
|
||||
new_trainer = GANTrainer(
|
||||
model=new_model,
|
||||
train_loader=train_loader,
|
||||
val_loader=val_loader,
|
||||
device=device,
|
||||
config=config,
|
||||
)
|
||||
|
||||
# Загружаем чекпоинт
|
||||
new_trainer.load_checkpoint(str(checkpoint_path))
|
||||
|
||||
print(f"Загружен чекпоинт эпохи: {new_trainer.current_epoch + 1}")
|
||||
print("✓ Чекпоинт загружен успешно")
|
||||
|
||||
|
||||
def main():
|
||||
"""Основная функция тестирования."""
|
||||
print("=" * 60)
|
||||
print("Начало тестирования GAN trainer")
|
||||
print("=" * 60)
|
||||
|
||||
try:
|
||||
# Запускаем все тесты
|
||||
test_gan_model()
|
||||
test_generator_step()
|
||||
test_discriminator_step()
|
||||
test_trainer_initialization()
|
||||
test_train_epoch()
|
||||
test_validation()
|
||||
test_checkpoint_saving()
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("Все тесты пройдены успешно! ✓")
|
||||
print("=" * 60)
|
||||
|
||||
except Exception as e:
|
||||
print(f"\nОшибка при тестировании: {e}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
return 1
|
||||
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Очищаем тестовые директории
|
||||
import shutil
|
||||
|
||||
test_dirs = ["test_runs/gan", "test_runs/gan_checkpoint"]
|
||||
for dir_path in test_dirs:
|
||||
if Path(dir_path).exists():
|
||||
shutil.rmtree(dir_path)
|
||||
|
||||
# Запускаем тесты
|
||||
exit_code = main()
|
||||
|
||||
# Очищаем после тестов
|
||||
for dir_path in test_dirs:
|
||||
if Path(dir_path).exists():
|
||||
shutil.rmtree(dir_path)
|
||||
|
||||
exit(exit_code)
|
||||
347
models/GAN/train_example.py
Normal file
347
models/GAN/train_example.py
Normal file
@@ -0,0 +1,347 @@
|
||||
"""
|
||||
Пример обучения GAN модели для преобразования Yandex → Google карт.
|
||||
|
||||
Этот скрипт показывает, как использовать GANTrainer для обучения модели.
|
||||
"""
|
||||
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
# Добавляем путь к модулям
|
||||
sys.path.append(str(Path(__file__).parent.parent.parent))
|
||||
|
||||
from models.GAN.gan import create_image_gan
|
||||
from models.GAN.trainer import GANTrainer
|
||||
|
||||
|
||||
def create_simple_config():
|
||||
"""Создает простую конфигурацию для обучения."""
|
||||
config = {
|
||||
# Параметры оптимизатора
|
||||
"learning_rate": 2e-4,
|
||||
"beta1": 0.5,
|
||||
"beta2": 0.999,
|
||||
# Параметры обучения
|
||||
"batch_size": 4,
|
||||
"epochs": 100,
|
||||
# Параметры GAN
|
||||
"gan_mode": "vanilla", # "vanilla", "lsgan", или "wgangp"
|
||||
"lambda_L1": 100.0, # Вес L1 потерь
|
||||
# Регуляризация
|
||||
"grad_clip": 1.0,
|
||||
# Ранняя остановка
|
||||
"early_stopping_patience": 20,
|
||||
# Выходные данные
|
||||
"output_dir": "runs/gan_training",
|
||||
# Логирование
|
||||
"log_interval": 10, # Логировать каждые N батчей
|
||||
"save_interval": 5, # Сохранять чекпоинт каждые N эпох
|
||||
}
|
||||
return config
|
||||
|
||||
|
||||
def create_advanced_config():
|
||||
"""Создает расширенную конфигурацию для обучения."""
|
||||
config = {
|
||||
# Параметры оптимизатора
|
||||
"learning_rate": 2e-4,
|
||||
"beta1": 0.5,
|
||||
"beta2": 0.999,
|
||||
# Планировщик learning rate
|
||||
"use_scheduler": True,
|
||||
"scheduler_type": "linear", # "linear", "cosine", или "plateau"
|
||||
"scheduler_start_epoch": 50,
|
||||
"scheduler_end_epoch": 100,
|
||||
# Параметры обучения
|
||||
"batch_size": 8,
|
||||
"epochs": 200,
|
||||
# Параметры GAN
|
||||
"gan_mode": "lsgan", # LSGAN обычно более стабилен
|
||||
"lambda_L1": 100.0,
|
||||
# Аугментация данных
|
||||
"augmentation": {
|
||||
"random_crop": True,
|
||||
"crop_size": 256,
|
||||
"random_flip": True,
|
||||
"color_jitter": True,
|
||||
"brightness": 0.2,
|
||||
"contrast": 0.2,
|
||||
"saturation": 0.2,
|
||||
"hue": 0.1,
|
||||
},
|
||||
# Регуляризация
|
||||
"grad_clip": 1.0,
|
||||
"weight_decay": 1e-4,
|
||||
# Ранняя остановка
|
||||
"early_stopping_patience": 30,
|
||||
"early_stopping_min_delta": 1e-4,
|
||||
# Выходные данные
|
||||
"output_dir": "runs/gan_advanced",
|
||||
# Логирование
|
||||
"log_interval": 20,
|
||||
"save_interval": 10,
|
||||
"save_best_only": True, # Сохранять только лучшую модель
|
||||
# Визуализация
|
||||
"visualize_samples": True,
|
||||
"num_visualize": 4,
|
||||
"visualize_interval": 5, # Визуализировать каждые N эпох
|
||||
}
|
||||
return config
|
||||
|
||||
|
||||
def print_config_summary(config):
|
||||
"""Печатает сводку конфигурации."""
|
||||
print("=" * 60)
|
||||
print("Конфигурация обучения GAN")
|
||||
print("=" * 60)
|
||||
|
||||
print(f"\nПараметры модели:")
|
||||
print(f" Режим GAN: {config.get('gan_mode', 'vanilla')}")
|
||||
print(f" Вес L1 потерь: {config.get('lambda_L1', 100.0)}")
|
||||
|
||||
print(f"\nПараметры обучения:")
|
||||
print(f" Learning rate: {config.get('learning_rate', 2e-4)}")
|
||||
print(f" Batch size: {config.get('batch_size', 4)}")
|
||||
print(f" Эпох: {config.get('epochs', 100)}")
|
||||
print(f" Beta1: {config.get('beta1', 0.5)}")
|
||||
print(f" Beta2: {config.get('beta2', 0.999)}")
|
||||
|
||||
if config.get("use_scheduler", False):
|
||||
print(f" Планировщик LR: {config.get('scheduler_type', 'linear')}")
|
||||
|
||||
print(f"\nРегуляризация:")
|
||||
print(f" Gradient clipping: {config.get('grad_clip', 1.0)}")
|
||||
if "weight_decay" in config:
|
||||
print(f" Weight decay: {config['weight_decay']}")
|
||||
|
||||
print(f"\nРанняя остановка:")
|
||||
if config.get("early_stopping_patience", 0) > 0:
|
||||
print(f" Patience: {config['early_stopping_patience']} эпох")
|
||||
if "early_stopping_min_delta" in config:
|
||||
print(f" Min delta: {config['early_stopping_min_delta']}")
|
||||
|
||||
print(f"\nВыходные данные:")
|
||||
print(f" Директория: {config.get('output_dir', 'runs/gan')}")
|
||||
print(f" Интервал сохранения: {config.get('save_interval', 5)} эпох")
|
||||
|
||||
print(f"\nЛогирование:")
|
||||
print(f" Интервал логирования: {config.get('log_interval', 10)} батчей")
|
||||
|
||||
print("=" * 60)
|
||||
|
||||
|
||||
def setup_training():
|
||||
"""Настраивает обучение."""
|
||||
print("Настройка обучения GAN...")
|
||||
|
||||
# Выбираем конфигурацию
|
||||
use_advanced = False # Измените на True для расширенной конфигурации
|
||||
|
||||
if use_advanced:
|
||||
config = create_advanced_config()
|
||||
else:
|
||||
config = create_simple_config()
|
||||
|
||||
# Печатаем сводку конфигурации
|
||||
print_config_summary(config)
|
||||
|
||||
# Устройство
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
print(f"\nИспользуемое устройство: {device}")
|
||||
|
||||
if device.type == "cuda":
|
||||
print(f" GPU: {torch.cuda.get_device_name(0)}")
|
||||
print(
|
||||
f" Память: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB"
|
||||
)
|
||||
|
||||
# Создаем модель
|
||||
print("\nСоздание модели...")
|
||||
model = create_image_gan(
|
||||
input_channels=3,
|
||||
output_channels=3,
|
||||
gan_mode=config.get("gan_mode", "vanilla"),
|
||||
lambda_L1=config.get("lambda_L1", 100.0),
|
||||
use_cuda=(device.type == "cuda"),
|
||||
)
|
||||
|
||||
# Создаем даталоадеры
|
||||
print("\nСоздание даталоадеров...")
|
||||
# ЗАМЕНИТЕ ЭТО НА ВАШИ РЕАЛЬНЫЕ ДАННЫЕ
|
||||
# Пример:
|
||||
# from your_dataset_module import create_data_loaders
|
||||
# train_loader, val_loader = create_data_loaders(
|
||||
# data_dir="ваш/путь/к/данным",
|
||||
# batch_size=config["batch_size"],
|
||||
# image_size=(256, 256),
|
||||
# augment=config.get("augmentation", None),
|
||||
# )
|
||||
|
||||
# Для примера создаем фиктивные даталоадеры
|
||||
# ВАЖНО: Замените это на реальные данные!
|
||||
print(" ВНИМАНИЕ: Используются фиктивные данные!")
|
||||
print(" Замените на реальные даталоадеры!")
|
||||
|
||||
import numpy as np
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
class DummyDataset(Dataset):
|
||||
def __init__(self, num_samples=100):
|
||||
self.num_samples = num_samples
|
||||
|
||||
def __len__(self):
|
||||
return self.num_samples
|
||||
|
||||
def __getitem__(self, idx):
|
||||
# Фиктивные данные для примера
|
||||
yandex_img = torch.randn(3, 256, 256)
|
||||
google_img = torch.randn(3, 256, 256)
|
||||
return {"yandex_img": yandex_img, "google_img": google_img}
|
||||
|
||||
train_dataset = DummyDataset(num_samples=100)
|
||||
val_dataset = DummyDataset(num_samples=20)
|
||||
|
||||
train_loader = DataLoader(
|
||||
train_dataset,
|
||||
batch_size=config.get("batch_size", 4),
|
||||
shuffle=True,
|
||||
num_workers=0,
|
||||
)
|
||||
|
||||
val_loader = DataLoader(
|
||||
val_dataset,
|
||||
batch_size=config.get("batch_size", 4),
|
||||
shuffle=False,
|
||||
num_workers=0,
|
||||
)
|
||||
|
||||
print(f" Размер обучающего набора: {len(train_dataset)}")
|
||||
print(f" Размер валидационного набора: {len(val_dataset)}")
|
||||
print(f" Батчей в эпохе: {len(train_loader)}")
|
||||
|
||||
# Создаем тренер
|
||||
print("\nСоздание тренера...")
|
||||
trainer = GANTrainer(
|
||||
model=model,
|
||||
train_loader=train_loader,
|
||||
val_loader=val_loader,
|
||||
device=device,
|
||||
config=config,
|
||||
)
|
||||
|
||||
return trainer, config
|
||||
|
||||
|
||||
def train_model(trainer, config):
|
||||
"""Запускает обучение модели."""
|
||||
print("\n" + "=" * 60)
|
||||
print("Начало обучения")
|
||||
print("=" * 60)
|
||||
|
||||
epochs = config.get("epochs", 100)
|
||||
|
||||
try:
|
||||
trainer.train(num_epochs=epochs)
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("Обучение завершено успешно!")
|
||||
print("=" * 60)
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print("\n\nОбучение прервано пользователем.")
|
||||
print("Сохранение текущего состояния...")
|
||||
trainer.save_checkpoint(is_best=False)
|
||||
|
||||
except Exception as e:
|
||||
print(f"\n\nОшибка при обучении: {e}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
|
||||
# Пытаемся сохранить чекпоинт при ошибке
|
||||
try:
|
||||
trainer.save_checkpoint(is_best=False)
|
||||
print("Текущее состояние сохранено.")
|
||||
except:
|
||||
print("Не удалось сохранить состояние.")
|
||||
|
||||
|
||||
def evaluate_model(trainer, test_loader=None):
|
||||
"""Оценивает обученную модель."""
|
||||
print("\n" + "=" * 60)
|
||||
print("Оценка модели")
|
||||
print("=" * 60)
|
||||
|
||||
if test_loader is None:
|
||||
print("Тестовый даталоадер не предоставлен.")
|
||||
print("Используется валидационный даталоадер для оценки.")
|
||||
test_loader = trainer.val_loader
|
||||
|
||||
metrics = trainer.evaluate(test_loader)
|
||||
|
||||
print("\nМетрики оценки:")
|
||||
for key, value in metrics.items():
|
||||
print(f" {key}: {value:.6f}")
|
||||
|
||||
return metrics
|
||||
|
||||
|
||||
def generate_examples(model, device, num_examples=4):
|
||||
"""Генерирует примеры преобразования."""
|
||||
print("\n" + "=" * 60)
|
||||
print("Генерация примеров")
|
||||
print("=" * 60)
|
||||
|
||||
model.eval()
|
||||
|
||||
# Создаем фиктивные входные данные
|
||||
yandex_input = torch.randn(num_examples, 3, 256, 256).to(device)
|
||||
|
||||
with torch.no_grad():
|
||||
google_output = model(yandex_input)
|
||||
|
||||
print(f"Сгенерировано {num_examples} примеров")
|
||||
print(f"Размер входных данных: {yandex_input.shape}")
|
||||
print(f"Размер выходных данных: {google_output.shape}")
|
||||
|
||||
# Сохраняем примеры (в реальном коде сохраняйте как изображения)
|
||||
print("\nПримеры сгенерированы.")
|
||||
print("В реальном коде сохраняйте их как изображения для визуализации.")
|
||||
|
||||
return yandex_input, google_output
|
||||
|
||||
|
||||
def main():
|
||||
"""Основная функция."""
|
||||
print("=" * 60)
|
||||
print("Пример обучения GAN для преобразования Yandex → Google")
|
||||
print("=" * 60)
|
||||
|
||||
# Настройка
|
||||
trainer, config = setup_training()
|
||||
|
||||
# Обучение
|
||||
train_model(trainer, config)
|
||||
|
||||
# Оценка (требует реальных тестовых данных)
|
||||
# evaluate_model(trainer)
|
||||
|
||||
# Генерация примеров
|
||||
# generate_examples(trainer.model, trainer.device)
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("Скрипт завершен.")
|
||||
print("=" * 60)
|
||||
print("\nСледующие шаги:")
|
||||
print("1. Замените фиктивные даталоадеры на реальные данные")
|
||||
print("2. Настройте параметры в create_simple_config()")
|
||||
print("3. Запустите обучение с реальными данными")
|
||||
print("4. Визуализируйте результаты")
|
||||
print("=" * 60)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
415
models/GAN/trainer.py
Normal file
415
models/GAN/trainer.py
Normal file
@@ -0,0 +1,415 @@
|
||||
import json
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
from torch.utils.data import DataLoader
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
from tqdm import tqdm
|
||||
|
||||
# Type aliases
|
||||
ModuleType = nn.Module
|
||||
TensorType = torch.Tensor
|
||||
|
||||
|
||||
class GANTrainer:
|
||||
"""Trainer class for GAN model."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: ModuleType,
|
||||
train_loader: DataLoader,
|
||||
val_loader: DataLoader,
|
||||
device: torch.device,
|
||||
config: Dict[str, Any],
|
||||
):
|
||||
"""
|
||||
Initialize the GAN trainer.
|
||||
|
||||
Args:
|
||||
model: GAN model (ImageGAN)
|
||||
train_loader: Training data loader
|
||||
val_loader: Validation data loader
|
||||
device: Device to run training on
|
||||
config: Training configuration dictionary
|
||||
"""
|
||||
self.model = model.to(device)
|
||||
self.train_loader = train_loader
|
||||
self.val_loader = val_loader
|
||||
self.device = device
|
||||
self.config = config
|
||||
|
||||
# Optimizers
|
||||
lr = config.get("learning_rate", 2e-4)
|
||||
beta1 = config.get("beta1", 0.5)
|
||||
beta2 = config.get("beta2", 0.999)
|
||||
|
||||
# Separate optimizers for generator and discriminator
|
||||
# Note: self.model is expected to have .generator and .discriminator attributes
|
||||
self.optimizer_G = optim.Adam(
|
||||
self.model.generator.parameters(), lr=lr, betas=(beta1, beta2)
|
||||
)
|
||||
self.optimizer_D = optim.Adam(
|
||||
self.model.discriminator.parameters(), lr=lr, betas=(beta1, beta2)
|
||||
)
|
||||
|
||||
# Training state
|
||||
self.current_epoch = 0
|
||||
self.best_val_loss = float("inf")
|
||||
self.g_losses: List[float] = []
|
||||
self.d_losses: List[float] = []
|
||||
self.val_g_losses: List[float] = []
|
||||
self.val_d_losses: List[float] = []
|
||||
|
||||
# Create output directory
|
||||
self.output_dir = Path(config.get("output_dir", "runs/gan"))
|
||||
self.output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# TensorBoard writer
|
||||
self.writer = SummaryWriter(log_dir=self.output_dir / "tensorboard")
|
||||
|
||||
# Save configuration
|
||||
config_path = self.output_dir / "config.json"
|
||||
with open(config_path, "w") as f:
|
||||
json.dump(config, f, indent=2)
|
||||
|
||||
print(f"Training configuration saved to {config_path}")
|
||||
# Access parameters through the model's generator and discriminator
|
||||
generator_params = sum(p.numel() for p in self.model.generator.parameters())
|
||||
discriminator_params = sum(
|
||||
p.numel() for p in self.model.discriminator.parameters()
|
||||
)
|
||||
|
||||
print(f"Generator has {generator_params:,} parameters")
|
||||
print(f"Discriminator has {discriminator_params:,} parameters")
|
||||
|
||||
def train_epoch(self) -> Tuple[float, float]:
|
||||
"""
|
||||
Train for one epoch.
|
||||
|
||||
Returns:
|
||||
Tuple of (average generator loss, average discriminator loss)
|
||||
"""
|
||||
self.model.train()
|
||||
total_g_loss = 0.0
|
||||
total_d_loss = 0.0
|
||||
num_batches = len(self.train_loader)
|
||||
|
||||
progress_bar = tqdm(self.train_loader, desc=f"Epoch {self.current_epoch + 1}")
|
||||
for batch_idx, batch in enumerate(progress_bar):
|
||||
# Move data to device
|
||||
yandex_img = batch["yandex_img"].to(self.device)
|
||||
google_img = batch["google_img"].to(self.device)
|
||||
|
||||
# ========== Train Discriminator ==========
|
||||
self.optimizer_D.zero_grad()
|
||||
|
||||
# Generate fake image
|
||||
with torch.no_grad():
|
||||
fake_google_img = self.model.generator(yandex_img)
|
||||
|
||||
# Discriminator loss - returns tuple of tensors
|
||||
d_loss_tuple = self.model.discriminator_step(
|
||||
yandex_img, google_img, fake_google_img
|
||||
)
|
||||
d_loss, d_real_loss, d_fake_loss = d_loss_tuple
|
||||
|
||||
# Backward and optimize discriminator
|
||||
d_loss.backward()
|
||||
self.optimizer_D.step()
|
||||
|
||||
# ========== Train Generator ==========
|
||||
self.optimizer_G.zero_grad()
|
||||
|
||||
# Generate fake image
|
||||
fake_google_img = self.model.generator(yandex_img)
|
||||
|
||||
# Generator loss - returns tuple of tensors
|
||||
g_loss_tuple = self.model.generator_step(yandex_img, google_img)
|
||||
g_loss, g_gan_loss, g_l1_loss = g_loss_tuple
|
||||
|
||||
# Backward and optimize generator
|
||||
g_loss.backward()
|
||||
self.optimizer_G.step()
|
||||
|
||||
# Update statistics
|
||||
total_g_loss += g_loss.item()
|
||||
total_d_loss += d_loss.item()
|
||||
|
||||
# Update progress bar
|
||||
progress_bar.set_postfix(
|
||||
{
|
||||
"g_loss": g_loss.item(),
|
||||
"d_loss": d_loss.item(),
|
||||
"g_l1": g_l1_loss.item(),
|
||||
"d_real": d_real_loss.item(),
|
||||
"d_fake": d_fake_loss.item(),
|
||||
}
|
||||
)
|
||||
|
||||
# Log batch losses to TensorBoard
|
||||
global_step = self.current_epoch * num_batches + batch_idx
|
||||
self.writer.add_scalar("train/batch_g_loss", g_loss.item(), global_step)
|
||||
self.writer.add_scalar("train/batch_d_loss", d_loss.item(), global_step)
|
||||
self.writer.add_scalar(
|
||||
"train/batch_g_l1_loss", g_l1_loss.item(), global_step
|
||||
)
|
||||
self.writer.add_scalar(
|
||||
"train/batch_d_real_loss", d_real_loss.item(), global_step
|
||||
)
|
||||
self.writer.add_scalar(
|
||||
"train/batch_d_fake_loss", d_fake_loss.item(), global_step
|
||||
)
|
||||
|
||||
avg_g_loss = total_g_loss / num_batches
|
||||
avg_d_loss = total_d_loss / num_batches
|
||||
self.g_losses.append(avg_g_loss)
|
||||
self.d_losses.append(avg_d_loss)
|
||||
|
||||
return avg_g_loss, avg_d_loss
|
||||
|
||||
def validate(self) -> Tuple[float, float]:
|
||||
"""
|
||||
Validate the model.
|
||||
|
||||
Returns:
|
||||
Tuple of (average generator validation loss, average discriminator validation loss)
|
||||
"""
|
||||
self.model.eval()
|
||||
total_g_loss = 0.0
|
||||
total_d_loss = 0.0
|
||||
|
||||
progress_bar = tqdm(self.val_loader, desc="Validation")
|
||||
for batch in progress_bar:
|
||||
# Move data to device
|
||||
yandex_img = batch["yandex_img"].to(self.device)
|
||||
google_img = batch["google_img"].to(self.device)
|
||||
|
||||
with torch.no_grad():
|
||||
# Generate fake image
|
||||
fake_google_img = self.model.generator(yandex_img)
|
||||
|
||||
# Generator loss - returns tuple of tensors
|
||||
g_loss_tuple = self.model.generator_step(yandex_img, google_img)
|
||||
g_loss, g_gan_loss, g_l1_loss = g_loss_tuple
|
||||
|
||||
# Discriminator loss - returns tuple of tensors
|
||||
d_loss_tuple = self.model.discriminator_step(
|
||||
yandex_img, google_img, fake_google_img
|
||||
)
|
||||
d_loss, d_real_loss, d_fake_loss = d_loss_tuple
|
||||
|
||||
# Update statistics
|
||||
total_g_loss += g_loss.item()
|
||||
total_d_loss += d_loss.item()
|
||||
|
||||
# Update progress bar
|
||||
progress_bar.set_postfix({"g_loss": g_loss.item(), "d_loss": d_loss.item()})
|
||||
|
||||
avg_g_loss = total_g_loss / len(self.val_loader)
|
||||
avg_d_loss = total_d_loss / len(self.val_loader)
|
||||
self.val_g_losses.append(avg_g_loss)
|
||||
self.val_d_losses.append(avg_d_loss)
|
||||
|
||||
return avg_g_loss, avg_d_loss
|
||||
|
||||
def save_checkpoint(self, is_best: bool = False):
|
||||
"""
|
||||
Save training checkpoint.
|
||||
|
||||
Args:
|
||||
is_best: Whether this is the best model so far
|
||||
"""
|
||||
checkpoint = {
|
||||
"epoch": self.current_epoch,
|
||||
"model_state_dict": self.model.state_dict(),
|
||||
"optimizer_G_state_dict": self.optimizer_G.state_dict(),
|
||||
"optimizer_D_state_dict": self.optimizer_D.state_dict(),
|
||||
"g_losses": self.g_losses,
|
||||
"d_losses": self.d_losses,
|
||||
"val_g_losses": self.val_g_losses,
|
||||
"val_d_losses": self.val_d_losses,
|
||||
"best_val_loss": self.best_val_loss,
|
||||
"config": self.config,
|
||||
}
|
||||
|
||||
# Save regular checkpoint
|
||||
checkpoint_path = (
|
||||
self.output_dir / f"checkpoint_epoch_{self.current_epoch + 1}.pth"
|
||||
)
|
||||
torch.save(checkpoint, checkpoint_path)
|
||||
|
||||
# Save best model
|
||||
if is_best:
|
||||
best_path = self.output_dir / "model_best.pth"
|
||||
torch.save(checkpoint, best_path)
|
||||
print(f"Best model saved to {best_path}")
|
||||
|
||||
def load_checkpoint(self, checkpoint_path: str, resume_training: bool = False):
|
||||
"""
|
||||
Load training checkpoint.
|
||||
|
||||
Args:
|
||||
checkpoint_path: Path to checkpoint file
|
||||
resume_training: Если True, продолжить обучение с сохраненной эпохи
|
||||
"""
|
||||
checkpoint = torch.load(checkpoint_path, map_location=self.device)
|
||||
|
||||
self.model.load_state_dict(checkpoint["model_state_dict"])
|
||||
self.optimizer_G.load_state_dict(checkpoint["optimizer_G_state_dict"])
|
||||
self.optimizer_D.load_state_dict(checkpoint["optimizer_D_state_dict"])
|
||||
|
||||
self.current_epoch = checkpoint["epoch"]
|
||||
self.g_losses = checkpoint["g_losses"]
|
||||
self.d_losses = checkpoint["d_losses"]
|
||||
self.val_g_losses = checkpoint["val_g_losses"]
|
||||
self.val_d_losses = checkpoint["val_d_losses"]
|
||||
self.best_val_loss = checkpoint["best_val_loss"]
|
||||
|
||||
if resume_training:
|
||||
print(f"Resuming training from epoch {self.current_epoch + 1}")
|
||||
else:
|
||||
print(f"Loaded checkpoint from epoch {self.current_epoch + 1}")
|
||||
|
||||
def train(self, num_epochs: int, start_epoch: int = 0):
|
||||
"""
|
||||
Train the model for specified number of epochs.
|
||||
|
||||
Args:
|
||||
num_epochs: Number of epochs to train
|
||||
start_epoch: Starting epoch (useful when resuming training)
|
||||
"""
|
||||
print(
|
||||
f"Starting GAN training for {num_epochs} epochs from epoch {start_epoch + 1}..."
|
||||
)
|
||||
start_time = time.time()
|
||||
|
||||
for epoch in range(start_epoch, start_epoch + num_epochs):
|
||||
self.current_epoch = epoch
|
||||
|
||||
# Train for one epoch
|
||||
train_g_loss, train_d_loss = self.train_epoch()
|
||||
|
||||
# Validate
|
||||
val_g_loss, val_d_loss = self.validate()
|
||||
|
||||
# Log to TensorBoard
|
||||
self.writer.add_scalar("train/epoch_g_loss", train_g_loss, epoch)
|
||||
self.writer.add_scalar("train/epoch_d_loss", train_d_loss, epoch)
|
||||
self.writer.add_scalar("val/epoch_g_loss", val_g_loss, epoch)
|
||||
self.writer.add_scalar("val/epoch_d_loss", val_d_loss, epoch)
|
||||
|
||||
# Print epoch summary
|
||||
print(f"\nEpoch {epoch + 1}/{num_epochs}:")
|
||||
print(" Generator:")
|
||||
print(f" Train Loss: {train_g_loss:.6f}")
|
||||
print(f" Val Loss: {val_g_loss:.6f}")
|
||||
print(" Discriminator:")
|
||||
print(f" Train Loss: {train_d_loss:.6f}")
|
||||
print(f" Val Loss: {val_d_loss:.6f}")
|
||||
|
||||
# Save checkpoint
|
||||
val_total_loss = val_g_loss + val_d_loss
|
||||
is_best = val_total_loss < self.best_val_loss
|
||||
if is_best:
|
||||
self.best_val_loss = val_total_loss
|
||||
|
||||
self.save_checkpoint(is_best=is_best)
|
||||
|
||||
# Early stopping
|
||||
if self.config.get("early_stopping_patience", 0) > 0:
|
||||
val_losses = [
|
||||
g + d for g, d in zip(self.val_g_losses, self.val_d_losses)
|
||||
]
|
||||
if (
|
||||
epoch - np.argmin(val_losses)
|
||||
>= self.config["early_stopping_patience"]
|
||||
):
|
||||
print(f"Early stopping at epoch {epoch + 1}")
|
||||
break
|
||||
|
||||
# Training completed
|
||||
training_time = time.time() - start_time
|
||||
print(f"\nTraining completed in {training_time:.2f} seconds")
|
||||
print(f"Best validation total loss: {self.best_val_loss:.6f}")
|
||||
|
||||
# Save final model
|
||||
final_model_path = self.output_dir / "model_final.pth"
|
||||
torch.save(self.model.state_dict(), final_model_path)
|
||||
print(f"Final model saved to {final_model_path}")
|
||||
|
||||
# Save training history
|
||||
history_path = self.output_dir / "training_history.json"
|
||||
history = {
|
||||
"g_losses": self.g_losses,
|
||||
"d_losses": self.d_losses,
|
||||
"val_g_losses": self.val_g_losses,
|
||||
"val_d_losses": self.val_d_losses,
|
||||
"best_val_loss": self.best_val_loss,
|
||||
"total_epochs": self.current_epoch + 1,
|
||||
}
|
||||
with open(history_path, "w") as f:
|
||||
json.dump(history, f, indent=2)
|
||||
print(f"Training history saved to {history_path}")
|
||||
|
||||
# Close TensorBoard writer
|
||||
self.writer.close()
|
||||
|
||||
def evaluate(self, test_loader: DataLoader) -> Dict:
|
||||
"""
|
||||
Evaluate the model on test data.
|
||||
|
||||
Args:
|
||||
test_loader: Test data loader
|
||||
|
||||
Returns:
|
||||
Dictionary with evaluation metrics
|
||||
"""
|
||||
self.model.eval()
|
||||
total_g_loss = 0.0
|
||||
total_d_loss = 0.0
|
||||
|
||||
print("Evaluating model on test set...")
|
||||
|
||||
for batch in tqdm(test_loader, desc="Evaluation"):
|
||||
# Move data to device
|
||||
yandex_img = batch["yandex_img"].to(self.device)
|
||||
google_img = batch["google_img"].to(self.device)
|
||||
|
||||
with torch.no_grad():
|
||||
# Generate fake image
|
||||
fake_google_img = self.model.generator(yandex_img)
|
||||
|
||||
# Generator loss - returns tuple of tensors
|
||||
g_loss_tuple = self.model.generator_step(yandex_img, google_img)
|
||||
g_loss, g_gan_loss, g_l1_loss = g_loss_tuple
|
||||
|
||||
# Discriminator loss - returns tuple of tensors
|
||||
d_loss_tuple = self.model.discriminator_step(
|
||||
yandex_img, google_img, fake_google_img
|
||||
)
|
||||
d_loss, d_real_loss, d_fake_loss = d_loss_tuple
|
||||
|
||||
# Update statistics
|
||||
total_g_loss += g_loss.item()
|
||||
total_d_loss += d_loss.item()
|
||||
|
||||
avg_g_loss = total_g_loss / len(test_loader)
|
||||
avg_d_loss = total_d_loss / len(test_loader)
|
||||
|
||||
metrics = {
|
||||
"generator_loss": avg_g_loss,
|
||||
"discriminator_loss": avg_d_loss,
|
||||
"total_loss": avg_g_loss + avg_d_loss,
|
||||
}
|
||||
|
||||
print("\nTest Results:")
|
||||
print(f" Generator Loss: {avg_g_loss:.6f}")
|
||||
print(f" Discriminator Loss: {avg_d_loss:.6f}")
|
||||
print(f" Total Loss: {avg_g_loss + avg_d_loss:.6f}")
|
||||
|
||||
return metrics
|
||||
1
models/SiaN/.gitignore
vendored
Normal file
1
models/SiaN/.gitignore
vendored
Normal file
@@ -0,0 +1 @@
|
||||
runs
|
||||
295
models/SiaN/README_homography.md
Normal file
295
models/SiaN/README_homography.md
Normal file
@@ -0,0 +1,295 @@
|
||||
# Homography Estimation System
|
||||
|
||||
This system provides a complete pipeline for estimating homography matrices between Google and Yandex map images using deep learning.
|
||||
|
||||
## Overview
|
||||
|
||||
Homography estimation is crucial for aligning images from different sources (Google Maps and Yandex Maps in this case). The system includes:
|
||||
|
||||
1. **Dataset handling** - Loading and preprocessing image pairs
|
||||
2. **Data augmentation** - Homography-based augmentation for robust training
|
||||
3. **CNN model** - Deep learning model for homography estimation
|
||||
4. **Training pipeline** - Complete training and evaluation workflow
|
||||
5. **Inference tools** - Tools for using trained models on new data
|
||||
|
||||
## Installation
|
||||
|
||||
### Prerequisites
|
||||
- Python 3.8+
|
||||
- PyTorch 1.9+
|
||||
- OpenCV
|
||||
- PIL/Pillow
|
||||
- NumPy
|
||||
|
||||
### Install dependencies
|
||||
```bash
|
||||
pip install torch torchvision opencv-python pillow numpy matplotlib tqdm tensorboard
|
||||
```
|
||||
|
||||
## Dataset Structure
|
||||
|
||||
The system expects image pairs in the following format:
|
||||
```
|
||||
dataset/
|
||||
├── 0000_google.png
|
||||
├── 0000_yandex.png
|
||||
├── 0001_google.png
|
||||
├── 0001_yandex.png
|
||||
└── ...
|
||||
```
|
||||
|
||||
Each pair consists of:
|
||||
- `{idx:04d}_google.png` - Google map image
|
||||
- `{idx:04d}_yandex.png` - Yandex map image
|
||||
|
||||
## Quick Start
|
||||
|
||||
### 1. Explore the dataset
|
||||
```python
|
||||
from models.homography import HomographyDataset
|
||||
|
||||
dataset = HomographyDataset(
|
||||
root_dir=r"C:\Users\admin\Projects\autopilot\datasets\ya_go_maps\images",
|
||||
augment=True,
|
||||
image_size=(256, 256)
|
||||
)
|
||||
|
||||
print(f"Found {len(dataset)} image pairs")
|
||||
sample = dataset[0]
|
||||
print(f"Sample homography matrix:\n{sample['homography'].numpy()}")
|
||||
```
|
||||
|
||||
### 2. Train a model
|
||||
```bash
|
||||
python models/train_homography.py \
|
||||
--data_dir "C:\Users\admin\Projects\autopilot\datasets\ya_go_maps\images" \
|
||||
--epochs 50 \
|
||||
--batch_size 16 \
|
||||
--lr 1e-3 \
|
||||
--output_dir "runs/my_experiment"
|
||||
```
|
||||
|
||||
### 3. Perform inference
|
||||
```bash
|
||||
python models/infer_homography.py \
|
||||
--model_path "runs/my_experiment/checkpoint_best.pth" \
|
||||
--mode single \
|
||||
--google_path "path/to/google.png" \
|
||||
--yandex_path "path/to/yandex.png" \
|
||||
--output_vis "alignment_result.png"
|
||||
```
|
||||
|
||||
## File Structure
|
||||
|
||||
```
|
||||
models/
|
||||
├── homography.py # Dataset class and data loaders
|
||||
├── homography_cnn.py # CNN model architecture
|
||||
├── train_homography.py # Training script
|
||||
├── infer_homography.py # Inference script
|
||||
├── example_homography.py # Example usage
|
||||
├── homography.ipynb # Jupyter notebook (empty)
|
||||
└── README_homography.md # This file
|
||||
```
|
||||
|
||||
## Model Architecture
|
||||
|
||||
The homography estimation model (`HomographyCNN`) consists of:
|
||||
|
||||
1. **Dual encoders** - Separate feature extraction for Google and Yandex images
|
||||
2. **Residual blocks** - For deep feature learning
|
||||
3. **Fusion layers** - Combine features from both images
|
||||
4. **Regression head** - Predict 3x3 homography matrix
|
||||
|
||||
### Key Features:
|
||||
- Residual connections for stable training
|
||||
- Batch normalization options
|
||||
- Dropout for regularization
|
||||
- Geometric consistency loss
|
||||
|
||||
## Training Configuration
|
||||
|
||||
Default training parameters:
|
||||
- **Optimizer**: Adam with learning rate 1e-3
|
||||
- **Loss function**: Combined matrix + geometric + regularization loss
|
||||
- **Batch size**: 32
|
||||
- **Image size**: 256x256
|
||||
- **Train/val split**: 80/20
|
||||
|
||||
## Inference Modes
|
||||
|
||||
The inference script supports three modes:
|
||||
|
||||
### 1. Single image pair
|
||||
```bash
|
||||
python infer_homography.py --mode single \
|
||||
--google_path google.png --yandex_path yandex.png
|
||||
```
|
||||
|
||||
### 2. Dataset evaluation
|
||||
```bash
|
||||
python infer_homography.py --mode dataset \
|
||||
--dataset_dir path/to/dataset --num_samples 100
|
||||
```
|
||||
|
||||
### 3. Batch processing
|
||||
```bash
|
||||
python infer_homography.py --mode batch \
|
||||
--input_dir path/to/input --output_dir path/to/output
|
||||
```
|
||||
|
||||
## Evaluation Metrics
|
||||
|
||||
The system computes several metrics:
|
||||
- **Matrix MSE**: Mean squared error of homography matrix elements
|
||||
- **Corner error**: Average pixel error at image corners
|
||||
- **Geometric consistency**: Warping error across grid points
|
||||
|
||||
## Data Augmentation
|
||||
|
||||
The dataset applies homography-based augmentation:
|
||||
- Random rotation (-30° to 30°)
|
||||
- Random scaling (0.8x to 1.2x)
|
||||
- Random translation (-50 to 50 pixels)
|
||||
- Small perspective distortion
|
||||
|
||||
## Integration with Autopilot System
|
||||
|
||||
The homography estimation can be integrated into the autopilot system:
|
||||
|
||||
```python
|
||||
from models.infer_homography import HomographyInference
|
||||
|
||||
# Initialize inference
|
||||
inference = HomographyInference(model_path="path/to/model.pth")
|
||||
|
||||
# During flight loop
|
||||
homography = inference.predict(google_img, yandex_img)
|
||||
# Use homography to update drone position
|
||||
```
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### Common Issues
|
||||
|
||||
1. **Out of memory**
|
||||
- Reduce batch size
|
||||
- Use smaller image size
|
||||
- Enable gradient checkpointing
|
||||
|
||||
2. **Poor convergence**
|
||||
- Adjust learning rate
|
||||
- Increase model capacity
|
||||
- Add more data augmentation
|
||||
|
||||
3. **Inference errors**
|
||||
- Check image formats (must be RGB)
|
||||
- Verify model was trained with same image size
|
||||
- Ensure proper normalization
|
||||
|
||||
### Debugging Tips
|
||||
|
||||
- Use `example_homography.py` to test components
|
||||
- Enable TensorBoard for training visualization
|
||||
- Check homography matrix normalization (last element should be ~1)
|
||||
|
||||
## Advanced Usage
|
||||
|
||||
### Custom Model Architecture
|
||||
```python
|
||||
from models.homography_cnn import create_homography_model
|
||||
|
||||
model = create_homography_model(
|
||||
model_type="cnn",
|
||||
input_size=(512, 512),
|
||||
hidden_channels=128,
|
||||
num_blocks=6,
|
||||
dropout_rate=0.5
|
||||
)
|
||||
```
|
||||
|
||||
### Custom Loss Function
|
||||
```python
|
||||
from models.homography_cnn import HomographyLoss
|
||||
|
||||
loss_fn = HomographyLoss(
|
||||
matrix_weight=0.7,
|
||||
geometric_weight=0.3,
|
||||
reg_weight=0.05,
|
||||
grid_size=16
|
||||
)
|
||||
```
|
||||
|
||||
### Transfer Learning
|
||||
```python
|
||||
# Load pretrained model
|
||||
checkpoint = torch.load("pretrained.pth")
|
||||
model.load_state_dict(checkpoint["model_state_dict"])
|
||||
|
||||
# Fine-tune on new data
|
||||
for param in model.google_encoder.parameters():
|
||||
param.requires_grad = False # Freeze encoder
|
||||
```
|
||||
|
||||
## Performance Tips
|
||||
|
||||
1. **Training speed**
|
||||
- Use multiple GPU workers for data loading
|
||||
- Enable mixed precision training
|
||||
- Use gradient accumulation for larger effective batch sizes
|
||||
|
||||
2. **Memory efficiency**
|
||||
- Use gradient checkpointing
|
||||
- Implement progressive resizing
|
||||
- Use memory-efficient optimizers
|
||||
|
||||
3. **Inference speed**
|
||||
- Use TensorRT or ONNX for deployment
|
||||
- Implement model quantization
|
||||
- Use batch inference when possible
|
||||
|
||||
## Future Improvements
|
||||
|
||||
1. **Model enhancements**
|
||||
- Transformer-based architecture
|
||||
- Multi-scale feature fusion
|
||||
- Uncertainty estimation
|
||||
|
||||
2. **Training improvements**
|
||||
- Self-supervised pre-training
|
||||
- Curriculum learning
|
||||
- Adversarial training
|
||||
|
||||
3. **Deployment features**
|
||||
- Real-time inference optimization
|
||||
- Mobile deployment
|
||||
- Web API
|
||||
|
||||
## Citation
|
||||
|
||||
If you use this system in your research, please cite:
|
||||
|
||||
```bibtex
|
||||
@software{homography_estimation_2024,
|
||||
title = {Homography Estimation for Map Alignment},
|
||||
author = {Autopilot Team},
|
||||
year = {2024},
|
||||
url = {https://github.com/your-repo/homography-estimation}
|
||||
}
|
||||
```
|
||||
|
||||
## License
|
||||
|
||||
This project is licensed under the MIT License - see the LICENSE file for details.
|
||||
|
||||
## Support
|
||||
|
||||
For questions and issues:
|
||||
1. Check the troubleshooting section
|
||||
2. Review the example scripts
|
||||
3. Open an issue on GitHub
|
||||
4. Contact the development team
|
||||
|
||||
---
|
||||
|
||||
*Last updated: October 2024*
|
||||
345
models/SiaN/example_homography.py
Normal file
345
models/SiaN/example_homography.py
Normal file
@@ -0,0 +1,345 @@
|
||||
"""
|
||||
Example script demonstrating the complete homography estimation workflow.
|
||||
|
||||
This script shows how to:
|
||||
1. Load the ya_go_maps dataset
|
||||
2. Create and train a homography estimation model
|
||||
3. Perform inference on new image pairs
|
||||
4. Visualize results
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
# Add parent directory to path for imports
|
||||
sys.path.append(str(Path(__file__).parent.parent))
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
|
||||
from models.homography import HomographyDataset, create_data_loaders
|
||||
from models.homography_cnn import HomographyCNN, HomographyLoss, create_homography_model
|
||||
from models.infer_homography import HomographyInference
|
||||
from models.train_homography import HomographyTrainer
|
||||
|
||||
|
||||
def example_dataset_loading():
|
||||
"""Example 1: Loading and exploring the dataset."""
|
||||
print("=" * 60)
|
||||
print("Example 1: Loading and exploring the dataset")
|
||||
print("=" * 60)
|
||||
|
||||
# Path to the dataset
|
||||
dataset_path = r"C:\Users\admin\Projects\autopilot\datasets\ya_go_maps\images"
|
||||
|
||||
# Create dataset
|
||||
dataset = HomographyDataset(
|
||||
root_dir=dataset_path,
|
||||
augment=True,
|
||||
image_size=(256, 256),
|
||||
cache_homographies=True,
|
||||
)
|
||||
|
||||
print(f"Dataset size: {len(dataset)} image pairs")
|
||||
|
||||
# Get a sample
|
||||
sample = dataset[0]
|
||||
print(f"\nSample keys: {list(sample.keys())}")
|
||||
print(f"Google image shape: {sample['google_img'].shape}")
|
||||
print(f"Yandex image shape: {sample['yandex_img'].shape}")
|
||||
print(f"Homography shape: {sample['homography'].shape}")
|
||||
|
||||
# Show sample homography matrix
|
||||
print(f"\nSample homography matrix:")
|
||||
print(sample["homography"].numpy())
|
||||
|
||||
# Get sample without augmentation for visualization
|
||||
raw_sample = dataset.get_sample_without_augmentation(0)
|
||||
|
||||
# Visualize sample
|
||||
fig, axes = plt.subplots(1, 2, figsize=(10, 5))
|
||||
|
||||
axes[0].imshow(raw_sample["google_img"])
|
||||
axes[0].set_title("Google Map")
|
||||
axes[0].axis("off")
|
||||
|
||||
axes[1].imshow(raw_sample["yandex_img"])
|
||||
axes[1].set_title("Yandex Map")
|
||||
axes[1].axis("off")
|
||||
|
||||
plt.suptitle("Sample Image Pair (without augmentation)")
|
||||
plt.tight_layout()
|
||||
plt.show()
|
||||
|
||||
return dataset
|
||||
|
||||
|
||||
def example_model_creation():
|
||||
"""Example 2: Creating and testing the model."""
|
||||
print("\n" + "=" * 60)
|
||||
print("Example 2: Creating and testing the model")
|
||||
print("=" * 60)
|
||||
|
||||
# Set device
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
print(f"Using device: {device}")
|
||||
|
||||
# Create model
|
||||
model = HomographyCNN(
|
||||
input_channels=3,
|
||||
hidden_channels=64,
|
||||
num_blocks=4,
|
||||
dropout_rate=0.3,
|
||||
use_batch_norm=True,
|
||||
).to(device)
|
||||
|
||||
print(
|
||||
f"Model created with {sum(p.numel() for p in model.parameters()):,} parameters"
|
||||
)
|
||||
|
||||
# Create dummy input
|
||||
batch_size = 2
|
||||
height, width = 256, 256
|
||||
|
||||
google_img = torch.randn(batch_size, 3, height, width).to(device)
|
||||
yandex_img = torch.randn(batch_size, 3, height, width).to(device)
|
||||
|
||||
# Test forward pass
|
||||
print("\nTesting forward pass...")
|
||||
output = model(google_img, yandex_img, return_matrix=True)
|
||||
print(f"Output shape: {output.shape}") # Should be (2, 3, 3)
|
||||
print(f"Sample output matrix:")
|
||||
print(output[0].cpu().detach().numpy())
|
||||
|
||||
# Test loss function
|
||||
print("\nTesting loss function...")
|
||||
target_homography = torch.eye(3).unsqueeze(0).expand(batch_size, -1, -1).to(device)
|
||||
loss_fn = HomographyLoss(
|
||||
matrix_weight=1.0,
|
||||
geometric_weight=0.5,
|
||||
reg_weight=0.1,
|
||||
grid_size=8,
|
||||
).to(device)
|
||||
|
||||
loss = loss_fn(output, target_homography, google_img, yandex_img)
|
||||
print(f"Loss value: {loss.item():.6f}")
|
||||
|
||||
# Test metrics
|
||||
print("\nTesting metrics...")
|
||||
metrics = loss_fn.compute_metrics(output, target_homography)
|
||||
for key, value in metrics.items():
|
||||
print(f"{key}: {value:.6f}")
|
||||
|
||||
return model, loss_fn
|
||||
|
||||
|
||||
def example_data_loaders():
|
||||
"""Example 3: Creating data loaders for training."""
|
||||
print("\n" + "=" * 60)
|
||||
print("Example 3: Creating data loaders for training")
|
||||
print("=" * 60)
|
||||
|
||||
# Path to the dataset
|
||||
dataset_path = r"C:\Users\admin\Projects\autopilot\datasets\ya_go_maps\images"
|
||||
|
||||
# Create data loaders
|
||||
train_loader, val_loader = create_data_loaders(
|
||||
root_dir=dataset_path,
|
||||
batch_size=16,
|
||||
train_split=0.8,
|
||||
num_workers=0, # Use 0 for debugging, increase for training
|
||||
image_size=(256, 256),
|
||||
augment_train=True,
|
||||
augment_val=False,
|
||||
)
|
||||
|
||||
print(f"Train batches: {len(train_loader)}")
|
||||
print(f"Val batches: {len(val_loader)}")
|
||||
|
||||
# Get a batch from train loader
|
||||
batch = next(iter(train_loader))
|
||||
print(f"\nBatch keys: {list(batch.keys())}")
|
||||
print(f"Batch size: {batch['google_img'].shape[0]}")
|
||||
print(f"Image shape: {batch['google_img'].shape[1:]}")
|
||||
|
||||
return train_loader, val_loader
|
||||
|
||||
|
||||
def example_training_config():
|
||||
"""Example 4: Setting up training configuration."""
|
||||
print("\n" + "=" * 60)
|
||||
print("Example 4: Training configuration")
|
||||
print("=" * 60)
|
||||
|
||||
# Training configuration
|
||||
config = {
|
||||
# Model config
|
||||
"model_type": "cnn",
|
||||
"hidden_channels": 64,
|
||||
"num_blocks": 4,
|
||||
"dropout_rate": 0.3,
|
||||
"use_batch_norm": True,
|
||||
"image_size": [256, 256],
|
||||
# Training config
|
||||
"epochs": 50,
|
||||
"batch_size": 16,
|
||||
"learning_rate": 1e-3,
|
||||
"weight_decay": 1e-4,
|
||||
"optimizer": "adam",
|
||||
"scheduler": "plateau",
|
||||
"grad_clip": 1.0,
|
||||
# Loss config
|
||||
"matrix_weight": 1.0,
|
||||
"geometric_weight": 0.5,
|
||||
"reg_weight": 0.1,
|
||||
"grid_size": 8,
|
||||
# Data config
|
||||
"data_dir": r"C:\Users\admin\Projects\autopilot\datasets\ya_go_maps\images",
|
||||
"train_split": 0.8,
|
||||
"num_workers": 0,
|
||||
# Output config
|
||||
"output_dir": "runs/example_training",
|
||||
"seed": 42,
|
||||
}
|
||||
|
||||
print("Training configuration:")
|
||||
for key, value in config.items():
|
||||
print(f" {key}: {value}")
|
||||
|
||||
return config
|
||||
|
||||
|
||||
def example_inference():
|
||||
"""Example 5: Performing inference with a trained model."""
|
||||
print("\n" + "=" * 60)
|
||||
print("Example 5: Inference example")
|
||||
print("=" * 60)
|
||||
|
||||
# Note: This example assumes you have a trained model
|
||||
# For demonstration, we'll show the code structure
|
||||
|
||||
print("Inference workflow:")
|
||||
print("1. Load trained model")
|
||||
print("2. Preprocess input images")
|
||||
print("3. Predict homography matrix")
|
||||
print("4. Visualize alignment")
|
||||
|
||||
# Example code structure (commented out since we don't have a trained model yet)
|
||||
"""
|
||||
# Create inference object
|
||||
inference = HomographyInference(
|
||||
model_path="runs/homography/checkpoint_best.pth",
|
||||
device="cuda" if torch.cuda.is_available() else "cpu",
|
||||
)
|
||||
|
||||
# Load images
|
||||
google_img = Image.open("path/to/google.png")
|
||||
yandex_img = Image.open("path/to/yandex.png")
|
||||
|
||||
# Predict homography
|
||||
homography = inference.predict(google_img, yandex_img)
|
||||
|
||||
print(f"Predicted homography matrix:")
|
||||
print(homography.cpu().numpy())
|
||||
|
||||
# Visualize alignment
|
||||
inference.visualize_alignment(
|
||||
google_img,
|
||||
yandex_img,
|
||||
homography.cpu().numpy(),
|
||||
save_path="alignment_visualization.png",
|
||||
show=True,
|
||||
)
|
||||
"""
|
||||
|
||||
print(
|
||||
"\nNote: To run actual inference, first train a model using train_homography.py"
|
||||
)
|
||||
|
||||
|
||||
def example_quick_start():
|
||||
"""Example 6: Quick start guide."""
|
||||
print("\n" + "=" * 60)
|
||||
print("Example 6: Quick Start Guide")
|
||||
print("=" * 60)
|
||||
|
||||
print("\nTo get started with homography estimation:")
|
||||
print("\n1. First, explore the dataset:")
|
||||
print(
|
||||
' python -c "from models.homography import HomographyDataset; '
|
||||
"dataset = HomographyDataset(r'C:\\Users\\admin\\Projects\\autopilot\\datasets\\ya_go_maps\\images'); "
|
||||
"print(f'Found {len(dataset)} image pairs')\""
|
||||
)
|
||||
|
||||
print("\n2. Train a model:")
|
||||
print(" python models/train_homography.py \\")
|
||||
print(
|
||||
' --data_dir "C:\\Users\\admin\\Projects\\autopilot\\datasets\\ya_go_maps\\images" \\'
|
||||
)
|
||||
print(" --epochs 50 \\")
|
||||
print(" --batch_size 16 \\")
|
||||
print(' --output_dir "runs/my_experiment"')
|
||||
|
||||
print("\n3. Perform inference on a single image pair:")
|
||||
print(" python models/infer_homography.py \\")
|
||||
print(' --model_path "runs/my_experiment/checkpoint_best.pth" \\')
|
||||
print(" --mode single \\")
|
||||
print(' --google_path "path/to/google.png" \\')
|
||||
print(' --yandex_path "path/to/yandex.png" \\')
|
||||
print(' --output_vis "alignment_result.png"')
|
||||
|
||||
print("\n4. Evaluate on the entire dataset:")
|
||||
print(" python models/infer_homography.py \\")
|
||||
print(' --model_path "runs/my_experiment/checkpoint_best.pth" \\')
|
||||
print(" --mode dataset \\")
|
||||
print(
|
||||
' --dataset_dir "C:\\Users\\admin\\Projects\\autopilot\\datasets\\ya_go_maps\\images" \\'
|
||||
)
|
||||
print(' --save_results "evaluation_results.json"')
|
||||
|
||||
|
||||
def main():
|
||||
"""Run all examples."""
|
||||
print("Homography Estimation Workflow Examples")
|
||||
print("=" * 60)
|
||||
|
||||
try:
|
||||
# Example 1: Dataset loading
|
||||
dataset = example_dataset_loading()
|
||||
|
||||
# Example 2: Model creation
|
||||
model, loss_fn = example_model_creation()
|
||||
|
||||
# Example 3: Data loaders
|
||||
train_loader, val_loader = example_data_loaders()
|
||||
|
||||
# Example 4: Training configuration
|
||||
config = example_training_config()
|
||||
|
||||
# Example 5: Inference
|
||||
example_inference()
|
||||
|
||||
# Example 6: Quick start
|
||||
example_quick_start()
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("All examples completed successfully!")
|
||||
print("=" * 60)
|
||||
|
||||
print("\nNext steps:")
|
||||
print("1. Run the training script to train a model")
|
||||
print("2. Use the inference script to test the trained model")
|
||||
print("3. Integrate the homography estimation into your autopilot system")
|
||||
|
||||
except Exception as e:
|
||||
print(f"\nError running examples: {e}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
2104
models/SiaN/homography.ipynb
Normal file
2104
models/SiaN/homography.ipynb
Normal file
File diff suppressed because one or more lines are too long
434
models/SiaN/homography.py
Normal file
434
models/SiaN/homography.py
Normal file
@@ -0,0 +1,434 @@
|
||||
import os
|
||||
import random
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
from torch.utils.data import DataLoader, Dataset
|
||||
|
||||
|
||||
class HomographyDataset(Dataset):
|
||||
"""
|
||||
Dataset for homography estimation between Yandex and Google map image pairs.
|
||||
|
||||
This dataset loads pairs of images (Yandex and Google maps) and provides
|
||||
homography matrices for data augmentation and training.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
root_dir: str,
|
||||
transform=None,
|
||||
augment: bool = True,
|
||||
max_samples: Optional[int] = None,
|
||||
image_size: Tuple[int, int] = (700, 700),
|
||||
cache_homographies: bool = True,
|
||||
):
|
||||
"""
|
||||
Initialize the HomographyDataset.
|
||||
|
||||
Args:
|
||||
root_dir: Directory containing image pairs (format: {idx:04d}_google.png, {idx:04d}_yandex.png)
|
||||
transform: Optional torchvision transforms to apply
|
||||
augment: Whether to apply homography-based data augmentation
|
||||
max_samples: Maximum number of samples to load (None for all)
|
||||
image_size: Target size for images (height, width)
|
||||
cache_homographies: Whether to cache generated homography matrices to disk
|
||||
"""
|
||||
self.root_dir = root_dir
|
||||
self.transform = transform
|
||||
self.augment = augment
|
||||
self.image_size = image_size
|
||||
self.cache_homographies = cache_homographies
|
||||
|
||||
# Find all image pairs
|
||||
self.image_pairs = self._discover_image_pairs()
|
||||
|
||||
if max_samples is not None:
|
||||
self.image_pairs = self.image_pairs[:max_samples]
|
||||
|
||||
print(f"Found {len(self.image_pairs)} image pairs in {root_dir}")
|
||||
|
||||
# Create directory for cached homographies if needed
|
||||
if cache_homographies:
|
||||
self.homography_cache_dir = os.path.join(root_dir, "homography_cache")
|
||||
os.makedirs(self.homography_cache_dir, exist_ok=True)
|
||||
|
||||
def _discover_image_pairs(self) -> List[Dict[str, Any]]:
|
||||
"""Discover all Google-Yandex image pairs in the dataset directory."""
|
||||
image_pairs = []
|
||||
|
||||
# Get all Google images
|
||||
google_files = [
|
||||
f for f in os.listdir(self.root_dir) if f.endswith("_google.png")
|
||||
]
|
||||
|
||||
for google_file in sorted(google_files):
|
||||
# Extract index from filename
|
||||
idx_str = google_file.split("_")[0]
|
||||
try:
|
||||
idx = int(idx_str)
|
||||
except ValueError:
|
||||
continue
|
||||
|
||||
# Check if corresponding Yandex image exists
|
||||
yandex_file = f"{idx:04d}_yandex.png"
|
||||
yandex_path = os.path.join(self.root_dir, yandex_file)
|
||||
|
||||
if os.path.exists(yandex_path):
|
||||
image_pairs.append(
|
||||
{
|
||||
"idx": idx,
|
||||
"google_path": os.path.join(self.root_dir, google_file),
|
||||
"yandex_path": yandex_path,
|
||||
}
|
||||
)
|
||||
|
||||
return image_pairs
|
||||
|
||||
def __len__(self) -> int:
|
||||
"""Return the number of image pairs in the dataset."""
|
||||
return len(self.image_pairs)
|
||||
|
||||
def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
|
||||
"""
|
||||
Get a sample from the dataset.
|
||||
|
||||
Returns a dictionary with:
|
||||
- 'google_img': Google map image tensor
|
||||
- 'yandex_img': Yandex map image tensor
|
||||
- 'homography': Ground truth homography matrix (3x3)
|
||||
- 'idx': Sample index
|
||||
"""
|
||||
pair_info = self.image_pairs[idx]
|
||||
|
||||
# Load images
|
||||
google_img = Image.open(pair_info["google_path"]).convert("RGB")
|
||||
yandex_img = Image.open(pair_info["yandex_path"]).convert("RGB")
|
||||
|
||||
# Resize images to target size
|
||||
google_img = google_img.resize(
|
||||
(self.image_size[1], self.image_size[0]), Image.BILINEAR
|
||||
)
|
||||
yandex_img = yandex_img.resize(
|
||||
(self.image_size[1], self.image_size[0]), Image.BILINEAR
|
||||
)
|
||||
|
||||
# Get or generate homography matrix
|
||||
homography_matrix = self._get_homography_matrix(pair_info["idx"])
|
||||
|
||||
# Apply data augmentation if enabled
|
||||
if self.augment:
|
||||
google_img, yandex_img, homography_matrix = self._apply_augmentation(
|
||||
google_img, yandex_img, homography_matrix
|
||||
)
|
||||
|
||||
# Convert images to tensors
|
||||
if self.transform:
|
||||
google_img = self.transform(google_img)
|
||||
yandex_img = self.transform(yandex_img)
|
||||
else:
|
||||
# Default conversion to tensor
|
||||
google_img = (
|
||||
torch.from_numpy(np.array(google_img)).float().permute(2, 0, 1) / 255.0
|
||||
)
|
||||
yandex_img = (
|
||||
torch.from_numpy(np.array(yandex_img)).float().permute(2, 0, 1) / 255.0
|
||||
)
|
||||
|
||||
# Convert homography to tensor
|
||||
homography_tensor = torch.from_numpy(homography_matrix).float()
|
||||
|
||||
return {
|
||||
"google_img": google_img,
|
||||
"yandex_img": yandex_img,
|
||||
"homography": homography_tensor,
|
||||
"idx": torch.tensor(pair_info["idx"], dtype=torch.long),
|
||||
}
|
||||
|
||||
def _get_homography_matrix(self, idx: int) -> np.ndarray:
|
||||
"""
|
||||
Get homography matrix for a given index.
|
||||
|
||||
If cached homography exists, load it. Otherwise generate a new one.
|
||||
"""
|
||||
if self.cache_homographies:
|
||||
cache_path = os.path.join(
|
||||
self.homography_cache_dir, f"{idx:04d}_homography.npy"
|
||||
)
|
||||
if os.path.exists(cache_path):
|
||||
return np.load(cache_path)
|
||||
|
||||
# Generate new homography matrix
|
||||
homography_matrix = self.generate_random_homography()
|
||||
|
||||
# Cache if enabled
|
||||
if self.cache_homographies:
|
||||
np.save(cache_path, homography_matrix)
|
||||
|
||||
return homography_matrix
|
||||
|
||||
def generate_random_homography(self) -> np.ndarray:
|
||||
"""
|
||||
Generate a random homography matrix for data augmentation.
|
||||
|
||||
Returns:
|
||||
np.ndarray: 3x3 homography matrix.
|
||||
"""
|
||||
# Generate random affine transformation parameters
|
||||
angle = np.random.uniform(-30, 30) # rotation in degrees
|
||||
scale = np.random.uniform(0.8, 1.2) # scaling factor
|
||||
tx = np.random.uniform(-50, 50) # translation in x
|
||||
ty = np.random.uniform(-50, 50) # translation in y
|
||||
|
||||
# Convert angle to radians
|
||||
theta = np.radians(angle)
|
||||
|
||||
# Create affine transformation matrix
|
||||
affine_matrix = np.array(
|
||||
[
|
||||
[scale * np.cos(theta), -scale * np.sin(theta), tx],
|
||||
[scale * np.sin(theta), scale * np.cos(theta), ty],
|
||||
[0, 0, 1],
|
||||
]
|
||||
)
|
||||
|
||||
# Add small perspective distortion
|
||||
perspective = np.random.uniform(-0.001, 0.001, (2, 3))
|
||||
perspective = np.vstack([perspective, [0, 0, 0]])
|
||||
|
||||
homography_matrix = affine_matrix + perspective
|
||||
|
||||
return homography_matrix
|
||||
|
||||
def _apply_augmentation(
|
||||
self,
|
||||
google_img: Image.Image,
|
||||
yandex_img: Image.Image,
|
||||
base_homography: np.ndarray,
|
||||
) -> Tuple[Image.Image, Image.Image, np.ndarray]:
|
||||
"""
|
||||
Apply homography-based data augmentation to image pair.
|
||||
|
||||
Args:
|
||||
google_img: Google map image
|
||||
yandex_img: Yandex map image
|
||||
base_homography: Base homography matrix
|
||||
|
||||
Returns:
|
||||
Tuple of (augmented_google_img, augmented_yandex_img, augmented_homography)
|
||||
"""
|
||||
# Generate augmentation homography
|
||||
aug_homography = self.generate_random_homography()
|
||||
|
||||
# Combine with base homography
|
||||
combined_homography = aug_homography @ base_homography
|
||||
|
||||
# Apply augmentation to both images
|
||||
google_aug = self._apply_homography_to_image(google_img, aug_homography)
|
||||
yandex_aug = self._apply_homography_to_image(yandex_img, aug_homography)
|
||||
|
||||
return google_aug, yandex_aug, combined_homography
|
||||
|
||||
def _apply_homography_to_image(
|
||||
self, img: Image.Image, homography: np.ndarray
|
||||
) -> Image.Image:
|
||||
"""
|
||||
Apply homography transformation to a single image.
|
||||
|
||||
Args:
|
||||
img: PIL Image to transform
|
||||
homography: 3x3 homography matrix
|
||||
|
||||
Returns:
|
||||
Transformed PIL Image
|
||||
"""
|
||||
# Convert to numpy array
|
||||
img_np = np.array(img)
|
||||
|
||||
# Get image dimensions
|
||||
h, w = img_np.shape[:2]
|
||||
|
||||
# Apply homography transformation
|
||||
transformed = cv2.warpPerspective(
|
||||
img_np,
|
||||
homography,
|
||||
(w, h),
|
||||
flags=cv2.INTER_LINEAR,
|
||||
borderMode=cv2.BORDER_REFLECT,
|
||||
)
|
||||
|
||||
# Convert back to PIL Image
|
||||
return Image.fromarray(transformed)
|
||||
|
||||
def get_sample_without_augmentation(self, idx: int) -> Dict[str, Any]:
|
||||
"""
|
||||
Get a sample without data augmentation.
|
||||
|
||||
Useful for visualization and evaluation.
|
||||
"""
|
||||
pair_info = self.image_pairs[idx]
|
||||
|
||||
# Load images
|
||||
google_img = Image.open(pair_info["google_path"]).convert("RGB")
|
||||
yandex_img = Image.open(pair_info["yandex_path"]).convert("RGB")
|
||||
|
||||
# Resize
|
||||
google_img = google_img.resize(
|
||||
(self.image_size[1], self.image_size[0]), Image.BILINEAR
|
||||
)
|
||||
yandex_img = yandex_img.resize(
|
||||
(self.image_size[1], self.image_size[0]), Image.BILINEAR
|
||||
)
|
||||
|
||||
# Get homography matrix
|
||||
homography_matrix = self._get_homography_matrix(pair_info["idx"])
|
||||
|
||||
return {
|
||||
"google_img": google_img,
|
||||
"yandex_img": yandex_img,
|
||||
"homography": homography_matrix,
|
||||
"idx": pair_info["idx"],
|
||||
"google_path": pair_info["google_path"],
|
||||
"yandex_path": pair_info["yandex_path"],
|
||||
}
|
||||
|
||||
|
||||
def create_data_loaders(
|
||||
root_dir: str,
|
||||
batch_size: int = 32,
|
||||
train_split: float = 0.8,
|
||||
num_workers: int = 4,
|
||||
image_size: Tuple[int, int] = (256, 256),
|
||||
augment_train: bool = True,
|
||||
augment_val: bool = False,
|
||||
) -> Tuple[DataLoader, DataLoader]:
|
||||
"""
|
||||
Create train and validation data loaders for homography estimation.
|
||||
|
||||
Args:
|
||||
root_dir: Directory containing image pairs
|
||||
batch_size: Batch size for data loaders
|
||||
train_split: Fraction of data to use for training
|
||||
num_workers: Number of worker processes for data loading
|
||||
image_size: Target image size (height, width)
|
||||
augment_train: Whether to augment training data
|
||||
augment_val: Whether to augment validation data
|
||||
|
||||
Returns:
|
||||
Tuple of (train_loader, val_loader)
|
||||
"""
|
||||
from torchvision import transforms
|
||||
|
||||
# Define transforms
|
||||
transform = transforms.Compose(
|
||||
[
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
||||
]
|
||||
)
|
||||
|
||||
# Create full dataset
|
||||
full_dataset = HomographyDataset(
|
||||
root_dir=root_dir,
|
||||
transform=transform,
|
||||
augment=False, # We'll handle augmentation separately
|
||||
image_size=image_size,
|
||||
cache_homographies=True,
|
||||
)
|
||||
|
||||
# Split dataset
|
||||
dataset_size = len(full_dataset)
|
||||
train_size = int(train_split * dataset_size)
|
||||
val_size = dataset_size - train_size
|
||||
|
||||
# Create indices for splitting
|
||||
indices = list(range(dataset_size))
|
||||
random.shuffle(indices)
|
||||
train_indices = indices[:train_size]
|
||||
val_indices = indices[train_size:]
|
||||
|
||||
# Create subset samplers
|
||||
from torch.utils.data import Subset
|
||||
|
||||
train_dataset = Subset(full_dataset, train_indices)
|
||||
val_dataset = Subset(full_dataset, val_indices)
|
||||
|
||||
# Apply augmentation by overriding __getitem__ for train dataset
|
||||
if augment_train:
|
||||
|
||||
class AugmentedSubset(Subset):
|
||||
def __getitem__(self, idx):
|
||||
sample = self.dataset[self.indices[idx]]
|
||||
# Apply augmentation
|
||||
google_img = sample["google_img"]
|
||||
yandex_img = sample["yandex_img"]
|
||||
homography = sample["homography"]
|
||||
|
||||
# Generate augmentation homography
|
||||
aug_homography = torch.from_numpy(
|
||||
full_dataset.generate_random_homography()
|
||||
).float()
|
||||
|
||||
# Combine homographies
|
||||
combined_homography = aug_homography @ homography
|
||||
|
||||
# Apply augmentation (simplified - in practice would warp images)
|
||||
# For now, we just return the combined homography
|
||||
return {
|
||||
"google_img": google_img,
|
||||
"yandex_img": yandex_img,
|
||||
"homography": combined_homography,
|
||||
"idx": sample["idx"],
|
||||
}
|
||||
|
||||
train_dataset = AugmentedSubset(full_dataset, train_indices)
|
||||
|
||||
# Create data loaders
|
||||
train_loader = DataLoader(
|
||||
train_dataset,
|
||||
batch_size=batch_size,
|
||||
shuffle=True,
|
||||
num_workers=num_workers,
|
||||
pin_memory=True,
|
||||
)
|
||||
|
||||
val_loader = DataLoader(
|
||||
val_dataset,
|
||||
batch_size=batch_size,
|
||||
shuffle=False,
|
||||
num_workers=num_workers,
|
||||
pin_memory=True,
|
||||
)
|
||||
|
||||
return train_loader, val_loader
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Example usage
|
||||
dataset = HomographyDataset(
|
||||
root_dir=r"C:\Users\admin\Projects\autopilot\datasets\ya_go_maps\images",
|
||||
augment=True,
|
||||
image_size=(256, 256),
|
||||
)
|
||||
|
||||
print(f"Dataset size: {len(dataset)}")
|
||||
|
||||
# Get a sample
|
||||
sample = dataset[0]
|
||||
print(f"Sample keys: {list(sample.keys())}")
|
||||
print(f"Google image shape: {sample['google_img'].shape}")
|
||||
print(f"Yandex image shape: {sample['yandex_img'].shape}")
|
||||
print(f"Homography shape: {sample['homography'].shape}")
|
||||
|
||||
# Create data loaders
|
||||
train_loader, val_loader = create_data_loaders(
|
||||
root_dir=r"C:\Users\admin\Projects\autopilot\datasets\ya_go_maps\images",
|
||||
batch_size=16,
|
||||
train_split=0.8,
|
||||
)
|
||||
|
||||
print(f"Train batches: {len(train_loader)}")
|
||||
print(f"Val batches: {len(val_loader)}")
|
||||
551
models/SiaN/homography_cnn.py
Normal file
551
models/SiaN/homography_cnn.py
Normal file
@@ -0,0 +1,551 @@
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class HomographyCNN(nn.Module):
|
||||
"""
|
||||
CNN model for homography estimation between two images.
|
||||
|
||||
This model takes two images (Google and Yandex maps) as input and
|
||||
outputs a 3x3 homography matrix that transforms one image to align with the other.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_channels: int = 3,
|
||||
hidden_channels: int = 64,
|
||||
num_blocks: int = 4,
|
||||
dropout_rate: float = 0.3,
|
||||
use_batch_norm: bool = True,
|
||||
output_size: int = 9, # Flattened 3x3 homography matrix
|
||||
):
|
||||
"""
|
||||
Initialize the HomographyCNN model.
|
||||
|
||||
Args:
|
||||
input_channels: Number of input channels per image (3 for RGB)
|
||||
hidden_channels: Base number of channels in the network
|
||||
num_blocks: Number of convolutional blocks
|
||||
dropout_rate: Dropout rate for regularization
|
||||
use_batch_norm: Whether to use batch normalization
|
||||
output_size: Size of output vector (9 for flattened 3x3 matrix)
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
self.input_channels = input_channels
|
||||
self.hidden_channels = hidden_channels
|
||||
self.num_blocks = num_blocks
|
||||
self.dropout_rate = dropout_rate
|
||||
self.use_batch_norm = use_batch_norm
|
||||
|
||||
# Feature extraction for each image separately
|
||||
self.google_encoder = self._build_encoder()
|
||||
self.yandex_encoder = self._build_encoder()
|
||||
|
||||
# Fusion layers to combine features from both images
|
||||
self.fusion_layers = self._build_fusion_layers()
|
||||
|
||||
# Regression head for homography estimation
|
||||
self.regression_head = self._build_regression_head(output_size)
|
||||
|
||||
# Initialize weights
|
||||
self._initialize_weights()
|
||||
|
||||
def _build_encoder(self) -> nn.Module:
|
||||
"""Build the encoder network for a single image."""
|
||||
layers = []
|
||||
in_channels = self.input_channels
|
||||
out_channels = self.hidden_channels
|
||||
|
||||
# First convolutional block
|
||||
layers.append(
|
||||
nn.Conv2d(in_channels, out_channels, kernel_size=7, stride=2, padding=3)
|
||||
)
|
||||
if self.use_batch_norm:
|
||||
layers.append(nn.BatchNorm2d(out_channels))
|
||||
layers.append(nn.ReLU(inplace=True))
|
||||
layers.append(nn.MaxPool2d(kernel_size=3, stride=2, padding=1))
|
||||
|
||||
# Additional convolutional blocks
|
||||
for i in range(self.num_blocks):
|
||||
block_in_channels = out_channels
|
||||
block_out_channels = out_channels * 2 if i < 2 else out_channels
|
||||
|
||||
layers.append(
|
||||
ResidualBlock(
|
||||
in_channels=block_in_channels,
|
||||
out_channels=block_out_channels,
|
||||
stride=1 if i == 0 else 2,
|
||||
dropout_rate=self.dropout_rate,
|
||||
use_batch_norm=self.use_batch_norm,
|
||||
)
|
||||
)
|
||||
|
||||
if i < 2:
|
||||
out_channels = block_out_channels
|
||||
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
def _build_fusion_layers(self) -> nn.Module:
|
||||
"""Build layers to fuse features from both images."""
|
||||
# After encoding, each image has hidden_channels * 4 features
|
||||
fused_channels = (
|
||||
self.hidden_channels * 8
|
||||
) # Concatenated features from both images
|
||||
|
||||
layers = [
|
||||
# Reduce dimensionality
|
||||
nn.Conv2d(
|
||||
fused_channels, self.hidden_channels * 4, kernel_size=3, padding=1
|
||||
),
|
||||
nn.BatchNorm2d(self.hidden_channels * 4)
|
||||
if self.use_batch_norm
|
||||
else nn.Identity(),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Dropout2d(self.dropout_rate),
|
||||
# Further processing
|
||||
nn.Conv2d(
|
||||
self.hidden_channels * 4,
|
||||
self.hidden_channels * 2,
|
||||
kernel_size=3,
|
||||
padding=1,
|
||||
),
|
||||
nn.BatchNorm2d(self.hidden_channels * 2)
|
||||
if self.use_batch_norm
|
||||
else nn.Identity(),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Dropout2d(self.dropout_rate),
|
||||
# Global pooling
|
||||
nn.AdaptiveAvgPool2d((1, 1)),
|
||||
]
|
||||
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
def _build_regression_head(self, output_size: int) -> nn.Module:
|
||||
"""Build the regression head for homography estimation."""
|
||||
# Input size after fusion and global pooling
|
||||
input_features = self.hidden_channels * 2
|
||||
|
||||
layers = [
|
||||
nn.Flatten(),
|
||||
nn.Linear(input_features, 512),
|
||||
nn.BatchNorm1d(512) if self.use_batch_norm else nn.Identity(),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Dropout(self.dropout_rate),
|
||||
nn.Linear(512, 256),
|
||||
nn.BatchNorm1d(256) if self.use_batch_norm else nn.Identity(),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Dropout(self.dropout_rate),
|
||||
nn.Linear(256, 128),
|
||||
nn.BatchNorm1d(128) if self.use_batch_norm else nn.Identity(),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Dropout(self.dropout_rate),
|
||||
nn.Linear(128, output_size),
|
||||
]
|
||||
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
def _initialize_weights(self):
|
||||
"""Initialize model weights."""
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
|
||||
if m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
elif isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d):
|
||||
nn.init.constant_(m.weight, 1)
|
||||
nn.init.constant_(m.bias, 0)
|
||||
elif isinstance(m, nn.Linear):
|
||||
nn.init.normal_(m.weight, 0, 0.01)
|
||||
nn.init.constant_(m.bias, 0)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
google_img: torch.Tensor,
|
||||
yandex_img: torch.Tensor,
|
||||
return_matrix: bool = True,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Forward pass of the model.
|
||||
|
||||
Args:
|
||||
google_img: Google map image tensor of shape (B, C, H, W)
|
||||
yandex_img: Yandex map image tensor of shape (B, C, H, W)
|
||||
return_matrix: If True, return 3x3 matrix; if False, return flattened vector
|
||||
|
||||
Returns:
|
||||
Homography matrix tensor of shape (B, 3, 3) or flattened vector of shape (B, 9)
|
||||
"""
|
||||
# Extract features from both images
|
||||
google_features = self.google_encoder(google_img)
|
||||
yandex_features = self.yandex_encoder(yandex_img)
|
||||
|
||||
# Concatenate features along channel dimension
|
||||
combined_features = torch.cat([google_features, yandex_features], dim=1)
|
||||
|
||||
# Fuse features
|
||||
fused_features = self.fusion_layers(combined_features)
|
||||
|
||||
# Regression to get homography parameters
|
||||
homography_flat = self.regression_head(fused_features)
|
||||
|
||||
if return_matrix:
|
||||
# Reshape to 3x3 matrix
|
||||
batch_size = homography_flat.shape[0]
|
||||
homography_matrix = homography_flat.view(batch_size, 3, 3)
|
||||
|
||||
# Ensure the last element is 1 (homogeneous coordinate normalization)
|
||||
# Add small epsilon to prevent division by zero
|
||||
epsilon = 1e-8
|
||||
homography_matrix = homography_matrix / (
|
||||
homography_matrix[:, 2, 2].view(-1, 1, 1) + epsilon
|
||||
)
|
||||
|
||||
return homography_matrix
|
||||
else:
|
||||
return homography_flat
|
||||
|
||||
def predict_homography(
|
||||
self,
|
||||
google_img: torch.Tensor,
|
||||
yandex_img: torch.Tensor,
|
||||
normalize: bool = True,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Predict homography matrix with optional normalization.
|
||||
|
||||
Args:
|
||||
google_img: Google map image tensor
|
||||
yandex_img: Yandex map image tensor
|
||||
normalize: Whether to normalize the homography matrix
|
||||
|
||||
Returns:
|
||||
Predicted homography matrix
|
||||
"""
|
||||
self.eval()
|
||||
with torch.no_grad():
|
||||
homography = self.forward(google_img, yandex_img, return_matrix=True)
|
||||
|
||||
if normalize:
|
||||
# Normalize so that last element is 1
|
||||
homography = homography / homography[:, 2, 2].view(-1, 1, 1)
|
||||
|
||||
return homography
|
||||
|
||||
|
||||
class ResidualBlock(nn.Module):
|
||||
"""Residual block with optional downsampling."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
stride: int = 1,
|
||||
dropout_rate: float = 0.3,
|
||||
use_batch_norm: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.conv1 = nn.Conv2d(
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size=3,
|
||||
stride=stride,
|
||||
padding=1,
|
||||
bias=False,
|
||||
)
|
||||
self.bn1 = nn.BatchNorm2d(out_channels) if use_batch_norm else nn.Identity()
|
||||
self.relu1 = nn.ReLU(inplace=True)
|
||||
self.dropout1 = (
|
||||
nn.Dropout2d(dropout_rate) if dropout_rate > 0 else nn.Identity()
|
||||
)
|
||||
|
||||
self.conv2 = nn.Conv2d(
|
||||
out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False
|
||||
)
|
||||
self.bn2 = nn.BatchNorm2d(out_channels) if use_batch_norm else nn.Identity()
|
||||
self.relu2 = nn.ReLU(inplace=True)
|
||||
self.dropout2 = (
|
||||
nn.Dropout2d(dropout_rate) if dropout_rate > 0 else nn.Identity()
|
||||
)
|
||||
|
||||
# Shortcut connection
|
||||
self.shortcut = nn.Sequential()
|
||||
if stride != 1 or in_channels != out_channels:
|
||||
self.shortcut = nn.Sequential(
|
||||
nn.Conv2d(
|
||||
in_channels, out_channels, kernel_size=1, stride=stride, bias=False
|
||||
),
|
||||
nn.BatchNorm2d(out_channels) if use_batch_norm else nn.Identity(),
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
identity = self.shortcut(x)
|
||||
|
||||
out = self.conv1(x)
|
||||
out = self.bn1(out)
|
||||
out = self.relu1(out)
|
||||
out = self.dropout1(out)
|
||||
|
||||
out = self.conv2(out)
|
||||
out = self.bn2(out)
|
||||
|
||||
out += identity
|
||||
out = self.relu2(out)
|
||||
out = self.dropout2(out)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class HomographyLoss(nn.Module):
|
||||
"""
|
||||
Custom loss function for homography estimation.
|
||||
|
||||
Combines multiple loss terms:
|
||||
1. Matrix element-wise L2 loss
|
||||
2. Geometric consistency loss (warping error)
|
||||
3. Determinant regularization (to prevent degenerate matrices)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
matrix_weight: float = 1.0,
|
||||
geometric_weight: float = 0.5,
|
||||
reg_weight: float = 0.1,
|
||||
grid_size: int = 8,
|
||||
):
|
||||
super().__init__()
|
||||
self.matrix_weight = matrix_weight
|
||||
self.geometric_weight = geometric_weight
|
||||
self.reg_weight = reg_weight
|
||||
self.grid_size = grid_size
|
||||
|
||||
# Create grid of points for geometric loss
|
||||
self.register_buffer(
|
||||
"grid_points",
|
||||
self._create_grid_points(grid_size),
|
||||
persistent=False,
|
||||
)
|
||||
|
||||
def _create_grid_points(self, grid_size: int) -> torch.Tensor:
|
||||
"""Create a grid of points for geometric consistency loss."""
|
||||
x = torch.linspace(-1, 1, grid_size)
|
||||
y = torch.linspace(-1, 1, grid_size)
|
||||
grid_y, grid_x = torch.meshgrid(y, x, indexing="ij")
|
||||
grid_points = torch.stack([grid_x.flatten(), grid_y.flatten()], dim=1)
|
||||
# Add homogeneous coordinate
|
||||
ones = torch.ones(grid_points.shape[0], 1)
|
||||
grid_points = torch.cat([grid_points, ones], dim=1)
|
||||
return grid_points.T # Shape: (3, grid_size*grid_size)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
pred_homography: torch.Tensor,
|
||||
target_homography: torch.Tensor,
|
||||
google_img: Optional[torch.Tensor] = None,
|
||||
yandex_img: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Compute homography loss.
|
||||
|
||||
Args:
|
||||
pred_homography: Predicted homography matrices (B, 3, 3)
|
||||
target_homography: Target homography matrices (B, 3, 3)
|
||||
google_img: Google images (optional, for geometric loss)
|
||||
yandex_img: Yandex images (optional, for geometric loss)
|
||||
|
||||
Returns:
|
||||
Combined loss value
|
||||
"""
|
||||
batch_size = pred_homography.shape[0]
|
||||
|
||||
# 1. Matrix element-wise L2 loss
|
||||
matrix_loss = F.mse_loss(pred_homography, target_homography)
|
||||
|
||||
# 2. Geometric consistency loss (if images provided)
|
||||
geometric_loss = torch.tensor(0.0, device=pred_homography.device)
|
||||
if google_img is not None and yandex_img is not None:
|
||||
# Warp grid points with predicted homography
|
||||
grid_points = self.grid_points.unsqueeze(0).expand(batch_size, -1, -1)
|
||||
warped_points = torch.bmm(pred_homography, grid_points)
|
||||
|
||||
# Normalize homogeneous coordinates
|
||||
warped_points = warped_points / (warped_points[:, 2:3, :] + 1e-8)
|
||||
|
||||
# Warp with target homography for comparison
|
||||
target_warped_points = torch.bmm(target_homography, grid_points)
|
||||
target_warped_points = target_warped_points / (
|
||||
target_warped_points[:, 2:3, :] + 1e-8
|
||||
)
|
||||
|
||||
# Compute point-wise distance
|
||||
geometric_loss = F.mse_loss(
|
||||
warped_points[:, :2, :], target_warped_points[:, :2, :]
|
||||
)
|
||||
|
||||
# 3. Regularization loss (prevent degenerate matrices)
|
||||
# Encourage determinant to be close to 1
|
||||
pred_det = torch.det(pred_homography)
|
||||
reg_loss = F.mse_loss(pred_det, torch.ones_like(pred_det))
|
||||
|
||||
# Combine losses
|
||||
total_loss = (
|
||||
self.matrix_weight * matrix_loss
|
||||
+ self.geometric_weight * geometric_loss
|
||||
+ self.reg_weight * reg_loss
|
||||
)
|
||||
|
||||
return total_loss
|
||||
|
||||
def compute_metrics(
|
||||
self,
|
||||
pred_homography: torch.Tensor,
|
||||
target_homography: torch.Tensor,
|
||||
) -> dict:
|
||||
"""
|
||||
Compute evaluation metrics for homography estimation.
|
||||
|
||||
Args:
|
||||
pred_homography: Predicted homography matrices
|
||||
target_homography: Target homography matrices
|
||||
|
||||
Returns:
|
||||
Dictionary of metrics
|
||||
"""
|
||||
with torch.no_grad():
|
||||
# Normalize matrices
|
||||
pred_norm = pred_homography / pred_homography[:, 2, 2].view(-1, 1, 1)
|
||||
target_norm = target_homography / target_homography[:, 2, 2].view(-1, 1, 1)
|
||||
|
||||
# Matrix L2 error
|
||||
matrix_error = F.mse_loss(pred_norm, target_norm, reduction="none").mean(
|
||||
dim=(1, 2)
|
||||
)
|
||||
|
||||
# Corner error (warp 4 corners of the image)
|
||||
corners = torch.tensor(
|
||||
[[-1, -1, 1], [1, -1, 1], [1, 1, 1], [-1, 1, 1]],
|
||||
dtype=torch.float32,
|
||||
device=pred_homography.device,
|
||||
).T # Shape: (3, 4)
|
||||
|
||||
corners = corners.unsqueeze(0).expand(pred_homography.shape[0], -1, -1)
|
||||
|
||||
pred_corners = torch.bmm(pred_norm, corners)
|
||||
pred_corners = pred_corners / (pred_corners[:, 2:3, :] + 1e-8)
|
||||
|
||||
target_corners = torch.bmm(target_norm, corners)
|
||||
target_corners = target_corners / (target_corners[:, 2:3, :] + 1e-8)
|
||||
|
||||
corner_error = torch.mean(
|
||||
torch.norm(pred_corners[:, :2, :] - target_corners[:, :2, :], dim=1),
|
||||
dim=1,
|
||||
)
|
||||
|
||||
# Average corner error in pixels (assuming image coordinates in [-1, 1])
|
||||
# Convert to pixel error if image size is known
|
||||
avg_corner_error = corner_error.mean().item()
|
||||
|
||||
return {
|
||||
"matrix_mse": matrix_error.mean().item(),
|
||||
"corner_error": avg_corner_error,
|
||||
"corner_error_px": avg_corner_error * 128, # Assuming 256x256 images
|
||||
}
|
||||
|
||||
|
||||
def create_homography_model(
|
||||
model_type: str = "cnn",
|
||||
input_size: Tuple[int, int] = (256, 256),
|
||||
**kwargs,
|
||||
) -> nn.Module:
|
||||
"""
|
||||
Factory function to create homography estimation model.
|
||||
|
||||
Args:
|
||||
model_type: Type of model to create ('cnn' or 'resnet')
|
||||
input_size: Input image size (height, width)
|
||||
**kwargs: Additional arguments passed to model constructor
|
||||
|
||||
Returns:
|
||||
Homography estimation model
|
||||
"""
|
||||
if model_type == "cnn":
|
||||
return HomographyCNN(**kwargs)
|
||||
else:
|
||||
raise ValueError(f"Unknown model type: {model_type}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Test the model
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
print(f"Using device: {device}")
|
||||
|
||||
# Create model
|
||||
model = HomographyCNN(
|
||||
input_channels=3,
|
||||
hidden_channels=64,
|
||||
num_blocks=4,
|
||||
dropout_rate=0.3,
|
||||
use_batch_norm=True,
|
||||
).to(device)
|
||||
|
||||
print(
|
||||
f"Model created with {sum(p.numel() for p in model.parameters()):,} parameters"
|
||||
)
|
||||
|
||||
# Create dummy input
|
||||
batch_size = 4
|
||||
height, width = 256, 256
|
||||
|
||||
google_img = torch.randn(batch_size, 3, height, width).to(device)
|
||||
yandex_img = torch.randn(batch_size, 3, height, width).to(device)
|
||||
|
||||
# Test forward pass
|
||||
print("\nTesting forward pass...")
|
||||
output = model(google_img, yandex_img, return_matrix=True)
|
||||
print(f"Output shape: {output.shape}") # Should be (4, 3, 3)
|
||||
print(f"Sample output:\n{output[0]}")
|
||||
|
||||
# Test prediction
|
||||
print("\nTesting prediction...")
|
||||
pred = model.predict_homography(google_img, yandex_img)
|
||||
print(f"Prediction shape: {pred.shape}")
|
||||
print(f"Last element (should be ~1): {pred[0, 2, 2]:.6f}")
|
||||
|
||||
# Test loss function
|
||||
print("\nTesting loss function...")
|
||||
target_homography = torch.eye(3).unsqueeze(0).expand(batch_size, -1, -1).to(device)
|
||||
loss_fn = HomographyLoss(
|
||||
matrix_weight=1.0,
|
||||
geometric_weight=0.5,
|
||||
reg_weight=0.1,
|
||||
grid_size=8,
|
||||
).to(device)
|
||||
|
||||
loss = loss_fn(output, target_homography, google_img, yandex_img)
|
||||
print(f"Loss value: {loss.item():.6f}")
|
||||
|
||||
# Test metrics
|
||||
print("\nTesting metrics...")
|
||||
metrics = loss_fn.compute_metrics(output, target_homography)
|
||||
for key, value in metrics.items():
|
||||
print(f"{key}: {value:.6f}")
|
||||
|
||||
# Test model factory
|
||||
print("\nTesting model factory...")
|
||||
model2 = create_homography_model(
|
||||
model_type="cnn",
|
||||
input_size=(256, 256),
|
||||
input_channels=3,
|
||||
hidden_channels=32,
|
||||
num_blocks=3,
|
||||
).to(device)
|
||||
|
||||
print(
|
||||
f"Model2 created with {sum(p.numel() for p in model2.parameters()):,} parameters"
|
||||
)
|
||||
|
||||
print("\nAll tests completed successfully!")
|
||||
553
models/SiaN/infer_homography.py
Normal file
553
models/SiaN/infer_homography.py
Normal file
@@ -0,0 +1,553 @@
|
||||
"""
|
||||
Inference script for homography estimation between Google and Yandex map images.
|
||||
|
||||
This script loads a trained homography estimation model and performs inference
|
||||
on new image pairs or the test dataset.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import cv2
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import torch
|
||||
from homography import HomographyDataset
|
||||
from homography_cnn import HomographyCNN, create_homography_model
|
||||
from PIL import Image
|
||||
from torchvision import transforms
|
||||
|
||||
|
||||
class HomographyInference:
|
||||
"""Class for performing inference with homography estimation model."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_path: str,
|
||||
config_path: Optional[str] = None,
|
||||
device: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
Initialize the inference class.
|
||||
|
||||
Args:
|
||||
model_path: Path to trained model checkpoint
|
||||
config_path: Path to model configuration file (optional)
|
||||
device: Device to run inference on ('cuda' or 'cpu')
|
||||
"""
|
||||
# Set device
|
||||
if device is None:
|
||||
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
else:
|
||||
self.device = torch.device(device)
|
||||
|
||||
print(f"Using device: {self.device}")
|
||||
|
||||
# Load configuration
|
||||
if config_path is None:
|
||||
# Try to find config in the same directory as model
|
||||
model_dir = Path(model_path).parent
|
||||
config_path = model_dir / "config.json"
|
||||
|
||||
if os.path.exists(config_path):
|
||||
with open(config_path, "r") as f:
|
||||
self.config = json.load(f)
|
||||
print(f"Loaded configuration from {config_path}")
|
||||
else:
|
||||
# Use default configuration
|
||||
self.config = {
|
||||
"image_size": [256, 256],
|
||||
"hidden_channels": 64,
|
||||
"num_blocks": 4,
|
||||
"dropout_rate": 0.3,
|
||||
"use_batch_norm": True,
|
||||
}
|
||||
print("Using default configuration")
|
||||
|
||||
# Create model
|
||||
self.model = self._create_model()
|
||||
self._load_model(model_path)
|
||||
|
||||
# Set up transforms
|
||||
self.transform = self._create_transforms()
|
||||
|
||||
# Set model to evaluation mode
|
||||
self.model.eval()
|
||||
|
||||
def _create_model(self) -> HomographyCNN:
|
||||
"""Create model based on configuration."""
|
||||
image_size = self.config.get("image_size", [256, 256])
|
||||
|
||||
model = create_homography_model(
|
||||
model_type="cnn",
|
||||
input_size=tuple(image_size),
|
||||
input_channels=3,
|
||||
hidden_channels=self.config.get("hidden_channels", 64),
|
||||
num_blocks=self.config.get("num_blocks", 4),
|
||||
dropout_rate=self.config.get("dropout_rate", 0.3),
|
||||
use_batch_norm=self.config.get("use_batch_norm", True),
|
||||
)
|
||||
|
||||
return model
|
||||
|
||||
def _load_model(self, model_path: str):
|
||||
"""Load model weights from checkpoint."""
|
||||
if not os.path.exists(model_path):
|
||||
raise FileNotFoundError(f"Model file not found: {model_path}")
|
||||
|
||||
# Load checkpoint
|
||||
checkpoint = torch.load(model_path, map_location=self.device)
|
||||
|
||||
if isinstance(checkpoint, dict) and "model_state_dict" in checkpoint:
|
||||
# Trainer checkpoint format
|
||||
self.model.load_state_dict(checkpoint["model_state_dict"])
|
||||
else:
|
||||
# Raw model weights format
|
||||
self.model.load_state_dict(checkpoint)
|
||||
|
||||
self.model = self.model.to(self.device)
|
||||
print(f"Loaded model from {model_path}")
|
||||
|
||||
def _create_transforms(self):
|
||||
"""Create image transforms for inference."""
|
||||
return transforms.Compose(
|
||||
[
|
||||
transforms.Resize(tuple(self.config.get("image_size", [256, 256]))),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize(
|
||||
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
def preprocess_images(
|
||||
self, google_img: Image.Image, yandex_img: Image.Image
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Preprocess images for inference.
|
||||
|
||||
Args:
|
||||
google_img: Google map image (PIL Image)
|
||||
yandex_img: Yandex map image (PIL Image)
|
||||
|
||||
Returns:
|
||||
Tuple of preprocessed image tensors
|
||||
"""
|
||||
# Convert to RGB if needed
|
||||
if google_img.mode != "RGB":
|
||||
google_img = google_img.convert("RGB")
|
||||
if yandex_img.mode != "RGB":
|
||||
yandex_img = yandex_img.convert("RGB")
|
||||
|
||||
# Apply transforms
|
||||
google_tensor = self.transform(google_img).unsqueeze(0) # Add batch dimension
|
||||
yandex_tensor = self.transform(yandex_img).unsqueeze(0)
|
||||
|
||||
return google_tensor, yandex_tensor
|
||||
|
||||
def predict(
|
||||
self,
|
||||
google_img: Image.Image,
|
||||
yandex_img: Image.Image,
|
||||
return_matrix: bool = True,
|
||||
normalize: bool = True,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Predict homography matrix for image pair.
|
||||
|
||||
Args:
|
||||
google_img: Google map image (PIL Image)
|
||||
yandex_img: Yandex map image (PIL Image)
|
||||
return_matrix: If True, return 3x3 matrix; if False, return flattened vector
|
||||
normalize: Whether to normalize the homography matrix
|
||||
|
||||
Returns:
|
||||
Predicted homography matrix or vector
|
||||
"""
|
||||
# Preprocess images
|
||||
google_tensor, yandex_tensor = self.preprocess_images(google_img, yandex_img)
|
||||
|
||||
# Move to device
|
||||
google_tensor = google_tensor.to(self.device)
|
||||
yandex_tensor = yandex_tensor.to(self.device)
|
||||
|
||||
# Perform inference
|
||||
with torch.no_grad():
|
||||
homography = self.model(
|
||||
google_tensor, yandex_tensor, return_matrix=return_matrix
|
||||
)
|
||||
|
||||
if return_matrix and normalize:
|
||||
# Normalize so that last element is 1
|
||||
homography = homography / homography[:, 2, 2].view(-1, 1, 1)
|
||||
|
||||
return homography.squeeze(0) # Remove batch dimension
|
||||
|
||||
def predict_from_paths(
|
||||
self,
|
||||
google_path: str,
|
||||
yandex_path: str,
|
||||
return_matrix: bool = True,
|
||||
normalize: bool = True,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Predict homography matrix from image file paths.
|
||||
|
||||
Args:
|
||||
google_path: Path to Google map image
|
||||
yandex_path: Path to Yandex map image
|
||||
return_matrix: If True, return 3x3 matrix; if False, return flattened vector
|
||||
normalize: Whether to normalize the homography matrix
|
||||
|
||||
Returns:
|
||||
Predicted homography matrix or vector
|
||||
"""
|
||||
# Load images
|
||||
google_img = Image.open(google_path)
|
||||
yandex_img = Image.open(yandex_path)
|
||||
|
||||
return self.predict(google_img, yandex_img, return_matrix, normalize)
|
||||
|
||||
def warp_image(
|
||||
self,
|
||||
img: Image.Image,
|
||||
homography: np.ndarray,
|
||||
output_size: Optional[Tuple[int, int]] = None,
|
||||
) -> Image.Image:
|
||||
"""
|
||||
Warp image using homography matrix.
|
||||
|
||||
Args:
|
||||
img: Input image (PIL Image)
|
||||
homography: 3x3 homography matrix (numpy array)
|
||||
output_size: Output image size (width, height). If None, uses input size.
|
||||
|
||||
Returns:
|
||||
Warped image (PIL Image)
|
||||
"""
|
||||
# Convert to numpy array
|
||||
img_np = np.array(img)
|
||||
|
||||
# Get output size
|
||||
if output_size is None:
|
||||
output_size = (img_np.shape[1], img_np.shape[0])
|
||||
|
||||
# Apply homography transformation
|
||||
warped_np = cv2.warpPerspective(
|
||||
img_np,
|
||||
homography,
|
||||
output_size,
|
||||
flags=cv2.INTER_LINEAR,
|
||||
borderMode=cv2.BORDER_REFLECT,
|
||||
)
|
||||
|
||||
# Convert back to PIL Image
|
||||
return Image.fromarray(warped_np)
|
||||
|
||||
def visualize_alignment(
|
||||
self,
|
||||
google_img: Image.Image,
|
||||
yandex_img: Image.Image,
|
||||
homography: np.ndarray,
|
||||
save_path: Optional[str] = None,
|
||||
show: bool = True,
|
||||
):
|
||||
"""
|
||||
Visualize alignment between images using homography.
|
||||
|
||||
Args:
|
||||
google_img: Google map image
|
||||
yandex_img: Yandex map image
|
||||
homography: Homography matrix
|
||||
save_path: Path to save visualization (optional)
|
||||
show: Whether to display the visualization
|
||||
"""
|
||||
# Warp yandex image to align with google
|
||||
yandex_warped = self.warp_image(yandex_img, homography)
|
||||
|
||||
# Convert images to numpy arrays for visualization
|
||||
google_np = np.array(google_img)
|
||||
yandex_np = np.array(yandex_img)
|
||||
yandex_warped_np = np.array(yandex_warped)
|
||||
|
||||
# Create visualization
|
||||
fig, axes = plt.subplots(2, 2, figsize=(12, 10))
|
||||
|
||||
# Original images
|
||||
axes[0, 0].imshow(google_np)
|
||||
axes[0, 0].set_title("Google Map (Original)")
|
||||
axes[0, 0].axis("off")
|
||||
|
||||
axes[0, 1].imshow(yandex_np)
|
||||
axes[0, 1].set_title("Yandex Map (Original)")
|
||||
axes[0, 1].axis("off")
|
||||
|
||||
# Warped image
|
||||
axes[1, 0].imshow(yandex_warped_np)
|
||||
axes[1, 0].set_title("Yandex Map (Warped)")
|
||||
axes[1, 0].axis("off")
|
||||
|
||||
# Overlay (50% transparency)
|
||||
overlay = cv2.addWeighted(google_np, 0.5, yandex_warped_np, 0.5, 0)
|
||||
axes[1, 1].imshow(overlay)
|
||||
axes[1, 1].set_title("Overlay (Google + Warped Yandex)")
|
||||
axes[1, 1].axis("off")
|
||||
|
||||
plt.tight_layout()
|
||||
|
||||
if save_path:
|
||||
plt.savefig(save_path, dpi=150, bbox_inches="tight")
|
||||
print(f"Visualization saved to {save_path}")
|
||||
|
||||
if show:
|
||||
plt.show()
|
||||
else:
|
||||
plt.close()
|
||||
|
||||
def evaluate_on_dataset(
|
||||
self,
|
||||
dataset_dir: str,
|
||||
num_samples: Optional[int] = None,
|
||||
save_dir: Optional[str] = None,
|
||||
) -> Dict[str, float]:
|
||||
"""
|
||||
Evaluate model on a dataset.
|
||||
|
||||
Args:
|
||||
dataset_dir: Directory containing image pairs
|
||||
num_samples: Number of samples to evaluate (None for all)
|
||||
save_dir: Directory to save visualizations (optional)
|
||||
|
||||
Returns:
|
||||
Dictionary of evaluation metrics
|
||||
"""
|
||||
# Create dataset
|
||||
dataset = HomographyDataset(
|
||||
root_dir=dataset_dir,
|
||||
transform=None, # We'll handle transforms manually
|
||||
augment=False,
|
||||
image_size=tuple(self.config.get("image_size", [256, 256])),
|
||||
cache_homographies=False,
|
||||
)
|
||||
|
||||
if num_samples is not None:
|
||||
indices = list(range(min(num_samples, len(dataset))))
|
||||
else:
|
||||
indices = list(range(len(dataset)))
|
||||
|
||||
errors = []
|
||||
corner_errors = []
|
||||
|
||||
print(f"Evaluating on {len(indices)} samples...")
|
||||
|
||||
for idx in indices:
|
||||
# Get sample without augmentation
|
||||
sample = dataset.get_sample_without_augmentation(idx)
|
||||
|
||||
google_img = sample["google_img"]
|
||||
yandex_img = sample["yandex_img"]
|
||||
target_homography = sample["homography"]
|
||||
|
||||
# Predict homography
|
||||
pred_homography = self.predict(
|
||||
google_img, yandex_img, return_matrix=True, normalize=True
|
||||
)
|
||||
|
||||
# Convert to numpy
|
||||
pred_homography_np = pred_homography.cpu().numpy()
|
||||
target_homography_np = target_homography
|
||||
|
||||
# Compute matrix error
|
||||
matrix_error = np.mean((pred_homography_np - target_homography_np) ** 2)
|
||||
errors.append(matrix_error)
|
||||
|
||||
# Compute corner error
|
||||
corners = np.array(
|
||||
[
|
||||
[-1, -1, 1],
|
||||
[1, -1, 1],
|
||||
[1, 1, 1],
|
||||
[-1, 1, 1],
|
||||
],
|
||||
dtype=np.float32,
|
||||
).T
|
||||
|
||||
pred_corners = pred_homography_np @ corners
|
||||
pred_corners = pred_corners / (pred_corners[2:3, :] + 1e-8)
|
||||
|
||||
target_corners = target_homography_np @ corners
|
||||
target_corners = target_corners / (target_corners[2:3, :] + 1e-8)
|
||||
|
||||
corner_error = np.mean(
|
||||
np.linalg.norm(pred_corners[:2, :] - target_corners[:2, :], axis=0)
|
||||
)
|
||||
corner_errors.append(corner_error)
|
||||
|
||||
# Save visualization if requested
|
||||
if save_dir:
|
||||
os.makedirs(save_dir, exist_ok=True)
|
||||
vis_path = os.path.join(save_dir, f"sample_{idx:04d}.png")
|
||||
self.visualize_alignment(
|
||||
google_img,
|
||||
yandex_img,
|
||||
pred_homography_np,
|
||||
save_path=vis_path,
|
||||
show=False,
|
||||
)
|
||||
|
||||
# Compute metrics
|
||||
metrics = {
|
||||
"mean_matrix_error": float(np.mean(errors)),
|
||||
"std_matrix_error": float(np.std(errors)),
|
||||
"mean_corner_error": float(np.mean(corner_errors)),
|
||||
"std_corner_error": float(np.std(corner_errors)),
|
||||
"median_corner_error": float(np.median(corner_errors)),
|
||||
"max_corner_error": float(np.max(corner_errors)),
|
||||
"min_corner_error": float(np.min(corner_errors)),
|
||||
}
|
||||
|
||||
print("\nEvaluation Results:")
|
||||
for key, value in metrics.items():
|
||||
print(f" {key}: {value:.6f}")
|
||||
|
||||
return metrics
|
||||
|
||||
|
||||
def main():
|
||||
"""Main inference function."""
|
||||
parser = argparse.ArgumentParser(description="Inference for homography estimation")
|
||||
|
||||
# Model arguments
|
||||
parser.add_argument(
|
||||
"--model_path",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to trained model checkpoint",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--config_path",
|
||||
type=str,
|
||||
help="Path to model configuration file (optional)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--device",
|
||||
type=str,
|
||||
choices=["cuda", "cpu"],
|
||||
help="Device to run inference on",
|
||||
)
|
||||
|
||||
# Inference mode
|
||||
parser.add_argument(
|
||||
"--mode",
|
||||
type=str,
|
||||
default="single",
|
||||
choices=["single", "dataset", "batch"],
|
||||
help="Inference mode",
|
||||
)
|
||||
|
||||
# Single image mode
|
||||
parser.add_argument(
|
||||
"--google_path",
|
||||
type=str,
|
||||
help="Path to Google map image (single mode)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--yandex_path",
|
||||
type=str,
|
||||
help="Path to Yandex map image (single mode)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output_vis",
|
||||
type=str,
|
||||
help="Path to save visualization (single mode)",
|
||||
)
|
||||
|
||||
# Dataset mode
|
||||
parser.add_argument(
|
||||
"--dataset_dir",
|
||||
type=str,
|
||||
default=r"C:\Users\admin\Projects\autopilot\datasets\ya_go_maps\images",
|
||||
help="Directory containing image pairs (dataset mode)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num_samples",
|
||||
type=int,
|
||||
help="Number of samples to evaluate (dataset mode)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--save_vis_dir",
|
||||
type=str,
|
||||
help="Directory to save visualizations (dataset mode)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--save_results",
|
||||
type=str,
|
||||
help="Path to save evaluation results (dataset mode)",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Create inference object
|
||||
inference = HomographyInference(
|
||||
model_path=args.model_path,
|
||||
config_path=args.config_path,
|
||||
device=args.device,
|
||||
)
|
||||
|
||||
if args.mode == "single":
|
||||
# Single image pair inference
|
||||
if not args.google_path or not args.yandex_path:
|
||||
raise ValueError(
|
||||
"Both --google_path and --yandex_path are required for single mode"
|
||||
)
|
||||
|
||||
print(f"Processing single image pair:")
|
||||
print(f" Google: {args.google_path}")
|
||||
print(f" Yandex: {args.yandex_path}")
|
||||
|
||||
# Predict homography
|
||||
homography = inference.predict_from_paths(args.google_path, args.yandex_path)
|
||||
|
||||
print(f"\nPredicted homography matrix:")
|
||||
print(homography.cpu().numpy())
|
||||
|
||||
# Visualize alignment
|
||||
if args.output_vis:
|
||||
google_img = Image.open(args.google_path)
|
||||
yandex_img = Image.open(args.yandex_path)
|
||||
inference.visualize_alignment(
|
||||
google_img,
|
||||
yandex_img,
|
||||
homography.cpu().numpy(),
|
||||
save_path=args.output_vis,
|
||||
show=True,
|
||||
)
|
||||
|
||||
elif args.mode == "dataset":
|
||||
# Evaluate on dataset
|
||||
metrics = inference.evaluate_on_dataset(
|
||||
dataset_dir=args.dataset_dir,
|
||||
num_samples=args.num_samples,
|
||||
save_dir=args.save_vis_dir,
|
||||
)
|
||||
|
||||
# Save results if requested
|
||||
if args.save_results:
|
||||
with open(args.save_results, "w") as f:
|
||||
json.dump(metrics, f, indent=2)
|
||||
print(f"\nResults saved to {args.save_results}")
|
||||
|
||||
elif args.mode == "batch":
|
||||
# Batch processing (placeholder for future implementation)
|
||||
print("Batch mode not yet implemented")
|
||||
# Could implement processing multiple image pairs from a directory
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unknown mode: {args.mode}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
611
models/SiaN/train_homography.py
Normal file
611
models/SiaN/train_homography.py
Normal file
@@ -0,0 +1,611 @@
|
||||
"""
|
||||
Training script for homography estimation between Google and Yandex map images.
|
||||
|
||||
This script trains a CNN model to estimate homography matrices that align
|
||||
Google map images with Yandex map images.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
from homography import HomographyDataset, create_data_loaders
|
||||
from homography_cnn import HomographyCNN, HomographyLoss, create_homography_model
|
||||
from torch.utils.data import DataLoader
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
class HomographyTrainer:
|
||||
"""Trainer class for homography estimation model."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: nn.Module,
|
||||
train_loader: DataLoader,
|
||||
val_loader: DataLoader,
|
||||
device: torch.device,
|
||||
config: Dict,
|
||||
):
|
||||
"""
|
||||
Initialize the trainer.
|
||||
|
||||
Args:
|
||||
model: Homography estimation model
|
||||
train_loader: Training data loader
|
||||
val_loader: Validation data loader
|
||||
device: Device to run training on
|
||||
config: Training configuration dictionary
|
||||
"""
|
||||
self.model = model.to(device)
|
||||
self.train_loader = train_loader
|
||||
self.val_loader = val_loader
|
||||
self.device = device
|
||||
self.config = config
|
||||
|
||||
# Loss function
|
||||
self.criterion = HomographyLoss(
|
||||
matrix_weight=config.get("matrix_weight", 1.0),
|
||||
geometric_weight=config.get("geometric_weight", 0.5),
|
||||
reg_weight=config.get("reg_weight", 0.1),
|
||||
grid_size=config.get("grid_size", 8),
|
||||
).to(device)
|
||||
|
||||
# Optimizer
|
||||
optimizer_name = config.get("optimizer", "adam").lower()
|
||||
lr = config.get("learning_rate", 1e-3)
|
||||
weight_decay = config.get("weight_decay", 1e-4)
|
||||
|
||||
if optimizer_name == "adam":
|
||||
self.optimizer = optim.Adam(
|
||||
self.model.parameters(), lr=lr, weight_decay=weight_decay
|
||||
)
|
||||
elif optimizer_name == "adamw":
|
||||
self.optimizer = optim.AdamW(
|
||||
self.model.parameters(), lr=lr, weight_decay=weight_decay
|
||||
)
|
||||
elif optimizer_name == "sgd":
|
||||
self.optimizer = optim.SGD(
|
||||
self.model.parameters(),
|
||||
lr=lr,
|
||||
momentum=config.get("momentum", 0.9),
|
||||
weight_decay=weight_decay,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown optimizer: {optimizer_name}")
|
||||
|
||||
# Learning rate scheduler
|
||||
scheduler_name = config.get("scheduler", "plateau").lower()
|
||||
if scheduler_name == "plateau":
|
||||
self.scheduler = optim.lr_scheduler.ReduceLROnPlateau(
|
||||
self.optimizer,
|
||||
mode="min",
|
||||
factor=config.get("scheduler_factor", 0.5),
|
||||
patience=config.get("scheduler_patience", 5),
|
||||
verbose=True,
|
||||
)
|
||||
elif scheduler_name == "cosine":
|
||||
self.scheduler = optim.lr_scheduler.CosineAnnealingLR(
|
||||
self.optimizer,
|
||||
T_max=config.get("epochs", 100),
|
||||
eta_min=config.get("min_lr", 1e-6),
|
||||
)
|
||||
elif scheduler_name == "step":
|
||||
self.scheduler = optim.lr_scheduler.StepLR(
|
||||
self.optimizer,
|
||||
step_size=config.get("step_size", 30),
|
||||
gamma=config.get("gamma", 0.1),
|
||||
)
|
||||
else:
|
||||
self.scheduler = None
|
||||
|
||||
# Training state
|
||||
self.current_epoch = 0
|
||||
self.best_val_loss = float("inf")
|
||||
self.train_losses: List[float] = []
|
||||
self.val_losses: List[float] = []
|
||||
self.val_metrics: List[Dict] = []
|
||||
|
||||
# Create output directory
|
||||
self.output_dir = Path(config.get("output_dir", "runs/homography"))
|
||||
self.output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# TensorBoard writer
|
||||
self.writer = SummaryWriter(log_dir=self.output_dir / "tensorboard")
|
||||
|
||||
# Save configuration
|
||||
config_path = self.output_dir / "config.json"
|
||||
with open(config_path, "w") as f:
|
||||
json.dump(config, f, indent=2)
|
||||
|
||||
print(f"Training configuration saved to {config_path}")
|
||||
print(
|
||||
f"Model has {sum(p.numel() for p in self.model.parameters()):,} parameters"
|
||||
)
|
||||
|
||||
def train_epoch(self) -> float:
|
||||
"""
|
||||
Train for one epoch.
|
||||
|
||||
Returns:
|
||||
Average training loss for the epoch
|
||||
"""
|
||||
self.model.train()
|
||||
total_loss = 0.0
|
||||
num_batches = len(self.train_loader)
|
||||
|
||||
progress_bar = tqdm(self.train_loader, desc=f"Epoch {self.current_epoch + 1}")
|
||||
for batch_idx, batch in enumerate(progress_bar):
|
||||
# Move data to device
|
||||
google_img = batch["google_img"].to(self.device)
|
||||
yandex_img = batch["yandex_img"].to(self.device)
|
||||
target_homography = batch["homography"].to(self.device)
|
||||
|
||||
# Forward pass
|
||||
self.optimizer.zero_grad()
|
||||
pred_homography = self.model(google_img, yandex_img, return_matrix=True)
|
||||
|
||||
# Compute loss
|
||||
loss = self.criterion(
|
||||
pred_homography,
|
||||
target_homography,
|
||||
google_img,
|
||||
yandex_img,
|
||||
)
|
||||
|
||||
# Backward pass
|
||||
loss.backward()
|
||||
|
||||
# Gradient clipping
|
||||
if self.config.get("grad_clip", 1.0) > 0:
|
||||
torch.nn.utils.clip_grad_norm_(
|
||||
self.model.parameters(),
|
||||
self.config.get("grad_clip", 1.0),
|
||||
)
|
||||
|
||||
# Optimizer step
|
||||
self.optimizer.step()
|
||||
|
||||
# Update statistics
|
||||
total_loss += loss.item()
|
||||
|
||||
# Update progress bar
|
||||
progress_bar.set_postfix({"loss": loss.item()})
|
||||
|
||||
# Log batch loss to TensorBoard
|
||||
global_step = self.current_epoch * num_batches + batch_idx
|
||||
self.writer.add_scalar("train/batch_loss", loss.item(), global_step)
|
||||
|
||||
avg_loss = total_loss / num_batches
|
||||
self.train_losses.append(avg_loss)
|
||||
|
||||
return avg_loss
|
||||
|
||||
@torch.no_grad()
|
||||
def validate(self) -> Tuple[float, Dict]:
|
||||
"""
|
||||
Validate the model.
|
||||
|
||||
Returns:
|
||||
Tuple of (average validation loss, validation metrics)
|
||||
"""
|
||||
self.model.eval()
|
||||
total_loss = 0.0
|
||||
all_metrics = []
|
||||
|
||||
progress_bar = tqdm(self.val_loader, desc="Validation")
|
||||
for batch in progress_bar:
|
||||
# Move data to device
|
||||
google_img = batch["google_img"].to(self.device)
|
||||
yandex_img = batch["yandex_img"].to(self.device)
|
||||
target_homography = batch["homography"].to(self.device)
|
||||
|
||||
# Forward pass
|
||||
pred_homography = self.model(google_img, yandex_img, return_matrix=True)
|
||||
|
||||
# Compute loss
|
||||
loss = self.criterion(
|
||||
pred_homography,
|
||||
target_homography,
|
||||
google_img,
|
||||
yandex_img,
|
||||
)
|
||||
|
||||
# Compute metrics
|
||||
metrics = self.criterion.compute_metrics(pred_homography, target_homography)
|
||||
|
||||
# Update statistics
|
||||
total_loss += loss.item()
|
||||
all_metrics.append(metrics)
|
||||
|
||||
# Update progress bar
|
||||
progress_bar.set_postfix({"loss": loss.item()})
|
||||
|
||||
avg_loss = total_loss / len(self.val_loader)
|
||||
self.val_losses.append(avg_loss)
|
||||
|
||||
# Aggregate metrics
|
||||
avg_metrics = {}
|
||||
for key in all_metrics[0].keys():
|
||||
avg_metrics[key] = np.mean([m[key] for m in all_metrics])
|
||||
|
||||
self.val_metrics.append(avg_metrics)
|
||||
|
||||
return avg_loss, avg_metrics
|
||||
|
||||
def save_checkpoint(self, is_best: bool = False):
|
||||
"""Save model checkpoint."""
|
||||
checkpoint = {
|
||||
"epoch": self.current_epoch,
|
||||
"model_state_dict": self.model.state_dict(),
|
||||
"optimizer_state_dict": self.optimizer.state_dict(),
|
||||
"train_losses": self.train_losses,
|
||||
"val_losses": self.val_losses,
|
||||
"val_metrics": self.val_metrics,
|
||||
"best_val_loss": self.best_val_loss,
|
||||
"config": self.config,
|
||||
}
|
||||
|
||||
if self.scheduler is not None:
|
||||
checkpoint["scheduler_state_dict"] = self.scheduler.state_dict()
|
||||
|
||||
# Save latest checkpoint
|
||||
checkpoint_path = self.output_dir / "checkpoint_latest.pth"
|
||||
torch.save(checkpoint, checkpoint_path)
|
||||
|
||||
# Save best checkpoint
|
||||
if is_best:
|
||||
best_path = self.output_dir / "checkpoint_best.pth"
|
||||
torch.save(checkpoint, best_path)
|
||||
print(f"Best model saved to {best_path}")
|
||||
|
||||
def load_checkpoint(self, checkpoint_path: str):
|
||||
"""Load model checkpoint."""
|
||||
checkpoint = torch.load(checkpoint_path, map_location=self.device)
|
||||
|
||||
self.model.load_state_dict(checkpoint["model_state_dict"])
|
||||
self.optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
|
||||
|
||||
if self.scheduler is not None and "scheduler_state_dict" in checkpoint:
|
||||
self.scheduler.load_state_dict(checkpoint["scheduler_state_dict"])
|
||||
|
||||
self.current_epoch = checkpoint["epoch"]
|
||||
self.train_losses = checkpoint["train_losses"]
|
||||
self.val_losses = checkpoint["val_losses"]
|
||||
self.val_metrics = checkpoint["val_metrics"]
|
||||
self.best_val_loss = checkpoint["best_val_loss"]
|
||||
|
||||
print(f"Loaded checkpoint from epoch {self.current_epoch}")
|
||||
|
||||
def train(self, num_epochs: int):
|
||||
"""
|
||||
Train the model for specified number of epochs.
|
||||
|
||||
Args:
|
||||
num_epochs: Number of epochs to train
|
||||
"""
|
||||
print(f"Starting training for {num_epochs} epochs...")
|
||||
start_time = time.time()
|
||||
|
||||
for epoch in range(num_epochs):
|
||||
self.current_epoch = epoch
|
||||
|
||||
# Train for one epoch
|
||||
train_loss = self.train_epoch()
|
||||
|
||||
# Validate
|
||||
val_loss, val_metrics = self.validate()
|
||||
|
||||
# Update learning rate scheduler
|
||||
if self.scheduler is not None:
|
||||
if isinstance(self.scheduler, optim.lr_scheduler.ReduceLROnPlateau):
|
||||
self.scheduler.step(val_loss)
|
||||
else:
|
||||
self.scheduler.step()
|
||||
|
||||
# Log to TensorBoard
|
||||
self.writer.add_scalar("train/epoch_loss", train_loss, epoch)
|
||||
self.writer.add_scalar("val/epoch_loss", val_loss, epoch)
|
||||
for metric_name, metric_value in val_metrics.items():
|
||||
self.writer.add_scalar(f"val/{metric_name}", metric_value, epoch)
|
||||
|
||||
# Print epoch summary
|
||||
print(f"\nEpoch {epoch + 1}/{num_epochs}:")
|
||||
print(f" Train Loss: {train_loss:.6f}")
|
||||
print(f" Val Loss: {val_loss:.6f}")
|
||||
print(" Val Metrics:")
|
||||
for metric_name, metric_value in val_metrics.items():
|
||||
print(f" {metric_name}: {metric_value:.6f}")
|
||||
|
||||
# Save checkpoint
|
||||
is_best = val_loss < self.best_val_loss
|
||||
if is_best:
|
||||
self.best_val_loss = val_loss
|
||||
|
||||
self.save_checkpoint(is_best=is_best)
|
||||
|
||||
# Early stopping
|
||||
if self.config.get("early_stopping_patience", 0) > 0:
|
||||
if (
|
||||
epoch - np.argmin(self.val_losses)
|
||||
>= self.config["early_stopping_patience"]
|
||||
):
|
||||
print(f"Early stopping at epoch {epoch + 1}")
|
||||
break
|
||||
|
||||
# Training completed
|
||||
training_time = time.time() - start_time
|
||||
print(f"\nTraining completed in {training_time:.2f} seconds")
|
||||
print(f"Best validation loss: {self.best_val_loss:.6f}")
|
||||
|
||||
# Save final model
|
||||
final_model_path = self.output_dir / "model_final.pth"
|
||||
torch.save(self.model.state_dict(), final_model_path)
|
||||
print(f"Final model saved to {final_model_path}")
|
||||
|
||||
# Close TensorBoard writer
|
||||
self.writer.close()
|
||||
|
||||
def evaluate(self, test_loader: Optional[DataLoader] = None) -> Dict:
|
||||
"""
|
||||
Evaluate the model on test data.
|
||||
|
||||
Args:
|
||||
test_loader: Test data loader (uses validation loader if None)
|
||||
|
||||
Returns:
|
||||
Dictionary of evaluation metrics
|
||||
"""
|
||||
if test_loader is None:
|
||||
test_loader = self.val_loader
|
||||
|
||||
self.model.eval()
|
||||
all_metrics = []
|
||||
|
||||
print("Evaluating model...")
|
||||
with torch.no_grad():
|
||||
for batch in tqdm(test_loader):
|
||||
# Move data to device
|
||||
google_img = batch["google_img"].to(self.device)
|
||||
yandex_img = batch["yandex_img"].to(self.device)
|
||||
target_homography = batch["homography"].to(self.device)
|
||||
|
||||
# Forward pass
|
||||
pred_homography = self.model(google_img, yandex_img, return_matrix=True)
|
||||
|
||||
# Compute metrics
|
||||
metrics = self.criterion.compute_metrics(
|
||||
pred_homography, target_homography
|
||||
)
|
||||
all_metrics.append(metrics)
|
||||
|
||||
# Aggregate metrics
|
||||
avg_metrics = {}
|
||||
for key in all_metrics[0].keys():
|
||||
avg_metrics[key] = np.mean([m[key] for m in all_metrics])
|
||||
|
||||
# Print evaluation results
|
||||
print("\nEvaluation Results:")
|
||||
for metric_name, metric_value in avg_metrics.items():
|
||||
print(f" {metric_name}: {metric_value:.6f}")
|
||||
|
||||
# Save evaluation results
|
||||
eval_path = self.output_dir / "evaluation_results.json"
|
||||
with open(eval_path, "w") as f:
|
||||
json.dump(avg_metrics, f, indent=2)
|
||||
print(f"Evaluation results saved to {eval_path}")
|
||||
|
||||
return avg_metrics
|
||||
|
||||
|
||||
def parse_args():
|
||||
"""Parse command line arguments."""
|
||||
parser = argparse.ArgumentParser(description="Train homography estimation model")
|
||||
|
||||
# Data arguments
|
||||
parser.add_argument(
|
||||
"--data_dir",
|
||||
type=str,
|
||||
default=r"C:\Users\admin\Projects\autopilot\datasets\ya_go_maps\images",
|
||||
help="Directory containing image pairs",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--batch_size", type=int, default=32, help="Batch size for training"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--image_size",
|
||||
type=int,
|
||||
nargs=2,
|
||||
default=[256, 256],
|
||||
help="Image size (height width)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--train_split", type=float, default=0.8, help="Train/validation split ratio"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num_workers", type=int, default=4, help="Number of data loader workers"
|
||||
)
|
||||
|
||||
# Model arguments
|
||||
parser.add_argument(
|
||||
"--model_type", type=str, default="cnn", choices=["cnn"], help="Model type"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--hidden_channels", type=int, default=64, help="Number of hidden channels"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num_blocks", type=int, default=4, help="Number of convolutional blocks"
|
||||
)
|
||||
parser.add_argument("--dropout_rate", type=float, default=0.3, help="Dropout rate")
|
||||
parser.add_argument(
|
||||
"--use_batch_norm", action="store_true", help="Use batch normalization"
|
||||
)
|
||||
|
||||
# Training arguments
|
||||
parser.add_argument("--epochs", type=int, default=100, help="Number of epochs")
|
||||
parser.add_argument("--lr", type=float, default=1e-3, help="Learning rate")
|
||||
parser.add_argument("--weight_decay", type=float, default=1e-4, help="Weight decay")
|
||||
parser.add_argument(
|
||||
"--optimizer", type=str, default="adam", choices=["adam", "adamw", "sgd"]
|
||||
)
|
||||
parser.add_argument(
|
||||
"--scheduler",
|
||||
type=str,
|
||||
default="plateau",
|
||||
choices=["plateau", "cosine", "step", "none"],
|
||||
)
|
||||
parser.add_argument(
|
||||
"--grad_clip", type=float, default=1.0, help="Gradient clipping value"
|
||||
)
|
||||
|
||||
# Loss arguments
|
||||
parser.add_argument(
|
||||
"--matrix_weight", type=float, default=1.0, help="Weight for matrix loss"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--geometric_weight",
|
||||
type=float,
|
||||
default=0.5,
|
||||
help="Weight for geometric loss",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--reg_weight", type=float, default=0.1, help="Weight for regularization loss"
|
||||
)
|
||||
|
||||
# Other arguments
|
||||
parser.add_argument(
|
||||
"--output_dir",
|
||||
type=str,
|
||||
default="runs/homography",
|
||||
help="Output directory for checkpoints and logs",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--resume",
|
||||
type=str,
|
||||
help="Path to checkpoint to resume training from",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--eval_only",
|
||||
action="store_true",
|
||||
help="Only evaluate the model (no training)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--seed", type=int, default=42, help="Random seed for reproducibility"
|
||||
)
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def main():
|
||||
"""Main training function."""
|
||||
args = parse_args()
|
||||
|
||||
# Set random seeds for reproducibility
|
||||
torch.manual_seed(args.seed)
|
||||
np.random.seed(args.seed)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed(args.seed)
|
||||
torch.cuda.manual_seed_all(args.seed)
|
||||
|
||||
# Set device
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
print(f"Using device: {device}")
|
||||
|
||||
# Create data loaders
|
||||
print("Creating data loaders...")
|
||||
train_loader, val_loader = create_data_loaders(
|
||||
root_dir=args.data_dir,
|
||||
batch_size=args.batch_size,
|
||||
train_split=args.train_split,
|
||||
num_workers=args.num_workers,
|
||||
image_size=tuple(args.image_size),
|
||||
augment_train=True,
|
||||
augment_val=False,
|
||||
)
|
||||
|
||||
print(f"Train batches: {len(train_loader)}")
|
||||
print(f"Val batches: {len(val_loader)}")
|
||||
|
||||
# Create model
|
||||
print("Creating model...")
|
||||
model = create_homography_model(
|
||||
model_type=args.model_type,
|
||||
input_size=tuple(args.image_size),
|
||||
input_channels=3,
|
||||
hidden_channels=args.hidden_channels,
|
||||
num_blocks=args.num_blocks,
|
||||
dropout_rate=args.dropout_rate,
|
||||
use_batch_norm=args.use_batch_norm,
|
||||
)
|
||||
|
||||
# Create trainer configuration
|
||||
config = {
|
||||
# Model config
|
||||
"model_type": args.model_type,
|
||||
"hidden_channels": args.hidden_channels,
|
||||
"num_blocks": args.num_blocks,
|
||||
"dropout_rate": args.dropout_rate,
|
||||
"use_batch_norm": args.use_batch_norm,
|
||||
"image_size": args.image_size,
|
||||
# Training config
|
||||
"epochs": args.epochs,
|
||||
"batch_size": args.batch_size,
|
||||
"learning_rate": args.lr,
|
||||
"weight_decay": args.weight_decay,
|
||||
"optimizer": args.optimizer,
|
||||
"scheduler": args.scheduler,
|
||||
"grad_clip": args.grad_clip,
|
||||
# Loss config
|
||||
"matrix_weight": args.matrix_weight,
|
||||
"geometric_weight": args.geometric_weight,
|
||||
"reg_weight": args.reg_weight,
|
||||
"grid_size": 8,
|
||||
# Data config
|
||||
"data_dir": args.data_dir,
|
||||
"train_split": args.train_split,
|
||||
"num_workers": args.num_workers,
|
||||
# Output config
|
||||
"output_dir": args.output_dir,
|
||||
"seed": args.seed,
|
||||
}
|
||||
|
||||
# Create trainer
|
||||
trainer = HomographyTrainer(
|
||||
model=model,
|
||||
train_loader=train_loader,
|
||||
val_loader=val_loader,
|
||||
device=device,
|
||||
config=config,
|
||||
)
|
||||
|
||||
# Resume from checkpoint if specified
|
||||
if args.resume:
|
||||
print(f"Resuming from checkpoint: {args.resume}")
|
||||
trainer.load_checkpoint(args.resume)
|
||||
|
||||
# Evaluate only mode
|
||||
if args.eval_only:
|
||||
trainer.evaluate()
|
||||
return
|
||||
|
||||
# Train the model
|
||||
trainer.train(num_epochs=args.epochs)
|
||||
|
||||
# Final evaluation
|
||||
print("\nPerforming final evaluation...")
|
||||
trainer.evaluate()
|
||||
|
||||
print("\nTraining completed successfully!")
|
||||
print(f"Results saved to: {args.output_dir}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
611
models/SiaN/train_homography_.py
Normal file
611
models/SiaN/train_homography_.py
Normal file
@@ -0,0 +1,611 @@
|
||||
"""
|
||||
Training script for homography estimation between Google and Yandex map images.
|
||||
|
||||
This script trains a CNN model to estimate homography matrices that align
|
||||
Google map images with Yandex map images.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
from homography import HomographyDataset, create_data_loaders
|
||||
from homography_cnn import HomographyCNN, HomographyLoss, create_homography_model
|
||||
from torch.utils.data import DataLoader
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
class HomographyTrainer:
|
||||
"""Trainer class for homography estimation model."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: nn.Module,
|
||||
train_loader: DataLoader,
|
||||
val_loader: DataLoader,
|
||||
device: torch.device,
|
||||
config: Dict,
|
||||
):
|
||||
"""
|
||||
Initialize the trainer.
|
||||
|
||||
Args:
|
||||
model: Homography estimation model
|
||||
train_loader: Training data loader
|
||||
val_loader: Validation data loader
|
||||
device: Device to run training on
|
||||
config: Training configuration dictionary
|
||||
"""
|
||||
self.model = model.to(device)
|
||||
self.train_loader = train_loader
|
||||
self.val_loader = val_loader
|
||||
self.device = device
|
||||
self.config = config
|
||||
|
||||
# Loss function
|
||||
self.criterion = HomographyLoss(
|
||||
matrix_weight=config.get("matrix_weight", 1.0),
|
||||
geometric_weight=config.get("geometric_weight", 0.5),
|
||||
reg_weight=config.get("reg_weight", 0.1),
|
||||
grid_size=config.get("grid_size", 8),
|
||||
).to(device)
|
||||
|
||||
# Optimizer
|
||||
optimizer_name = config.get("optimizer", "adam").lower()
|
||||
lr = config.get("learning_rate", 1e-3)
|
||||
weight_decay = config.get("weight_decay", 1e-4)
|
||||
|
||||
if optimizer_name == "adam":
|
||||
self.optimizer = optim.Adam(
|
||||
self.model.parameters(), lr=lr, weight_decay=weight_decay
|
||||
)
|
||||
elif optimizer_name == "adamw":
|
||||
self.optimizer = optim.AdamW(
|
||||
self.model.parameters(), lr=lr, weight_decay=weight_decay
|
||||
)
|
||||
elif optimizer_name == "sgd":
|
||||
self.optimizer = optim.SGD(
|
||||
self.model.parameters(),
|
||||
lr=lr,
|
||||
momentum=config.get("momentum", 0.9),
|
||||
weight_decay=weight_decay,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown optimizer: {optimizer_name}")
|
||||
|
||||
# Learning rate scheduler
|
||||
scheduler_name = config.get("scheduler", "plateau").lower()
|
||||
if scheduler_name == "plateau":
|
||||
self.scheduler = optim.lr_scheduler.ReduceLROnPlateau(
|
||||
self.optimizer,
|
||||
mode="min",
|
||||
factor=config.get("scheduler_factor", 0.5),
|
||||
patience=config.get("scheduler_patience", 5),
|
||||
verbose=True,
|
||||
)
|
||||
elif scheduler_name == "cosine":
|
||||
self.scheduler = optim.lr_scheduler.CosineAnnealingLR(
|
||||
self.optimizer,
|
||||
T_max=config.get("epochs", 100),
|
||||
eta_min=config.get("min_lr", 1e-6),
|
||||
)
|
||||
elif scheduler_name == "step":
|
||||
self.scheduler = optim.lr_scheduler.StepLR(
|
||||
self.optimizer,
|
||||
step_size=config.get("step_size", 30),
|
||||
gamma=config.get("gamma", 0.1),
|
||||
)
|
||||
else:
|
||||
self.scheduler = None
|
||||
|
||||
# Training state
|
||||
self.current_epoch = 0
|
||||
self.best_val_loss = float("inf")
|
||||
self.train_losses: List[float] = []
|
||||
self.val_losses: List[float] = []
|
||||
self.val_metrics: List[Dict] = []
|
||||
|
||||
# Create output directory
|
||||
self.output_dir = Path(config.get("output_dir", "runs/homography"))
|
||||
self.output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# TensorBoard writer
|
||||
self.writer = SummaryWriter(log_dir=self.output_dir / "tensorboard")
|
||||
|
||||
# Save configuration
|
||||
config_path = self.output_dir / "config.json"
|
||||
with open(config_path, "w") as f:
|
||||
json.dump(config, f, indent=2)
|
||||
|
||||
print(f"Training configuration saved to {config_path}")
|
||||
print(
|
||||
f"Model has {sum(p.numel() for p in self.model.parameters()):,} parameters"
|
||||
)
|
||||
|
||||
def train_epoch(self) -> float:
|
||||
"""
|
||||
Train for one epoch.
|
||||
|
||||
Returns:
|
||||
Average training loss for the epoch
|
||||
"""
|
||||
self.model.train()
|
||||
total_loss = 0.0
|
||||
num_batches = len(self.train_loader)
|
||||
|
||||
progress_bar = tqdm(self.train_loader, desc=f"Epoch {self.current_epoch + 1}")
|
||||
for batch_idx, batch in enumerate(progress_bar):
|
||||
# Move data to device
|
||||
google_img = batch["google_img"].to(self.device)
|
||||
yandex_img = batch["yandex_img"].to(self.device)
|
||||
target_homography = batch["homography"].to(self.device)
|
||||
|
||||
# Forward pass
|
||||
self.optimizer.zero_grad()
|
||||
pred_homography = self.model(google_img, yandex_img, return_matrix=True)
|
||||
|
||||
# Compute loss
|
||||
loss = self.criterion(
|
||||
pred_homography,
|
||||
target_homography,
|
||||
google_img,
|
||||
yandex_img,
|
||||
)
|
||||
|
||||
# Backward pass
|
||||
loss.backward()
|
||||
|
||||
# Gradient clipping
|
||||
if self.config.get("grad_clip", 1.0) > 0:
|
||||
torch.nn.utils.clip_grad_norm_(
|
||||
self.model.parameters(),
|
||||
self.config.get("grad_clip", 1.0),
|
||||
)
|
||||
|
||||
# Optimizer step
|
||||
self.optimizer.step()
|
||||
|
||||
# Update statistics
|
||||
total_loss += loss.item()
|
||||
|
||||
# Update progress bar
|
||||
progress_bar.set_postfix({"loss": loss.item()})
|
||||
|
||||
# Log batch loss to TensorBoard
|
||||
global_step = self.current_epoch * num_batches + batch_idx
|
||||
self.writer.add_scalar("train/batch_loss", loss.item(), global_step)
|
||||
|
||||
avg_loss = total_loss / num_batches
|
||||
self.train_losses.append(avg_loss)
|
||||
|
||||
return avg_loss
|
||||
|
||||
@torch.no_grad()
|
||||
def validate(self) -> Tuple[float, Dict]:
|
||||
"""
|
||||
Validate the model.
|
||||
|
||||
Returns:
|
||||
Tuple of (average validation loss, validation metrics)
|
||||
"""
|
||||
self.model.eval()
|
||||
total_loss = 0.0
|
||||
all_metrics = []
|
||||
|
||||
progress_bar = tqdm(self.val_loader, desc="Validation")
|
||||
for batch in progress_bar:
|
||||
# Move data to device
|
||||
google_img = batch["google_img"].to(self.device)
|
||||
yandex_img = batch["yandex_img"].to(self.device)
|
||||
target_homography = batch["homography"].to(self.device)
|
||||
|
||||
# Forward pass
|
||||
pred_homography = self.model(google_img, yandex_img, return_matrix=True)
|
||||
|
||||
# Compute loss
|
||||
loss = self.criterion(
|
||||
pred_homography,
|
||||
target_homography,
|
||||
google_img,
|
||||
yandex_img,
|
||||
)
|
||||
|
||||
# Compute metrics
|
||||
metrics = self.criterion.compute_metrics(pred_homography, target_homography)
|
||||
|
||||
# Update statistics
|
||||
total_loss += loss.item()
|
||||
all_metrics.append(metrics)
|
||||
|
||||
# Update progress bar
|
||||
progress_bar.set_postfix({"loss": loss.item()})
|
||||
|
||||
avg_loss = total_loss / len(self.val_loader)
|
||||
self.val_losses.append(avg_loss)
|
||||
|
||||
# Aggregate metrics
|
||||
avg_metrics = {}
|
||||
for key in all_metrics[0].keys():
|
||||
avg_metrics[key] = np.mean([m[key] for m in all_metrics])
|
||||
|
||||
self.val_metrics.append(avg_metrics)
|
||||
|
||||
return avg_loss, avg_metrics
|
||||
|
||||
def save_checkpoint(self, is_best: bool = False):
|
||||
"""Save model checkpoint."""
|
||||
checkpoint = {
|
||||
"epoch": self.current_epoch,
|
||||
"model_state_dict": self.model.state_dict(),
|
||||
"optimizer_state_dict": self.optimizer.state_dict(),
|
||||
"train_losses": self.train_losses,
|
||||
"val_losses": self.val_losses,
|
||||
"val_metrics": self.val_metrics,
|
||||
"best_val_loss": self.best_val_loss,
|
||||
"config": self.config,
|
||||
}
|
||||
|
||||
if self.scheduler is not None:
|
||||
checkpoint["scheduler_state_dict"] = self.scheduler.state_dict()
|
||||
|
||||
# Save latest checkpoint
|
||||
checkpoint_path = self.output_dir / "checkpoint_latest.pth"
|
||||
torch.save(checkpoint, checkpoint_path)
|
||||
|
||||
# Save best checkpoint
|
||||
if is_best:
|
||||
best_path = self.output_dir / "checkpoint_best.pth"
|
||||
torch.save(checkpoint, best_path)
|
||||
print(f"Best model saved to {best_path}")
|
||||
|
||||
def load_checkpoint(self, checkpoint_path: str):
|
||||
"""Load model checkpoint."""
|
||||
checkpoint = torch.load(checkpoint_path, map_location=self.device)
|
||||
|
||||
self.model.load_state_dict(checkpoint["model_state_dict"])
|
||||
self.optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
|
||||
|
||||
if self.scheduler is not None and "scheduler_state_dict" in checkpoint:
|
||||
self.scheduler.load_state_dict(checkpoint["scheduler_state_dict"])
|
||||
|
||||
self.current_epoch = checkpoint["epoch"]
|
||||
self.train_losses = checkpoint["train_losses"]
|
||||
self.val_losses = checkpoint["val_losses"]
|
||||
self.val_metrics = checkpoint["val_metrics"]
|
||||
self.best_val_loss = checkpoint["best_val_loss"]
|
||||
|
||||
print(f"Loaded checkpoint from epoch {self.current_epoch}")
|
||||
|
||||
def train(self, num_epochs: int):
|
||||
"""
|
||||
Train the model for specified number of epochs.
|
||||
|
||||
Args:
|
||||
num_epochs: Number of epochs to train
|
||||
"""
|
||||
print(f"Starting training for {num_epochs} epochs...")
|
||||
start_time = time.time()
|
||||
|
||||
for epoch in range(num_epochs):
|
||||
self.current_epoch = epoch
|
||||
|
||||
# Train for one epoch
|
||||
train_loss = self.train_epoch()
|
||||
|
||||
# Validate
|
||||
val_loss, val_metrics = self.validate()
|
||||
|
||||
# Update learning rate scheduler
|
||||
if self.scheduler is not None:
|
||||
if isinstance(self.scheduler, optim.lr_scheduler.ReduceLROnPlateau):
|
||||
self.scheduler.step(val_loss)
|
||||
else:
|
||||
self.scheduler.step()
|
||||
|
||||
# Log to TensorBoard
|
||||
self.writer.add_scalar("train/epoch_loss", train_loss, epoch)
|
||||
self.writer.add_scalar("val/epoch_loss", val_loss, epoch)
|
||||
for metric_name, metric_value in val_metrics.items():
|
||||
self.writer.add_scalar(f"val/{metric_name}", metric_value, epoch)
|
||||
|
||||
# Print epoch summary
|
||||
print(f"\nEpoch {epoch + 1}/{num_epochs}:")
|
||||
print(f" Train Loss: {train_loss:.6f}")
|
||||
print(f" Val Loss: {val_loss:.6f}")
|
||||
print(" Val Metrics:")
|
||||
for metric_name, metric_value in val_metrics.items():
|
||||
print(f" {metric_name}: {metric_value:.6f}")
|
||||
|
||||
# Save checkpoint
|
||||
is_best = val_loss < self.best_val_loss
|
||||
if is_best:
|
||||
self.best_val_loss = val_loss
|
||||
|
||||
self.save_checkpoint(is_best=is_best)
|
||||
|
||||
# Early stopping
|
||||
if self.config.get("early_stopping_patience", 0) > 0:
|
||||
if (
|
||||
epoch - np.argmin(self.val_losses)
|
||||
>= self.config["early_stopping_patience"]
|
||||
):
|
||||
print(f"Early stopping at epoch {epoch + 1}")
|
||||
break
|
||||
|
||||
# Training completed
|
||||
training_time = time.time() - start_time
|
||||
print(f"\nTraining completed in {training_time:.2f} seconds")
|
||||
print(f"Best validation loss: {self.best_val_loss:.6f}")
|
||||
|
||||
# Save final model
|
||||
final_model_path = self.output_dir / "model_final.pth"
|
||||
torch.save(self.model.state_dict(), final_model_path)
|
||||
print(f"Final model saved to {final_model_path}")
|
||||
|
||||
# Close TensorBoard writer
|
||||
self.writer.close()
|
||||
|
||||
def evaluate(self, test_loader: Optional[DataLoader] = None) -> Dict:
|
||||
"""
|
||||
Evaluate the model on test data.
|
||||
|
||||
Args:
|
||||
test_loader: Test data loader (uses validation loader if None)
|
||||
|
||||
Returns:
|
||||
Dictionary of evaluation metrics
|
||||
"""
|
||||
if test_loader is None:
|
||||
test_loader = self.val_loader
|
||||
|
||||
self.model.eval()
|
||||
all_metrics = []
|
||||
|
||||
print("Evaluating model...")
|
||||
with torch.no_grad():
|
||||
for batch in tqdm(test_loader):
|
||||
# Move data to device
|
||||
google_img = batch["google_img"].to(self.device)
|
||||
yandex_img = batch["yandex_img"].to(self.device)
|
||||
target_homography = batch["homography"].to(self.device)
|
||||
|
||||
# Forward pass
|
||||
pred_homography = self.model(google_img, yandex_img, return_matrix=True)
|
||||
|
||||
# Compute metrics
|
||||
metrics = self.criterion.compute_metrics(
|
||||
pred_homography, target_homography
|
||||
)
|
||||
all_metrics.append(metrics)
|
||||
|
||||
# Aggregate metrics
|
||||
avg_metrics = {}
|
||||
for key in all_metrics[0].keys():
|
||||
avg_metrics[key] = np.mean([m[key] for m in all_metrics])
|
||||
|
||||
# Print evaluation results
|
||||
print("\nEvaluation Results:")
|
||||
for metric_name, metric_value in avg_metrics.items():
|
||||
print(f" {metric_name}: {metric_value:.6f}")
|
||||
|
||||
# Save evaluation results
|
||||
eval_path = self.output_dir / "evaluation_results.json"
|
||||
with open(eval_path, "w") as f:
|
||||
json.dump(avg_metrics, f, indent=2)
|
||||
print(f"Evaluation results saved to {eval_path}")
|
||||
|
||||
return avg_metrics
|
||||
|
||||
|
||||
def parse_args():
|
||||
"""Parse command line arguments."""
|
||||
parser = argparse.ArgumentParser(description="Train homography estimation model")
|
||||
|
||||
# Data arguments
|
||||
parser.add_argument(
|
||||
"--data_dir",
|
||||
type=str,
|
||||
default=r"C:\Users\admin\Projects\autopilot\datasets\ya_go_maps\images",
|
||||
help="Directory containing image pairs",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--batch_size", type=int, default=32, help="Batch size for training"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--image_size",
|
||||
type=int,
|
||||
nargs=2,
|
||||
default=[256, 256],
|
||||
help="Image size (height width)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--train_split", type=float, default=0.8, help="Train/validation split ratio"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num_workers", type=int, default=4, help="Number of data loader workers"
|
||||
)
|
||||
|
||||
# Model arguments
|
||||
parser.add_argument(
|
||||
"--model_type", type=str, default="cnn", choices=["cnn"], help="Model type"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--hidden_channels", type=int, default=64, help="Number of hidden channels"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num_blocks", type=int, default=4, help="Number of convolutional blocks"
|
||||
)
|
||||
parser.add_argument("--dropout_rate", type=float, default=0.3, help="Dropout rate")
|
||||
parser.add_argument(
|
||||
"--use_batch_norm", action="store_true", help="Use batch normalization"
|
||||
)
|
||||
|
||||
# Training arguments
|
||||
parser.add_argument("--epochs", type=int, default=100, help="Number of epochs")
|
||||
parser.add_argument("--lr", type=float, default=1e-3, help="Learning rate")
|
||||
parser.add_argument("--weight_decay", type=float, default=1e-4, help="Weight decay")
|
||||
parser.add_argument(
|
||||
"--optimizer", type=str, default="adam", choices=["adam", "adamw", "sgd"]
|
||||
)
|
||||
parser.add_argument(
|
||||
"--scheduler",
|
||||
type=str,
|
||||
default="plateau",
|
||||
choices=["plateau", "cosine", "step", "none"],
|
||||
)
|
||||
parser.add_argument(
|
||||
"--grad_clip", type=float, default=1.0, help="Gradient clipping value"
|
||||
)
|
||||
|
||||
# Loss arguments
|
||||
parser.add_argument(
|
||||
"--matrix_weight", type=float, default=1.0, help="Weight for matrix loss"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--geometric_weight",
|
||||
type=float,
|
||||
default=0.5,
|
||||
help="Weight for geometric loss",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--reg_weight", type=float, default=0.1, help="Weight for regularization loss"
|
||||
)
|
||||
|
||||
# Other arguments
|
||||
parser.add_argument(
|
||||
"--output_dir",
|
||||
type=str,
|
||||
default="runs/homography",
|
||||
help="Output directory for checkpoints and logs",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--resume",
|
||||
type=str,
|
||||
help="Path to checkpoint to resume training from",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--eval_only",
|
||||
action="store_true",
|
||||
help="Only evaluate the model (no training)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--seed", type=int, default=42, help="Random seed for reproducibility"
|
||||
)
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def main():
|
||||
"""Main training function."""
|
||||
args = parse_args()
|
||||
|
||||
# Set random seeds for reproducibility
|
||||
torch.manual_seed(args.seed)
|
||||
np.random.seed(args.seed)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed(args.seed)
|
||||
torch.cuda.manual_seed_all(args.seed)
|
||||
|
||||
# Set device
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
print(f"Using device: {device}")
|
||||
|
||||
# Create data loaders
|
||||
print("Creating data loaders...")
|
||||
train_loader, val_loader = create_data_loaders(
|
||||
root_dir=args.data_dir,
|
||||
batch_size=args.batch_size,
|
||||
train_split=args.train_split,
|
||||
num_workers=args.num_workers,
|
||||
image_size=tuple(args.image_size),
|
||||
augment_train=True,
|
||||
augment_val=False,
|
||||
)
|
||||
|
||||
print(f"Train batches: {len(train_loader)}")
|
||||
print(f"Val batches: {len(val_loader)}")
|
||||
|
||||
# Create model
|
||||
print("Creating model...")
|
||||
model = create_homography_model(
|
||||
model_type=args.model_type,
|
||||
input_size=tuple(args.image_size),
|
||||
input_channels=3,
|
||||
hidden_channels=args.hidden_channels,
|
||||
num_blocks=args.num_blocks,
|
||||
dropout_rate=args.dropout_rate,
|
||||
use_batch_norm=args.use_batch_norm,
|
||||
)
|
||||
|
||||
# Create trainer configuration
|
||||
config = {
|
||||
# Model config
|
||||
"model_type": args.model_type,
|
||||
"hidden_channels": args.hidden_channels,
|
||||
"num_blocks": args.num_blocks,
|
||||
"dropout_rate": args.dropout_rate,
|
||||
"use_batch_norm": args.use_batch_norm,
|
||||
"image_size": args.image_size,
|
||||
# Training config
|
||||
"epochs": args.epochs,
|
||||
"batch_size": args.batch_size,
|
||||
"learning_rate": args.lr,
|
||||
"weight_decay": args.weight_decay,
|
||||
"optimizer": args.optimizer,
|
||||
"scheduler": args.scheduler,
|
||||
"grad_clip": args.grad_clip,
|
||||
# Loss config
|
||||
"matrix_weight": args.matrix_weight,
|
||||
"geometric_weight": args.geometric_weight,
|
||||
"reg_weight": args.reg_weight,
|
||||
"grid_size": 8,
|
||||
# Data config
|
||||
"data_dir": args.data_dir,
|
||||
"train_split": args.train_split,
|
||||
"num_workers": args.num_workers,
|
||||
# Output config
|
||||
"output_dir": args.output_dir,
|
||||
"seed": args.seed,
|
||||
}
|
||||
|
||||
# Create trainer
|
||||
trainer = HomographyTrainer(
|
||||
model=model,
|
||||
train_loader=train_loader,
|
||||
val_loader=val_loader,
|
||||
device=device,
|
||||
config=config,
|
||||
)
|
||||
|
||||
# Resume from checkpoint if specified
|
||||
if args.resume:
|
||||
print(f"Resuming from checkpoint: {args.resume}")
|
||||
trainer.load_checkpoint(args.resume)
|
||||
|
||||
# Evaluate only mode
|
||||
if args.eval_only:
|
||||
trainer.evaluate()
|
||||
return
|
||||
|
||||
# Train the model
|
||||
trainer.train(num_epochs=args.epochs)
|
||||
|
||||
# Final evaluation
|
||||
print("\nPerforming final evaluation...")
|
||||
trainer.evaluate()
|
||||
|
||||
print("\nTraining completed successfully!")
|
||||
print(f"Results saved to: {args.output_dir}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
54
position.py
54
position.py
@@ -42,22 +42,45 @@ class Position:
|
||||
f"roll={math.degrees(self.roll):.1f}°)"
|
||||
)
|
||||
|
||||
def get_homography_matrix(self, K: np.ndarray = constants.K, sliding: bool = True) -> np.ndarray:
|
||||
def __imul__(self, scalar: float):
|
||||
self.x *= scalar
|
||||
self.y *= scalar
|
||||
self.z *= scalar
|
||||
return self
|
||||
|
||||
def __mul__(self, scalar: float) -> 'Position':
|
||||
pos = self.copy()
|
||||
pos *= scalar
|
||||
return pos
|
||||
|
||||
def __itruediv__(self, scalar: float):
|
||||
self.x /= scalar
|
||||
self.y /= scalar
|
||||
self.z /= scalar
|
||||
return self
|
||||
|
||||
def __truediv__(self, scalar: float) -> 'Position':
|
||||
pos = self.copy()
|
||||
pos /= scalar
|
||||
return pos
|
||||
|
||||
def get_homography_matrix(self, K_in: np.ndarray = constants.K, K_out: np.ndarray | None = None, sliding: bool = True) -> np.ndarray:
|
||||
""" Возвращает матрицу гомографии """
|
||||
R = self.get_rotation_matrix()
|
||||
T = self.get_translation_matrix()
|
||||
T = self.get_translation_matrix(K_in)
|
||||
if not sliding:
|
||||
T[0, 2] = T[1, 2] = 0
|
||||
return K @ R @ T @ np.linalg.inv(K)
|
||||
if K_out is None: K_out = K_in
|
||||
return K_out @ R @ T @ np.linalg.inv(K_in)
|
||||
|
||||
def copy(self) -> 'Position':
|
||||
"""Создает полную копию объекта"""
|
||||
return Position(self.x, self.y, self.z, self.yaw, self.pitch, self.roll)
|
||||
|
||||
def get_translation_matrix(self) -> np.ndarray:
|
||||
def get_translation_matrix(self, K: np.ndarray = constants.K) -> np.ndarray:
|
||||
return np.array([
|
||||
[1, 0, self.x / constants._K_FOCUS_DISTANCE],
|
||||
[0, 1, self.y / constants._K_FOCUS_DISTANCE],
|
||||
[1, 0, self.x / K[0][0]],
|
||||
[0, 1, self.y / K[0][0]],
|
||||
[0, 0, self.z]
|
||||
])
|
||||
|
||||
@@ -101,6 +124,13 @@ class Position:
|
||||
R = np.array(R)
|
||||
t = np.array(t)
|
||||
|
||||
# print(cv2.decomposeHomographyMat(inv(H), K))
|
||||
# _, _, z, _ = cv2.decomposeHomographyMat(inv(H), K)
|
||||
# print(z)
|
||||
# z = z.copy()
|
||||
# z *= constants._K_FOCUS_DISTANCE
|
||||
# print(z)
|
||||
|
||||
T = inv(R) @ inv(K) @ H @ K
|
||||
ind = np.array([A[2][0] ** 2 + A[2][1] ** 2 for A in T])
|
||||
top_k = max(1, len(T) // 2)
|
||||
@@ -116,14 +146,16 @@ class Position:
|
||||
|
||||
R = R[best_id]
|
||||
rot = Rotation.from_matrix(R).as_euler('XYZ').flatten()
|
||||
self.roll = rot[0]
|
||||
self.pitch = rot[1]
|
||||
self.roll = min(np.radians(5), max(np.radians(-5), rot[0]))
|
||||
self.pitch = min(np.radians(5), max(np.radians(-5), rot[1]))
|
||||
self.yaw = rot[2]
|
||||
|
||||
t = t[best_id].flatten()
|
||||
self.x += -T[0] * constants._K_FOCUS_DISTANCE * self.z
|
||||
self.y += T[1] * constants._K_FOCUS_DISTANCE * self.z
|
||||
self.z = 1 + T[2]
|
||||
self.x -= T[0] * constants._K_FOCUS_DISTANCE
|
||||
self.y += T[1] * constants._K_FOCUS_DISTANCE
|
||||
self.z = max(0.7, min(1.3, 1 + T[2]))
|
||||
T[0] *= constants._K_FOCUS_DISTANCE
|
||||
T[1] *= constants._K_FOCUS_DISTANCE
|
||||
|
||||
def apply(self, homography_matrix: np.ndarray, K = constants.K) -> 'Position':
|
||||
"""Применяет матрицу трансформации для вычисления новой позиции и ориентации."""
|
||||
|
||||
40
simulator.py
40
simulator.py
@@ -8,12 +8,13 @@ import numpy as np
|
||||
from position import Position
|
||||
from vision_chunk import VisionChunk
|
||||
from yandex_map import YandexMap
|
||||
from google_map import GoogleMap
|
||||
import constants
|
||||
import utility
|
||||
|
||||
class Simulator:
|
||||
def __init__(self, yandex_map: YandexMap = None):
|
||||
self.yandex_map = yandex_map
|
||||
def __init__(self, online_map: YandexMap | GoogleMap = None):
|
||||
self.online_map = online_map
|
||||
# Используем новый конструктор с yaw, pitch, roll
|
||||
self.pos = Position(x=0, y=0, z=1, yaw=0, pitch=0, roll=0)
|
||||
|
||||
@@ -35,24 +36,26 @@ class Simulator:
|
||||
Возвращает квадратное изображение 700x700.
|
||||
"""
|
||||
img_array = np.array(image)
|
||||
print(img_array.shape)
|
||||
h, w, _ = img_array.shape
|
||||
|
||||
# Применяем трансформацию
|
||||
pos = self.pos.copy()
|
||||
pos.x = 0
|
||||
pos.y = 0
|
||||
K = utility.calc_camera_matrix(w, h)
|
||||
K = constants.K
|
||||
img_array = img_array[:constants.CHUNK_WIDTH, :constants.CHUNK_WIDTH]
|
||||
transformed = cv2.warpPerspective(img_array, pos.get_homography_matrix(K), (constants.CHUNK_WIDTH, constants.CHUNK_WIDTH))
|
||||
|
||||
K_in = utility.calc_camera_matrix(w, h)
|
||||
K_out = constants.K
|
||||
H = pos.get_homography_matrix(K_in, K_out)
|
||||
|
||||
shape = (constants.CHUNK_WIDTH, constants.CHUNK_WIDTH)
|
||||
transformed = cv2.warpPerspective(img_array, H, shape)
|
||||
|
||||
return Image.fromarray(transformed)
|
||||
|
||||
def update_trajectory(self, dx: float, dy: float):
|
||||
"""Обновляет координаты дрона"""
|
||||
self.pos.x += dx * self.pos.z
|
||||
self.pos.y += dy * self.pos.z
|
||||
self.pos.x += dx
|
||||
self.pos.y += dy
|
||||
|
||||
def handle(self, dangle: float, velocity: float = 50) -> None:
|
||||
"""
|
||||
@@ -60,27 +63,17 @@ class Simulator:
|
||||
dangle - изменение угла курса (радианы)
|
||||
velocity - скорость движения
|
||||
"""
|
||||
from selenium.webdriver.common.by import By
|
||||
from selenium.webdriver.common.action_chains import ActionChains
|
||||
|
||||
html = self.yandex_map.driver.find_element(By.TAG_NAME, 'html')
|
||||
action = ActionChains(self.yandex_map.driver)
|
||||
action.move_to_element_with_offset(html, 200, 200)
|
||||
action.click_and_hold()
|
||||
|
||||
# Обновляем yaw в объекте Position
|
||||
self.pos.yaw += dangle
|
||||
velocity = max(velocity, 10)
|
||||
|
||||
# Вычисляем смещение на основе текущего yaw
|
||||
dx = math.cos(math.pi / 2 + self.pos.yaw) * velocity / self.pos.z
|
||||
dy = math.sin(math.pi / 2 + self.pos.yaw) * velocity / self.pos.z
|
||||
dx = int(math.cos(math.pi / 2 + self.pos.yaw) * velocity)
|
||||
dy = int(math.sin(math.pi / 2 + self.pos.yaw) * velocity)
|
||||
|
||||
self.update_trajectory(dx, dy)
|
||||
|
||||
action.move_by_offset(-dx, dy)
|
||||
action.release()
|
||||
action.perform()
|
||||
self.online_map.move(dx, dy)
|
||||
|
||||
def set_zoom(self, zoom_level: float):
|
||||
"""Программное изменение масштаба"""
|
||||
@@ -88,8 +81,7 @@ class Simulator:
|
||||
|
||||
def get_chunk(self) -> VisionChunk:
|
||||
"""Получить текущий снимок с камеры дрона"""
|
||||
png = self.yandex_map.driver.get_screenshot_as_png()
|
||||
im = Image.open(BytesIO(png))
|
||||
im = self.online_map.make_screenshot()
|
||||
|
||||
# Применяем перспективную трансформацию
|
||||
transformed_im = self._apply_perspective_transform(im)
|
||||
|
||||
189
test_chunks.ipynb
Normal file
189
test_chunks.ipynb
Normal file
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
6
timer.py
6
timer.py
@@ -29,3 +29,9 @@ class Timer:
|
||||
self.elapsed = 0.
|
||||
self.enabled = False
|
||||
self.last_enabled = 0.
|
||||
|
||||
def loop(self) -> float:
|
||||
v = self.get_diff()
|
||||
self.stop()
|
||||
self.start()
|
||||
return v
|
||||
24
todo.md
24
todo.md
@@ -2,16 +2,16 @@
|
||||
[!] Проверка корректности выявления ориентира на кадре
|
||||
[!] Исправление коррекции координат на основе сопоставления с ориентиром
|
||||
|
||||
[-] FPS счетчик
|
||||
| [-] Оптимизация детекции точек
|
||||
[-] Оформление статистики при тестовых запусках
|
||||
[-] Проведение тестовых запусков
|
||||
[-] Оформление отчета
|
||||
[-] Эксперименты с разными детекторами (SIFT, KAZE)
|
||||
[+] FPS счетчик
|
||||
| [+] Оптимизация детекции точек
|
||||
[+] Оформление статистики при тестовых запусках
|
||||
[+] Проведение тестовых запусков
|
||||
[+] Оформление отчета
|
||||
[+] Эксперименты с разными детекторами (SIFT, KAZE)
|
||||
|
||||
[?] Изменение масштаба во время полёта, обработка этой трансформации
|
||||
[?] Поворот ориентиров
|
||||
[?] Ограничение выбора точек при построении маршрута, чтобы ориентиры полностью попадали в кадр
|
||||
[+] Изменение масштаба во время полёта, обработка этой трансформации
|
||||
[+] Поворот ориентиров
|
||||
[+] Ограничение выбора точек при построении маршрута, чтобы ориентиры полностью попадали в кадр
|
||||
|
||||
[+] График межкадрового смещения
|
||||
| [+] График межкадровых смещениях по версии матрицы гомографии
|
||||
@@ -19,6 +19,6 @@
|
||||
| [+] Исследовать причину погрешности при развороте
|
||||
[+] Устранение большой погрешности при повороте
|
||||
|
||||
[ ] Переделать ключевые точки -> Optical Flow
|
||||
[ ] Добавить перспективу
|
||||
[ ] Эталоны на Google Maps, полёт тот же
|
||||
[+] Переделать ключевые точки -> Optical Flow
|
||||
[+] Добавить перспективу
|
||||
[+] Эталоны на Google Maps, полёт тот же
|
||||
|
||||
143
utility.py
143
utility.py
@@ -1,7 +1,11 @@
|
||||
from PIL import Image
|
||||
from datetime import datetime
|
||||
from urllib.parse import parse_qs, urlparse, unquote
|
||||
|
||||
import constants
|
||||
import cv2
|
||||
import numpy as np
|
||||
import constants
|
||||
import re
|
||||
|
||||
def cv2_to_pil(cv_image: np.ndarray) -> Image.Image:
|
||||
"""
|
||||
@@ -19,11 +23,10 @@ def get_normals(H: np.ndarray, R: np.ndarray, T: np.ndarray) -> np.ndarray:
|
||||
n = cv2.decomposeHomographyMat(H, constants.K, R, T)
|
||||
return n
|
||||
|
||||
def estimate_transformation_matrix(src_pts: np.ndarray, dst_pts: np.ndarray) -> tuple[np.ndarray, float | None]:
|
||||
def estimate_transformation_matrix(src_pts: np.ndarray, dst_pts: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
|
||||
"""Оценивает матрицу трансформации на основе сопоставленных точек"""
|
||||
# Используем RANSAC для оценки матрицы гомографии
|
||||
H, _ = cv2.findHomography(src_pts, dst_pts, cv2.RANSAC, 3.0, maxIters=1000)
|
||||
return H
|
||||
return cv2.findHomography(src_pts, dst_pts, cv2.RANSAC, 3.0, maxIters=1000)
|
||||
|
||||
def calc_camera_matrix(w: float, h: float):
|
||||
f = constants._K_FOCUS_DISTANCE
|
||||
@@ -32,3 +35,135 @@ def calc_camera_matrix(w: float, h: float):
|
||||
[0, f, h / 2],
|
||||
[0, 0, 1]
|
||||
])
|
||||
|
||||
|
||||
def generate_folder_name():
|
||||
"""
|
||||
Генерирует название для папки с текущей датой и временем до секунд.
|
||||
Формат: YYYY-MM-DD_HH-MM-SS
|
||||
"""
|
||||
now = datetime.now()
|
||||
folder_name = now.strftime("%Y-%m-%d_%H-%M-%S")
|
||||
return folder_name
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
def parse_yandex_maps_url(url):
|
||||
"""
|
||||
Парсит URL Яндекс.Карт и извлекает lat, lon и zoom
|
||||
Формат: ?ll=lon,lat&z=zoom
|
||||
"""
|
||||
# Декодируем URL (на случай %2C вместо запятых)
|
||||
url = unquote(url)
|
||||
|
||||
# Парсим URL
|
||||
parsed_url = urlparse(url)
|
||||
params = parse_qs(parsed_url.query)
|
||||
|
||||
if 'll' in params and 'z' in params:
|
||||
# ll содержит "lon,lat"
|
||||
ll_value = params['ll'][0]
|
||||
lat, lon = map(float, ll_value.split(','))
|
||||
zoom = int(params['z'][0])
|
||||
|
||||
return {
|
||||
'lat': lat,
|
||||
'lon': lon,
|
||||
'zoom': zoom
|
||||
}
|
||||
return None
|
||||
|
||||
def parse_google_maps_url(url):
|
||||
"""
|
||||
Парсит URL Google Maps и извлекает lat, lon и zoom
|
||||
Формат: /@lat,lon,zoom[m|z]
|
||||
"""
|
||||
pattern = r'/@([-\d.]+),([-\d.]+),(\d+)([mz])'
|
||||
match = re.search(pattern, url)
|
||||
|
||||
if match:
|
||||
lon = float(match.group(1))
|
||||
lat = float(match.group(2))
|
||||
zoom_value = int(match.group(3))
|
||||
zoom_unit = match.group(4)
|
||||
|
||||
return {
|
||||
'lat': lat,
|
||||
'lon': lon,
|
||||
'zoom': zoom_value,
|
||||
'zoom_unit': zoom_unit
|
||||
}
|
||||
return None
|
||||
|
||||
def google_map_js_move_script(dx, dy) -> str:
|
||||
return """
|
||||
async function sleep(ms) {
|
||||
return new Promise((resolve, reject) => {
|
||||
setTimeout(() => resolve(), ms);
|
||||
});
|
||||
}
|
||||
|
||||
async function simulateDrag(element, offsetX, offsetY) {
|
||||
const rect = element.getBoundingClientRect();
|
||||
const startX = rect.left + rect.width / 2;
|
||||
const startY = rect.top + rect.height / 2;
|
||||
const step = 10
|
||||
const endX = startX + offsetX + step;
|
||||
const endY = startY + offsetY + step;
|
||||
|
||||
element.dispatchEvent(new MouseEvent('mousedown', {
|
||||
view: window,
|
||||
bubbles: true,
|
||||
cancelable: true,
|
||||
clientX: startX,
|
||||
clientY: startY,
|
||||
button: 0
|
||||
}));
|
||||
|
||||
let currentX = startX;
|
||||
let currentY = startY;
|
||||
|
||||
function move(stepX, stepY) {
|
||||
currentX += stepX;
|
||||
currentY += stepY;
|
||||
|
||||
document.dispatchEvent(new MouseEvent('mousemove', {
|
||||
view: window,
|
||||
bubbles: true,
|
||||
cancelable: false,
|
||||
clientX: currentX,
|
||||
clientY: currentY
|
||||
}));
|
||||
}
|
||||
|
||||
await sleep(50);
|
||||
move(step, step)
|
||||
|
||||
while (currentX != endX || currentY != endY) {
|
||||
await sleep(50);
|
||||
const stepX = Math.min(step, Math.max(-step, endX - currentX));
|
||||
const stepY = Math.min(step, Math.max(-step, endY - currentY));
|
||||
move(stepX, stepY);
|
||||
}
|
||||
|
||||
await sleep(50)
|
||||
document.dispatchEvent(new MouseEvent('mouseup', {
|
||||
view: window,
|
||||
bubbles: true,
|
||||
cancelable: false,
|
||||
clientX: endX,
|
||||
clientY: endY,
|
||||
button: 0
|
||||
}));
|
||||
}
|
||||
|
||||
{
|
||||
const canvas = document.querySelector('canvas.H1VXrf.JRr1M.DnOnV');
|
||||
""" + f"simulateDrag(canvas, {int(-dx)}, {int(dy)});" + """
|
||||
}
|
||||
"""
|
||||
@@ -1,10 +1,13 @@
|
||||
import constants
|
||||
import cv2
|
||||
import json
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from position import Position
|
||||
from timer import Timer
|
||||
from typing import Literal, Optional, Tuple
|
||||
from PIL import Image
|
||||
|
||||
FeatureMethod = Literal["orb", "sift", "akaze", "brisk"]
|
||||
DEFAULT_METHOD = "orb"
|
||||
@@ -14,6 +17,7 @@ class VisionChunk:
|
||||
image: Image.Image
|
||||
feature_method: FeatureMethod = DEFAULT_METHOD
|
||||
|
||||
pos: Optional[Position] = field(default=None, init=False)
|
||||
keypoints: Optional[list] = field(default=None, init=False)
|
||||
descriptors: Optional[np.ndarray] = field(default=None, init=False)
|
||||
_detector: Optional[cv2.Feature2D] = field(default=None, init=False, repr=False)
|
||||
@@ -27,12 +31,12 @@ class VisionChunk:
|
||||
self._detector = cv2.ORB_create(
|
||||
nfeatures=1000,
|
||||
scaleFactor=1.2,
|
||||
nlevels=32,
|
||||
nlevels=16,
|
||||
edgeThreshold=31,
|
||||
firstLevel=0,
|
||||
WTA_K=2,
|
||||
patchSize=31,
|
||||
fastThreshold=20,
|
||||
fastThreshold=10,
|
||||
)
|
||||
elif self.feature_method == "sift":
|
||||
self._detector = cv2.SIFT_create(
|
||||
@@ -70,30 +74,55 @@ class VisionChunk:
|
||||
return self._matcher
|
||||
|
||||
def _preprocess(self, img_np: np.ndarray) -> np.ndarray:
|
||||
"""CLAHE предобработка для улучшения контраста"""
|
||||
"""Предобработка для улучшения сопоставления между снимками разного времени"""
|
||||
if len(img_np.shape) == 3:
|
||||
gray = cv2.cvtColor(img_np, cv2.COLOR_RGB2GRAY)
|
||||
else:
|
||||
gray = img_np
|
||||
|
||||
|
||||
# Гауссовское размытие для подавления шума и мелких различий
|
||||
# blurred = cv2.GaussianBlur(gray, (5, 5), 1.0)
|
||||
|
||||
# CLAHE для выравнивания контраста между снимками
|
||||
clahe = cv2.createCLAHE(clipLimit=3.0, tileGridSize=(8, 8))
|
||||
return clahe.apply(gray)
|
||||
enhanced = clahe.apply(gray)
|
||||
|
||||
# Опционально: нормализация гистограммы для устранения различий в освещении
|
||||
normalized = cv2.normalize(enhanced, None, 0, 255, cv2.NORM_MINMAX)
|
||||
|
||||
return normalized
|
||||
|
||||
|
||||
def compute_keypoints(self, force: bool = False) -> Tuple[list[cv2.KeyPoint], Optional[np.ndarray]]:
|
||||
if self.keypoints is not None and self.descriptors is not None and not force:
|
||||
return self.keypoints, self.descriptors
|
||||
|
||||
timer = Timer()
|
||||
timer.start()
|
||||
detector = self._get_detector()
|
||||
|
||||
if constants.DEBUG_FPS:
|
||||
print(f"[VC-DETECTION]: get_detector: {timer.loop() * 1000:.2f} ms")
|
||||
|
||||
# PIL -> OpenCV (RGB->BGR)
|
||||
img_np = np.array(self.image)
|
||||
if img_np.ndim == 3 and img_np.shape[2] == 3:
|
||||
img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR)
|
||||
|
||||
if constants.DEBUG_FPS:
|
||||
print(f"[VC-DETECTION]: converting: {timer.loop() * 1000:.2f} ms")
|
||||
|
||||
# CLAHE предобработка
|
||||
preprocessed = self._preprocess(img_np)
|
||||
|
||||
if constants.DEBUG_FPS:
|
||||
print(f"[VC-DETECTION]: preprocess: {timer.loop() * 1000:.2f} ms")
|
||||
|
||||
keypoints, descriptors = detector.detectAndCompute(preprocessed, None)
|
||||
|
||||
if constants.DEBUG_FPS:
|
||||
print(f"[VC-DETECTION]: detect and compute: {timer.loop() * 1000:.2f} ms")
|
||||
|
||||
# Получаем массив response для всех точек
|
||||
responses = np.array([kp.response for kp in keypoints])
|
||||
|
||||
@@ -104,6 +133,9 @@ class VisionChunk:
|
||||
best_keypoints = [keypoints[i] for i in top_indices]
|
||||
best_descriptors = descriptors[top_indices]
|
||||
|
||||
if constants.DEBUG_FPS:
|
||||
print(f"[VC-DETECTION]: filtration: {timer.loop() * 1000:.2f} ms")
|
||||
|
||||
self.keypoints = best_keypoints
|
||||
self.descriptors = best_descriptors
|
||||
return self.keypoints, self.descriptors
|
||||
@@ -122,15 +154,29 @@ class VisionChunk:
|
||||
Возвращает: src_pts, dst_pts, good_matches, kp1, kp2 (отцентрированные координаты)
|
||||
"""
|
||||
# Вычисляем keypoints для обоих
|
||||
timer = Timer()
|
||||
timer.start()
|
||||
|
||||
kp1, des1 = self.compute_keypoints()
|
||||
|
||||
if constants.DEBUG_FPS:
|
||||
print(f"[VC-KEYPOINTS]: computing 1: {timer.loop() * 1000:.2f} ms")
|
||||
|
||||
kp2, des2 = other.compute_keypoints()
|
||||
|
||||
if constants.DEBUG_FPS:
|
||||
print(f"[VC-KEYPOINTS]: computing 2: {timer.loop() * 1000:.2f} ms")
|
||||
|
||||
if des1 is None or des2 is None or len(kp1) < 4 or len(kp2) < 4:
|
||||
return None, None, None, None, None
|
||||
|
||||
# kNN matching + Lowe ratio test
|
||||
matcher = self._get_matcher()
|
||||
matches_knn = matcher.knnMatch(des1, des2, k=2)
|
||||
|
||||
if constants.DEBUG_FPS:
|
||||
print(f"[VC-KEYPOINTS]: matching: {timer.loop() * 1000:.2f} ms")
|
||||
|
||||
good_matches: list[cv2.DMatch] = []
|
||||
|
||||
for m_n in matches_knn:
|
||||
@@ -147,15 +193,6 @@ class VisionChunk:
|
||||
if len(good_matches) < 4:
|
||||
return None, None, None, None, None
|
||||
|
||||
# Центр изображений
|
||||
img1_cv = self.to_cv2_gray()
|
||||
img2_cv = other.to_cv2_gray()
|
||||
h1, w1 = img1_cv.shape
|
||||
h2, w2 = img2_cv.shape
|
||||
cx1, cy1 = w1 // 2, h1 // 2
|
||||
cx2, cy2 = w2 // 2, h2 // 2
|
||||
|
||||
# Отцентрированные координаты (x_rel, y_rel)
|
||||
src_pts = []
|
||||
dst_pts = []
|
||||
|
||||
@@ -169,6 +206,9 @@ class VisionChunk:
|
||||
src_pts = np.float32(src_pts).reshape(-1, 1, 2)
|
||||
dst_pts = np.float32(dst_pts).reshape(-1, 1, 2)
|
||||
|
||||
if constants.DEBUG_FPS:
|
||||
print(f"[VC-KEYPOINTS]: filtration: {timer.loop() * 1000:.2f} ms")
|
||||
|
||||
return src_pts, dst_pts, good_matches, kp1, kp2
|
||||
|
||||
def to_cv2_gray(self) -> np.ndarray:
|
||||
|
||||
@@ -3,14 +3,16 @@
|
||||
Модуль для управления общим окном визуализации
|
||||
"""
|
||||
|
||||
from PIL import Image
|
||||
from enum import Enum
|
||||
from scipy.interpolate import make_interp_spline
|
||||
|
||||
import cv2
|
||||
import matplotlib
|
||||
import matplotlib.axes
|
||||
import matplotlib.pyplot as plt
|
||||
import matplotlib.patches as patches
|
||||
import numpy as np
|
||||
from enum import Enum
|
||||
import cv2
|
||||
from PIL import Image
|
||||
import matplotlib
|
||||
|
||||
# Настройки matplotlib
|
||||
matplotlib.use('TkAgg')
|
||||
@@ -93,14 +95,14 @@ class VisualizationManager:
|
||||
self.ax_matches.axis('off')
|
||||
|
||||
# Сопоставление точек (средний средний угол)
|
||||
self.ax_chunk_matches = self.fig.add_subplot(gs[1, 2])
|
||||
self.ax_chunk_matches = self.fig.add_subplot(gs[1, 1:3])
|
||||
self.ax_chunk_matches.set_title('Chunk Matching')
|
||||
self.ax_chunk_matches.axis('off')
|
||||
|
||||
# Визуализация движения ключевых точек (левый нижний угол)
|
||||
self.ax_motion_vectors = self.fig.add_subplot(gs[1, 1])
|
||||
self.ax_motion_vectors.set_title('Motion Vectors - Движение ключевых точек')
|
||||
self.ax_motion_vectors.axis('off')
|
||||
# self.ax_motion_vectors = self.fig.add_subplot(gs[1, 1])
|
||||
# self.ax_motion_vectors.set_title('Motion Vectors - Движение ключевых точек')
|
||||
# self.ax_motion_vectors.axis('off')
|
||||
|
||||
# Визуализация движения ключевых точек на основе матрицы гомографии
|
||||
self.ax_motion_gomography = self.fig.add_subplot(gs[0, 1])
|
||||
@@ -157,7 +159,7 @@ class VisualizationManager:
|
||||
# Рисуем текущую целевую точку
|
||||
if self.target_idx < len(self.target_pts):
|
||||
pt = self.target_pts[self.target_idx]
|
||||
self.ax_global_map.plot(pt[0], pt[1], 'yo', markersize=8, label='Цель (0, 0)')
|
||||
self.ax_global_map.plot(pt[0], pt[1], 'yo', markersize=8, label='Цель')
|
||||
|
||||
self.ax_global_map.legend()
|
||||
|
||||
@@ -208,7 +210,37 @@ class VisualizationManager:
|
||||
self.ax_error_plot.grid(True, alpha=0.3)
|
||||
|
||||
if len(self.error_times) > 1:
|
||||
self.ax_error_plot.plot(self.error_times, self.position_errors, 'b-', linewidth=2)
|
||||
# Оригинальный график (более прозрачный)
|
||||
self.ax_error_plot.plot(self.error_times, self.position_errors, 'b-',
|
||||
linewidth=1, alpha=0.4, label='Погрешность данных')
|
||||
|
||||
if len(self.error_times) > 5:
|
||||
# Сглаженный график
|
||||
smoothed_times = np.linspace(self.error_times[0], self.error_times[-1], 300)
|
||||
spl = make_interp_spline(self.error_times, self.position_errors, k=3)
|
||||
smoothed_errors = spl(smoothed_times)
|
||||
|
||||
|
||||
self.ax_error_plot.plot(smoothed_times, smoothed_errors, 'orange',
|
||||
linewidth=2, label='Сглаженный тренд')
|
||||
|
||||
# if len(self.position_errors) > 5: # Достаточно данных для сглаживания
|
||||
# window_size = min(11, len(self.position_errors) // 3) # Адаптивный размер окна
|
||||
# if window_size % 2 == 0: # Должен быть нечетным
|
||||
# window_size += 1
|
||||
|
||||
# # Метод скользящего среднего
|
||||
# smoothed_errors = np.convolve(
|
||||
# self.position_errors,
|
||||
# np.ones(window_size) / window_size,
|
||||
# mode='valid'
|
||||
# )
|
||||
|
||||
# # Корректируем временную ось для сглаженных данных
|
||||
# offset = (window_size - 1) // 2
|
||||
# smoothed_times = self.error_times[offset:offset + len(smoothed_errors)]
|
||||
|
||||
self.ax_error_plot.legend(loc='upper right')
|
||||
|
||||
# Автоматически масштабируем оси
|
||||
if len(self.position_errors) > 0:
|
||||
@@ -219,6 +251,7 @@ class VisualizationManager:
|
||||
else:
|
||||
self.ax_error_plot.set_ylim(0, 1)
|
||||
|
||||
|
||||
def update_matches(self, img1: np.ndarray, img2: np.ndarray,
|
||||
kp1, kp2, matches, transformation_info=None):
|
||||
"""Обновляет визуализацию сопоставления точек"""
|
||||
|
||||
@@ -1,16 +1,17 @@
|
||||
import math
|
||||
from io import BytesIO
|
||||
from time import sleep
|
||||
import os
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
from PIL import Image
|
||||
from io import BytesIO
|
||||
from selenium.webdriver.common.actions.wheel_input import ScrollOrigin
|
||||
from selenium import webdriver
|
||||
from selenium.webdriver.common.by import By
|
||||
from selenium.webdriver.common.action_chains import ActionChains
|
||||
from time import sleep
|
||||
|
||||
import constants
|
||||
import cv2
|
||||
import math
|
||||
import numpy as np
|
||||
import os
|
||||
import utility
|
||||
|
||||
def generateURL(lat: float, lon: float, zoom: int):
|
||||
return f"https://yandex.ru/maps/43/kazan/?l=sat&ll={lat}%2C{lon}&z={zoom}"
|
||||
@@ -24,6 +25,7 @@ class YandexMap:
|
||||
self.initial_lat = initial_lat
|
||||
self.initial_lon = initial_lon
|
||||
self.initial_zoom = initial_zoom
|
||||
self.pixel_ratio = constants.YANDEX_PIXEL_RATIO[self.initial_zoom]
|
||||
|
||||
options = webdriver.ChromeOptions()
|
||||
# options.add_experimental_option("detach", True)
|
||||
@@ -32,9 +34,8 @@ class YandexMap:
|
||||
self.driver.maximize_window()
|
||||
sleep(2)
|
||||
|
||||
action = ActionChains(self.driver)
|
||||
|
||||
# Закрытие левой панели
|
||||
action = ActionChains(self.driver)
|
||||
action.click(self.driver.find_element(By.CLASS_NAME, 'sidebar-toggle-button'))
|
||||
action.perform()
|
||||
|
||||
@@ -43,8 +44,25 @@ class YandexMap:
|
||||
self.driver.execute_script("arguments[0].remove();", self.driver.find_element(By.XPATH, "//nav[@class='map-controls']"))
|
||||
self.driver.execute_script("arguments[0].remove();", self.driver.find_element(By.XPATH, "//footer"))
|
||||
|
||||
self.move(39, -9)
|
||||
|
||||
sleep(0.2)
|
||||
|
||||
def open(self, lat, lon, zoom):
|
||||
self.initial_lat = lat
|
||||
self.initial_lon = lon
|
||||
self.initial_zoom = zoom
|
||||
self.pixel_ratio = constants.YANDEX_PIXEL_RATIO[self.initial_zoom]
|
||||
self.driver.get(generateURL(lat, lon, zoom))
|
||||
sleep(2)
|
||||
|
||||
# Закрытие левой панели
|
||||
action = ActionChains(self.driver)
|
||||
action.click(self.driver.find_element(By.CLASS_NAME, 'sidebar-toggle-button'))
|
||||
action.perform()
|
||||
|
||||
self.move(39, -9)
|
||||
|
||||
def save_photo(self, filename: str) -> bytes:
|
||||
return self.driver.save_screenshot(filename)
|
||||
|
||||
@@ -68,51 +86,26 @@ class YandexMap:
|
||||
if i != count - 1:
|
||||
sleep(0.25)
|
||||
|
||||
def make_as_center(self, x: float, y: float):
|
||||
def move(self, dx: float, dy: float):
|
||||
html = self.driver.find_element(By.TAG_NAME, 'html')
|
||||
|
||||
action = ActionChains(self.driver)
|
||||
action.move_to_element_with_offset(html, 0, 0)
|
||||
action.click_and_hold()
|
||||
|
||||
dx = (x - 0.5) * self.get_size()[0]
|
||||
dy = (0.5 - y) * self.get_size()[1]
|
||||
print(dx, dy)
|
||||
action.move_by_offset(-dx, dy)
|
||||
action.release()
|
||||
action.perform()
|
||||
|
||||
def make_screenshot(self, x: float, y: float, width: float, height: float) -> cv2.typing.MatLike:
|
||||
# Сохраняем скриншот
|
||||
self.save_photo("temp.png")
|
||||
|
||||
# Загружаем изображение
|
||||
image = cv2.imread("temp.png")
|
||||
|
||||
if image is None:
|
||||
raise ValueError("Не удалось загрузить изображение temp.png")
|
||||
|
||||
# Получаем размеры исходного изображения
|
||||
img_height, img_width = image.shape[:2]
|
||||
|
||||
# Преобразуем относительные координаты в абсолютные пиксели
|
||||
center_x = int(x * img_width)
|
||||
center_y = int(y * img_height)
|
||||
crop_width = int(width * img_width)
|
||||
crop_height = int(height * img_height)
|
||||
|
||||
# Вычисляем координаты прямоугольника для кадрирования
|
||||
x1 = max(0, center_x - crop_width // 2)
|
||||
y1 = max(0, center_y - crop_height // 2)
|
||||
x2 = min(img_width, center_x + crop_width // 2)
|
||||
y2 = min(img_height, center_y + crop_height // 2)
|
||||
|
||||
# Проверяем, что прямоугольник имеет положительные размеры
|
||||
if x2 <= x1 or y2 <= y1:
|
||||
raise ValueError("Некорректные размеры для кадрирования")
|
||||
|
||||
# Кадрируем изображение
|
||||
cropped_image = image[y1:y2, x1:x2]
|
||||
|
||||
# Если нужно вернуть изображение как результат функции:
|
||||
return cropped_image
|
||||
|
||||
def make_as_center(self, x: float, y: float):
|
||||
dx = (x - 0.5) * self.get_size()[0]
|
||||
dy = (0.5 - y) * self.get_size()[1]
|
||||
self.move(dx, dy)
|
||||
|
||||
def make_screenshot(self) -> Image.Image:
|
||||
png = self.driver.get_screenshot_as_png()
|
||||
im = Image.open(BytesIO(png))
|
||||
return utility.cv2_to_pil(np.array(im)[:, :])
|
||||
|
||||
def get_geolocation(self):
|
||||
current_url = self.driver.current_url
|
||||
return utility.parse_yandex_maps_url(current_url)
|
||||
|
||||
Reference in New Issue
Block a user