93 lines
3.8 KiB
Python
93 lines
3.8 KiB
Python
import os
|
|
import random
|
|
from typing import Tuple
|
|
|
|
import cv2
|
|
import numpy as np
|
|
import torch
|
|
from PIL import Image
|
|
from torch.utils.data import DataLoader, Dataset, Subset
|
|
from torchvision import transforms
|
|
|
|
from utils import config, get_camera_matrix, generate_random_homography_params, homography_params_to_matrix, matrix_to_homography_params
|
|
|
|
|
|
class YaGoDataset(Dataset):
|
|
def __init__(self, root_dir: str, transform=None, augment: bool = True,
|
|
image_size: Tuple[int, int] = (256, 256)):
|
|
self.root_dir = root_dir
|
|
self.transform = transform
|
|
self.augment = augment
|
|
self.image_size = image_size
|
|
self.K = get_camera_matrix(image_size[1], image_size[0])
|
|
self.image_pairs = self._discover_image_pairs()
|
|
|
|
def _discover_image_pairs(self):
|
|
pairs = []
|
|
for f in os.listdir(self.root_dir):
|
|
if f.endswith("_google.png"):
|
|
idx = f.split("_")[0]
|
|
yandex_path = os.path.join(self.root_dir, f"{idx}_yandex.png")
|
|
if os.path.exists(yandex_path):
|
|
pairs.append({"idx": int(idx), "google": os.path.join(self.root_dir, f), "yandex": yandex_path})
|
|
return sorted(pairs, key=lambda x: x["idx"])
|
|
|
|
def __len__(self):
|
|
return len(self.image_pairs)
|
|
|
|
def __getitem__(self, idx):
|
|
pair = self.image_pairs[idx]
|
|
google_img = Image.open(pair["google"]).convert("RGB").resize((self.image_size[1], self.image_size[0]), Image.BILINEAR)
|
|
yandex_img = Image.open(pair["yandex"]).convert("RGB").resize((self.image_size[1], self.image_size[0]), Image.BILINEAR)
|
|
|
|
if self.augment:
|
|
params1 = generate_random_homography_params()
|
|
params2 = generate_random_homography_params()
|
|
H1 = homography_params_to_matrix(params1, self.K)
|
|
H2 = homography_params_to_matrix(params2, self.K)
|
|
H_combined = np.linalg.inv(H1) @ H2
|
|
yandex_img = Image.fromarray(cv2.warpPerspective(np.array(yandex_img), H1, self.image_size))
|
|
google_img = Image.fromarray(cv2.warpPerspective(np.array(google_img), H2, self.image_size))
|
|
target_params = matrix_to_homography_params(H_combined, self.K)
|
|
target_matrix = H_combined
|
|
else:
|
|
target_params = np.zeros(6, dtype=np.float32)
|
|
target_matrix = np.eye(3, dtype=np.float32)
|
|
|
|
if self.transform:
|
|
google_img = self.transform(google_img)
|
|
yandex_img = self.transform(yandex_img)
|
|
|
|
return {
|
|
"google_img": google_img,
|
|
"yandex_img": yandex_img,
|
|
"homography_matrix": torch.from_numpy(target_matrix).float(),
|
|
"homography_params": torch.from_numpy(target_params).float(),
|
|
}
|
|
|
|
|
|
def create_data_loaders(root_dir, batch_size=32, train_split=0.8, num_workers=0,
|
|
image_size=(256, 256), augment_train=True):
|
|
transform = transforms.Compose([transforms.ToTensor()])
|
|
|
|
full_ds = YaGoDataset(root_dir, transform=transform, augment=False, image_size=image_size)
|
|
aug_ds = YaGoDataset(root_dir, transform=transform, augment=True, image_size=image_size)
|
|
|
|
indices = list(range(len(full_ds)))
|
|
random.shuffle(indices)
|
|
split = int(train_split * len(indices))
|
|
|
|
train_ds = Subset(aug_ds if augment_train else full_ds, indices[:split])
|
|
val_ds = Subset(full_ds, indices[split:])
|
|
|
|
return (DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=True),
|
|
DataLoader(val_ds, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=True))
|
|
|
|
|
|
if __name__ == "__main__":
|
|
ds = YaGoDataset(config["data_dir"], augment=True, image_size=config["image_size"])
|
|
print(f"Dataset size: {len(ds)}")
|
|
s = ds[0]
|
|
print(f"Keys: {list(s.keys())}")
|
|
print(f"Params: {s['homography_params'].numpy()}")
|