import os import random from typing import Tuple import cv2 import numpy as np import torch from PIL import Image from torch.utils.data import DataLoader, Dataset, Subset from torchvision import transforms from utils import config, get_camera_matrix, generate_random_homography_params, homography_params_to_matrix, matrix_to_homography_params class YaGoDataset(Dataset): def __init__(self, root_dir: str, transform=None, augment: bool = True, image_size: Tuple[int, int] = (256, 256)): self.root_dir = root_dir self.transform = transform self.augment = augment self.image_size = image_size self.K = get_camera_matrix(image_size[1], image_size[0]) self.image_pairs = self._discover_image_pairs() def _discover_image_pairs(self): pairs = [] for f in os.listdir(self.root_dir): if f.endswith("_google.png"): idx = f.split("_")[0] yandex_path = os.path.join(self.root_dir, f"{idx}_yandex.png") if os.path.exists(yandex_path): pairs.append({"idx": int(idx), "google": os.path.join(self.root_dir, f), "yandex": yandex_path}) return sorted(pairs, key=lambda x: x["idx"]) def __len__(self): return len(self.image_pairs) def __getitem__(self, idx): pair = self.image_pairs[idx] google_img = Image.open(pair["google"]).convert("RGB").resize((self.image_size[1], self.image_size[0]), Image.BILINEAR) yandex_img = Image.open(pair["yandex"]).convert("RGB").resize((self.image_size[1], self.image_size[0]), Image.BILINEAR) if self.augment: params1 = generate_random_homography_params() params2 = generate_random_homography_params() H1 = homography_params_to_matrix(params1, self.K) H2 = homography_params_to_matrix(params2, self.K) H_combined = np.linalg.inv(H1) @ H2 yandex_img = Image.fromarray(cv2.warpPerspective(np.array(yandex_img), H1, self.image_size)) google_img = Image.fromarray(cv2.warpPerspective(np.array(google_img), H2, self.image_size)) target_params = matrix_to_homography_params(H_combined, self.K) target_matrix = H_combined else: target_params = np.zeros(6, dtype=np.float32) target_matrix = np.eye(3, dtype=np.float32) if self.transform: google_img = self.transform(google_img) yandex_img = self.transform(yandex_img) return { "google_img": google_img, "yandex_img": yandex_img, "homography_matrix": torch.from_numpy(target_matrix).float(), "homography_params": torch.from_numpy(target_params).float(), } def create_data_loaders(root_dir, batch_size=32, train_split=0.8, num_workers=0, image_size=(256, 256), augment_train=True): transform = transforms.Compose([transforms.ToTensor()]) full_ds = YaGoDataset(root_dir, transform=transform, augment=False, image_size=image_size) aug_ds = YaGoDataset(root_dir, transform=transform, augment=True, image_size=image_size) indices = list(range(len(full_ds))) random.shuffle(indices) split = int(train_split * len(indices)) train_ds = Subset(aug_ds if augment_train else full_ds, indices[:split]) val_ds = Subset(full_ds, indices[split:]) return (DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=True), DataLoader(val_ds, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=True)) if __name__ == "__main__": ds = YaGoDataset(config["data_dir"], augment=True, image_size=config["image_size"]) print(f"Dataset size: {len(ds)}") s = ds[0] print(f"Keys: {list(s.keys())}") print(f"Params: {s['homography_params'].numpy()}")