feat: working model
This commit is contained in:
@@ -1,34 +1,6 @@
|
||||
config = {
|
||||
"learning_rate": 2e-4,
|
||||
"beta1": 0.5,
|
||||
"beta2": 0.999,
|
||||
"batch_size": 32,
|
||||
"epochs": 100,
|
||||
"gan_mode": "vanilla",
|
||||
"lambda_L1": 100.0,
|
||||
"grad_clip": 1.0,
|
||||
"early_stopping_patience": 20,
|
||||
"output_dir": "runs/gan_training",
|
||||
"log_interval": 10,
|
||||
"save_interval": 5,
|
||||
"data_dir": r"C:\Users\admin\Projects\autopilot\datasets\ya_go_maps\images",
|
||||
"image_size": [256, 256],
|
||||
"train_split": 0.8,
|
||||
"num_workers": 0,
|
||||
}
|
||||
|
||||
import os
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
import torch
|
||||
from PIL import Image
|
||||
from torch.utils.data import DataLoader, Dataset
|
||||
from torchvision import transforms
|
||||
|
||||
|
||||
import os
|
||||
import random
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
from typing import Tuple
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
@@ -37,238 +9,84 @@ from PIL import Image
|
||||
from torch.utils.data import DataLoader, Dataset, Subset
|
||||
from torchvision import transforms
|
||||
|
||||
from utils import config, get_camera_matrix, generate_random_homography_params, homography_params_to_matrix, matrix_to_homography_params
|
||||
|
||||
|
||||
class YaGoDataset(Dataset):
|
||||
def __init__(
|
||||
self,
|
||||
root_dir: str,
|
||||
transform=None,
|
||||
augment: bool = True,
|
||||
max_samples: Optional[int] = None,
|
||||
image_size: Tuple[int, int] = (700, 700),
|
||||
cache_homographies: bool = True,
|
||||
device=None,
|
||||
):
|
||||
def __init__(self, root_dir: str, transform=None, augment: bool = True,
|
||||
image_size: Tuple[int, int] = (256, 256)):
|
||||
self.root_dir = root_dir
|
||||
self.transform = transform
|
||||
self.augment = augment
|
||||
self.image_size = image_size
|
||||
self.cache_homographies = cache_homographies
|
||||
self.device = device
|
||||
self.K = get_camera_matrix(image_size[1], image_size[0])
|
||||
self.image_pairs = self._discover_image_pairs()
|
||||
if max_samples is not None:
|
||||
self.image_pairs = self.image_pairs[:max_samples]
|
||||
|
||||
def _discover_image_pairs(self) -> List[Dict[str, Any]]:
|
||||
image_pairs = []
|
||||
google_files = [f for f in os.listdir(self.root_dir) if f.endswith("_google.png")]
|
||||
for google_file in sorted(google_files):
|
||||
idx_str = google_file.split("_")[0]
|
||||
try:
|
||||
idx = int(idx_str)
|
||||
except ValueError:
|
||||
continue
|
||||
yandex_file = f"{idx:04d}_yandex.png"
|
||||
yandex_path = os.path.join(self.root_dir, yandex_file)
|
||||
if os.path.exists(yandex_path):
|
||||
image_pairs.append({
|
||||
"idx": idx,
|
||||
"google_path": os.path.join(self.root_dir, google_file),
|
||||
"yandex_path": yandex_path,
|
||||
})
|
||||
return image_pairs
|
||||
def _discover_image_pairs(self):
|
||||
pairs = []
|
||||
for f in os.listdir(self.root_dir):
|
||||
if f.endswith("_google.png"):
|
||||
idx = f.split("_")[0]
|
||||
yandex_path = os.path.join(self.root_dir, f"{idx}_yandex.png")
|
||||
if os.path.exists(yandex_path):
|
||||
pairs.append({"idx": int(idx), "google": os.path.join(self.root_dir, f), "yandex": yandex_path})
|
||||
return sorted(pairs, key=lambda x: x["idx"])
|
||||
|
||||
def __len__(self) -> int:
|
||||
def __len__(self):
|
||||
return len(self.image_pairs)
|
||||
|
||||
def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
|
||||
pair_info = self.image_pairs[idx]
|
||||
google_path = pair_info["google_path"]
|
||||
yandex_path = pair_info["yandex_path"]
|
||||
same_domain = True
|
||||
|
||||
if np.random.rand() > 0.5:
|
||||
random_idx = np.random.randint(0, len(self))
|
||||
google_path = self.image_pairs[random_idx]["google_path"]
|
||||
same_domain = random_idx == idx
|
||||
def __getitem__(self, idx):
|
||||
pair = self.image_pairs[idx]
|
||||
google_img = Image.open(pair["google"]).convert("RGB").resize((self.image_size[1], self.image_size[0]), Image.BILINEAR)
|
||||
yandex_img = Image.open(pair["yandex"]).convert("RGB").resize((self.image_size[1], self.image_size[0]), Image.BILINEAR)
|
||||
|
||||
yandex_img = Image.open(yandex_path).convert("RGB")
|
||||
google_img = Image.open(google_path).convert("RGB")
|
||||
|
||||
google_img = google_img.resize((self.image_size[1], self.image_size[0]), Image.BILINEAR)
|
||||
yandex_img = yandex_img.resize((self.image_size[1], self.image_size[0]), Image.BILINEAR)
|
||||
|
||||
matrices = self._get_homography_matrix(pair_info["idx"])
|
||||
|
||||
if self.augment:
|
||||
google_img, yandex_img, homography_matrix = self._apply_augmentation(
|
||||
google_img, yandex_img, matrices
|
||||
)
|
||||
homography_tensor = torch.from_numpy(homography_matrix).float()
|
||||
params1 = generate_random_homography_params()
|
||||
params2 = generate_random_homography_params()
|
||||
H1 = homography_params_to_matrix(params1, self.K)
|
||||
H2 = homography_params_to_matrix(params2, self.K)
|
||||
H_combined = np.linalg.inv(H1) @ H2
|
||||
yandex_img = Image.fromarray(cv2.warpPerspective(np.array(yandex_img), H1, self.image_size))
|
||||
google_img = Image.fromarray(cv2.warpPerspective(np.array(google_img), H2, self.image_size))
|
||||
target_params = matrix_to_homography_params(H_combined, self.K)
|
||||
target_matrix = H_combined
|
||||
else:
|
||||
homography_tensor = torch.from_numpy(np.eye(3))
|
||||
|
||||
target_params = np.zeros(6, dtype=np.float32)
|
||||
target_matrix = np.eye(3, dtype=np.float32)
|
||||
|
||||
if self.transform:
|
||||
google_img = self.transform(google_img)
|
||||
yandex_img = self.transform(yandex_img)
|
||||
else:
|
||||
google_img = torch.from_numpy(np.array(google_img)).float().permute(2, 0, 1) / 255.0
|
||||
yandex_img = torch.from_numpy(np.array(yandex_img)).float().permute(2, 0, 1) / 255.0
|
||||
|
||||
|
||||
|
||||
return {
|
||||
"google_img": google_img,
|
||||
"yandex_img": yandex_img,
|
||||
"homography": homography_tensor,
|
||||
"same_domain": same_domain,
|
||||
"idx": torch.tensor(pair_info["idx"], dtype=torch.long),
|
||||
"homography_matrix": torch.from_numpy(target_matrix).float(),
|
||||
"homography_params": torch.from_numpy(target_params).float(),
|
||||
}
|
||||
|
||||
def _get_homography_matrix(self, idx: int) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
|
||||
homography_matrix_1 = self.generate_random_homography()
|
||||
homography_matrix_2 = self.generate_random_homography()
|
||||
homography_matrix_r = np.linalg.inv(homography_matrix_1) @ homography_matrix_2
|
||||
return (homography_matrix_1, homography_matrix_2, homography_matrix_r)
|
||||
|
||||
def generate_random_homography(self) -> np.ndarray:
|
||||
scale = np.random.uniform(0.8, 1.2)
|
||||
tx = np.random.uniform(-0.50, 0.50)
|
||||
ty = np.random.uniform(-0.50, 0.50)
|
||||
|
||||
angle_x = np.random.uniform(np.radians(-10), np.radians(10))
|
||||
angle_y = np.random.uniform(np.radians(-10), np.radians(10))
|
||||
angle_z = np.random.uniform(np.radians(-10), np.radians(10))
|
||||
|
||||
cy, sy = np.cos(angle_z), np.sin(angle_z)
|
||||
cp, sp = np.cos(angle_y), np.sin(angle_y)
|
||||
cr, sr = np.cos(angle_x), np.sin(angle_x)
|
||||
|
||||
Rz = np.array([[cy, -sy, 0], [sy, cy, 0], [0, 0, 1]])
|
||||
Ry = np.array([[cp, 0, sp], [0, 1, 0], [-sp, 0, cp]])
|
||||
Rx = np.array([[1, 0, 0], [0, cr, -sr], [0, sr, cr]])
|
||||
|
||||
T = np.array([[1, 0, tx], [0, 1, ty], [0, 0, scale]])
|
||||
K = self.get_camera_matrix()
|
||||
return K @ Rx @ Ry @ Rz @ T @ np.linalg.inv(K)
|
||||
|
||||
def get_camera_matrix(self) -> np.ndarray:
|
||||
w, h = config["image_size"]
|
||||
return np.array([[w / 2, 0, w / 2], [0, h / 2, h / 2], [0, 0, 1]])
|
||||
|
||||
def _apply_augmentation(
|
||||
self,
|
||||
google_img: Image.Image,
|
||||
yandex_img: Image.Image,
|
||||
matrices: Tuple[np.ndarray, np.ndarray, np.ndarray],
|
||||
) -> Tuple[Image.Image, Image.Image, np.ndarray]:
|
||||
combined_homography = matrices[2]
|
||||
yandex_aug = self._apply_homography_to_image(yandex_img, matrices[0])
|
||||
google_aug = self._apply_homography_to_image(google_img, matrices[1])
|
||||
print("F", combined_homography, np.linalg.inv(matrices[0]) @ matrices[1])
|
||||
return google_aug, yandex_aug, combined_homography
|
||||
|
||||
def _apply_homography_to_image(
|
||||
self, img: Image.Image, homography: np.ndarray
|
||||
) -> Image.Image:
|
||||
img_np = np.array(img)
|
||||
h, w = img_np.shape[:2]
|
||||
transformed = cv2.warpPerspective(
|
||||
img_np, homography, (w, h), flags=cv2.INTER_LINEAR
|
||||
)
|
||||
return Image.fromarray(transformed)
|
||||
|
||||
|
||||
def create_data_loaders(
|
||||
root_dir: str,
|
||||
batch_size: int = 32,
|
||||
train_split: float = 0.8,
|
||||
num_workers: int = 4,
|
||||
image_size: Tuple[int, int] = (256, 256),
|
||||
augment_train: bool = True,
|
||||
augment_val: bool = False,
|
||||
device=None,
|
||||
) -> Tuple[DataLoader, DataLoader]:
|
||||
transform = transforms.Compose([
|
||||
transforms.ToTensor(),
|
||||
# transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
||||
])
|
||||
def create_data_loaders(root_dir, batch_size=32, train_split=0.8, num_workers=0,
|
||||
image_size=(256, 256), augment_train=True):
|
||||
transform = transforms.Compose([transforms.ToTensor()])
|
||||
|
||||
full_dataset = YaGoDataset(
|
||||
root_dir=root_dir,
|
||||
transform=transform,
|
||||
augment=False,
|
||||
image_size=image_size,
|
||||
cache_homographies=True,
|
||||
device=device,
|
||||
)
|
||||
|
||||
aug_dataset = YaGoDataset(
|
||||
root_dir=root_dir,
|
||||
transform=transform,
|
||||
augment=True,
|
||||
image_size=image_size,
|
||||
cache_homographies=False,
|
||||
device=device,
|
||||
)
|
||||
|
||||
dataset_size = len(full_dataset)
|
||||
train_size = int(train_split * dataset_size)
|
||||
val_size = dataset_size - train_size
|
||||
|
||||
indices = list(range(dataset_size))
|
||||
full_ds = YaGoDataset(root_dir, transform=transform, augment=False, image_size=image_size)
|
||||
aug_ds = YaGoDataset(root_dir, transform=transform, augment=True, image_size=image_size)
|
||||
|
||||
indices = list(range(len(full_ds)))
|
||||
random.shuffle(indices)
|
||||
train_indices = indices[:train_size]
|
||||
val_indices = indices[train_size:]
|
||||
split = int(train_split * len(indices))
|
||||
|
||||
train_dataset = Subset(full_dataset, train_indices)
|
||||
val_dataset = Subset(full_dataset, val_indices)
|
||||
|
||||
if augment_train:
|
||||
train_dataset = Subset(aug_dataset, train_indices)
|
||||
|
||||
train_loader = DataLoader(
|
||||
train_dataset,
|
||||
batch_size=batch_size,
|
||||
shuffle=True,
|
||||
num_workers=num_workers,
|
||||
pin_memory=True,
|
||||
)
|
||||
|
||||
val_loader = DataLoader(
|
||||
val_dataset,
|
||||
batch_size=batch_size,
|
||||
shuffle=False,
|
||||
num_workers=num_workers,
|
||||
pin_memory=True,
|
||||
)
|
||||
|
||||
return train_loader, val_loader
|
||||
train_ds = Subset(aug_ds if augment_train else full_ds, indices[:split])
|
||||
val_ds = Subset(full_ds, indices[split:])
|
||||
|
||||
return (DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=True),
|
||||
DataLoader(val_ds, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=True))
|
||||
|
||||
|
||||
|
||||
# Example usage
|
||||
dataset = YaGoDataset(
|
||||
root_dir=config["data_dir"],
|
||||
augment=True,
|
||||
image_size=(256, 256),
|
||||
)
|
||||
|
||||
print(f"Dataset size: {len(dataset)}")
|
||||
|
||||
# Get a sample
|
||||
sample = dataset[0]
|
||||
print(f"Sample keys: {list(sample.keys())}")
|
||||
print(f"Google image shape: {sample['google_img'].shape}")
|
||||
print(f"Yandex image shape: {sample['yandex_img'].shape}")
|
||||
print(f"Homography shape: {sample['homography'].shape}")
|
||||
|
||||
# Create data loaders
|
||||
train_loader, val_loader = create_data_loaders(
|
||||
root_dir=config["data_dir"],
|
||||
batch_size=16,
|
||||
train_split=0.8,
|
||||
)
|
||||
|
||||
print(f"Train batches: {len(train_loader)}")
|
||||
print(f"Val batches: {len(val_loader)}")
|
||||
if __name__ == "__main__":
|
||||
ds = YaGoDataset(config["data_dir"], augment=True, image_size=config["image_size"])
|
||||
print(f"Dataset size: {len(ds)}")
|
||||
s = ds[0]
|
||||
print(f"Keys: {list(s.keys())}")
|
||||
print(f"Params: {s['homography_params'].numpy()}")
|
||||
|
||||
Reference in New Issue
Block a user