feat: working model

This commit is contained in:
2026-04-04 20:26:56 +03:00
parent 4b398f6c9a
commit 703ea8dbaf
6 changed files with 754 additions and 579 deletions

View File

@@ -1,34 +1,6 @@
config = {
"learning_rate": 2e-4,
"beta1": 0.5,
"beta2": 0.999,
"batch_size": 32,
"epochs": 100,
"gan_mode": "vanilla",
"lambda_L1": 100.0,
"grad_clip": 1.0,
"early_stopping_patience": 20,
"output_dir": "runs/gan_training",
"log_interval": 10,
"save_interval": 5,
"data_dir": r"C:\Users\admin\Projects\autopilot\datasets\ya_go_maps\images",
"image_size": [256, 256],
"train_split": 0.8,
"num_workers": 0,
}
import os
from typing import Dict, List, Tuple
import torch
from PIL import Image
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
import os
import random
from typing import Any, Dict, List, Optional, Tuple
from typing import Tuple
import cv2
import numpy as np
@@ -37,238 +9,84 @@ from PIL import Image
from torch.utils.data import DataLoader, Dataset, Subset
from torchvision import transforms
from utils import config, get_camera_matrix, generate_random_homography_params, homography_params_to_matrix, matrix_to_homography_params
class YaGoDataset(Dataset):
def __init__(
self,
root_dir: str,
transform=None,
augment: bool = True,
max_samples: Optional[int] = None,
image_size: Tuple[int, int] = (700, 700),
cache_homographies: bool = True,
device=None,
):
def __init__(self, root_dir: str, transform=None, augment: bool = True,
image_size: Tuple[int, int] = (256, 256)):
self.root_dir = root_dir
self.transform = transform
self.augment = augment
self.image_size = image_size
self.cache_homographies = cache_homographies
self.device = device
self.K = get_camera_matrix(image_size[1], image_size[0])
self.image_pairs = self._discover_image_pairs()
if max_samples is not None:
self.image_pairs = self.image_pairs[:max_samples]
def _discover_image_pairs(self) -> List[Dict[str, Any]]:
image_pairs = []
google_files = [f for f in os.listdir(self.root_dir) if f.endswith("_google.png")]
for google_file in sorted(google_files):
idx_str = google_file.split("_")[0]
try:
idx = int(idx_str)
except ValueError:
continue
yandex_file = f"{idx:04d}_yandex.png"
yandex_path = os.path.join(self.root_dir, yandex_file)
if os.path.exists(yandex_path):
image_pairs.append({
"idx": idx,
"google_path": os.path.join(self.root_dir, google_file),
"yandex_path": yandex_path,
})
return image_pairs
def _discover_image_pairs(self):
pairs = []
for f in os.listdir(self.root_dir):
if f.endswith("_google.png"):
idx = f.split("_")[0]
yandex_path = os.path.join(self.root_dir, f"{idx}_yandex.png")
if os.path.exists(yandex_path):
pairs.append({"idx": int(idx), "google": os.path.join(self.root_dir, f), "yandex": yandex_path})
return sorted(pairs, key=lambda x: x["idx"])
def __len__(self) -> int:
def __len__(self):
return len(self.image_pairs)
def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
pair_info = self.image_pairs[idx]
google_path = pair_info["google_path"]
yandex_path = pair_info["yandex_path"]
same_domain = True
if np.random.rand() > 0.5:
random_idx = np.random.randint(0, len(self))
google_path = self.image_pairs[random_idx]["google_path"]
same_domain = random_idx == idx
def __getitem__(self, idx):
pair = self.image_pairs[idx]
google_img = Image.open(pair["google"]).convert("RGB").resize((self.image_size[1], self.image_size[0]), Image.BILINEAR)
yandex_img = Image.open(pair["yandex"]).convert("RGB").resize((self.image_size[1], self.image_size[0]), Image.BILINEAR)
yandex_img = Image.open(yandex_path).convert("RGB")
google_img = Image.open(google_path).convert("RGB")
google_img = google_img.resize((self.image_size[1], self.image_size[0]), Image.BILINEAR)
yandex_img = yandex_img.resize((self.image_size[1], self.image_size[0]), Image.BILINEAR)
matrices = self._get_homography_matrix(pair_info["idx"])
if self.augment:
google_img, yandex_img, homography_matrix = self._apply_augmentation(
google_img, yandex_img, matrices
)
homography_tensor = torch.from_numpy(homography_matrix).float()
params1 = generate_random_homography_params()
params2 = generate_random_homography_params()
H1 = homography_params_to_matrix(params1, self.K)
H2 = homography_params_to_matrix(params2, self.K)
H_combined = np.linalg.inv(H1) @ H2
yandex_img = Image.fromarray(cv2.warpPerspective(np.array(yandex_img), H1, self.image_size))
google_img = Image.fromarray(cv2.warpPerspective(np.array(google_img), H2, self.image_size))
target_params = matrix_to_homography_params(H_combined, self.K)
target_matrix = H_combined
else:
homography_tensor = torch.from_numpy(np.eye(3))
target_params = np.zeros(6, dtype=np.float32)
target_matrix = np.eye(3, dtype=np.float32)
if self.transform:
google_img = self.transform(google_img)
yandex_img = self.transform(yandex_img)
else:
google_img = torch.from_numpy(np.array(google_img)).float().permute(2, 0, 1) / 255.0
yandex_img = torch.from_numpy(np.array(yandex_img)).float().permute(2, 0, 1) / 255.0
return {
"google_img": google_img,
"yandex_img": yandex_img,
"homography": homography_tensor,
"same_domain": same_domain,
"idx": torch.tensor(pair_info["idx"], dtype=torch.long),
"homography_matrix": torch.from_numpy(target_matrix).float(),
"homography_params": torch.from_numpy(target_params).float(),
}
def _get_homography_matrix(self, idx: int) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
homography_matrix_1 = self.generate_random_homography()
homography_matrix_2 = self.generate_random_homography()
homography_matrix_r = np.linalg.inv(homography_matrix_1) @ homography_matrix_2
return (homography_matrix_1, homography_matrix_2, homography_matrix_r)
def generate_random_homography(self) -> np.ndarray:
scale = np.random.uniform(0.8, 1.2)
tx = np.random.uniform(-0.50, 0.50)
ty = np.random.uniform(-0.50, 0.50)
angle_x = np.random.uniform(np.radians(-10), np.radians(10))
angle_y = np.random.uniform(np.radians(-10), np.radians(10))
angle_z = np.random.uniform(np.radians(-10), np.radians(10))
cy, sy = np.cos(angle_z), np.sin(angle_z)
cp, sp = np.cos(angle_y), np.sin(angle_y)
cr, sr = np.cos(angle_x), np.sin(angle_x)
Rz = np.array([[cy, -sy, 0], [sy, cy, 0], [0, 0, 1]])
Ry = np.array([[cp, 0, sp], [0, 1, 0], [-sp, 0, cp]])
Rx = np.array([[1, 0, 0], [0, cr, -sr], [0, sr, cr]])
T = np.array([[1, 0, tx], [0, 1, ty], [0, 0, scale]])
K = self.get_camera_matrix()
return K @ Rx @ Ry @ Rz @ T @ np.linalg.inv(K)
def get_camera_matrix(self) -> np.ndarray:
w, h = config["image_size"]
return np.array([[w / 2, 0, w / 2], [0, h / 2, h / 2], [0, 0, 1]])
def _apply_augmentation(
self,
google_img: Image.Image,
yandex_img: Image.Image,
matrices: Tuple[np.ndarray, np.ndarray, np.ndarray],
) -> Tuple[Image.Image, Image.Image, np.ndarray]:
combined_homography = matrices[2]
yandex_aug = self._apply_homography_to_image(yandex_img, matrices[0])
google_aug = self._apply_homography_to_image(google_img, matrices[1])
print("F", combined_homography, np.linalg.inv(matrices[0]) @ matrices[1])
return google_aug, yandex_aug, combined_homography
def _apply_homography_to_image(
self, img: Image.Image, homography: np.ndarray
) -> Image.Image:
img_np = np.array(img)
h, w = img_np.shape[:2]
transformed = cv2.warpPerspective(
img_np, homography, (w, h), flags=cv2.INTER_LINEAR
)
return Image.fromarray(transformed)
def create_data_loaders(
root_dir: str,
batch_size: int = 32,
train_split: float = 0.8,
num_workers: int = 4,
image_size: Tuple[int, int] = (256, 256),
augment_train: bool = True,
augment_val: bool = False,
device=None,
) -> Tuple[DataLoader, DataLoader]:
transform = transforms.Compose([
transforms.ToTensor(),
# transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
def create_data_loaders(root_dir, batch_size=32, train_split=0.8, num_workers=0,
image_size=(256, 256), augment_train=True):
transform = transforms.Compose([transforms.ToTensor()])
full_dataset = YaGoDataset(
root_dir=root_dir,
transform=transform,
augment=False,
image_size=image_size,
cache_homographies=True,
device=device,
)
aug_dataset = YaGoDataset(
root_dir=root_dir,
transform=transform,
augment=True,
image_size=image_size,
cache_homographies=False,
device=device,
)
dataset_size = len(full_dataset)
train_size = int(train_split * dataset_size)
val_size = dataset_size - train_size
indices = list(range(dataset_size))
full_ds = YaGoDataset(root_dir, transform=transform, augment=False, image_size=image_size)
aug_ds = YaGoDataset(root_dir, transform=transform, augment=True, image_size=image_size)
indices = list(range(len(full_ds)))
random.shuffle(indices)
train_indices = indices[:train_size]
val_indices = indices[train_size:]
split = int(train_split * len(indices))
train_dataset = Subset(full_dataset, train_indices)
val_dataset = Subset(full_dataset, val_indices)
if augment_train:
train_dataset = Subset(aug_dataset, train_indices)
train_loader = DataLoader(
train_dataset,
batch_size=batch_size,
shuffle=True,
num_workers=num_workers,
pin_memory=True,
)
val_loader = DataLoader(
val_dataset,
batch_size=batch_size,
shuffle=False,
num_workers=num_workers,
pin_memory=True,
)
return train_loader, val_loader
train_ds = Subset(aug_ds if augment_train else full_ds, indices[:split])
val_ds = Subset(full_ds, indices[split:])
return (DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=True),
DataLoader(val_ds, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=True))
# Example usage
dataset = YaGoDataset(
root_dir=config["data_dir"],
augment=True,
image_size=(256, 256),
)
print(f"Dataset size: {len(dataset)}")
# Get a sample
sample = dataset[0]
print(f"Sample keys: {list(sample.keys())}")
print(f"Google image shape: {sample['google_img'].shape}")
print(f"Yandex image shape: {sample['yandex_img'].shape}")
print(f"Homography shape: {sample['homography'].shape}")
# Create data loaders
train_loader, val_loader = create_data_loaders(
root_dir=config["data_dir"],
batch_size=16,
train_split=0.8,
)
print(f"Train batches: {len(train_loader)}")
print(f"Val batches: {len(val_loader)}")
if __name__ == "__main__":
ds = YaGoDataset(config["data_dir"], augment=True, image_size=config["image_size"])
print(f"Dataset size: {len(ds)}")
s = ds[0]
print(f"Keys: {list(s.keys())}")
print(f"Params: {s['homography_params'].numpy()}")