Files
autopilot/models/SiaN/dataloader.py
2026-04-04 20:26:56 +03:00

93 lines
3.8 KiB
Python

import os
import random
from typing import Tuple
import cv2
import numpy as np
import torch
from PIL import Image
from torch.utils.data import DataLoader, Dataset, Subset
from torchvision import transforms
from utils import config, get_camera_matrix, generate_random_homography_params, homography_params_to_matrix, matrix_to_homography_params
class YaGoDataset(Dataset):
def __init__(self, root_dir: str, transform=None, augment: bool = True,
image_size: Tuple[int, int] = (256, 256)):
self.root_dir = root_dir
self.transform = transform
self.augment = augment
self.image_size = image_size
self.K = get_camera_matrix(image_size[1], image_size[0])
self.image_pairs = self._discover_image_pairs()
def _discover_image_pairs(self):
pairs = []
for f in os.listdir(self.root_dir):
if f.endswith("_google.png"):
idx = f.split("_")[0]
yandex_path = os.path.join(self.root_dir, f"{idx}_yandex.png")
if os.path.exists(yandex_path):
pairs.append({"idx": int(idx), "google": os.path.join(self.root_dir, f), "yandex": yandex_path})
return sorted(pairs, key=lambda x: x["idx"])
def __len__(self):
return len(self.image_pairs)
def __getitem__(self, idx):
pair = self.image_pairs[idx]
google_img = Image.open(pair["google"]).convert("RGB").resize((self.image_size[1], self.image_size[0]), Image.BILINEAR)
yandex_img = Image.open(pair["yandex"]).convert("RGB").resize((self.image_size[1], self.image_size[0]), Image.BILINEAR)
if self.augment:
params1 = generate_random_homography_params()
params2 = generate_random_homography_params()
H1 = homography_params_to_matrix(params1, self.K)
H2 = homography_params_to_matrix(params2, self.K)
H_combined = np.linalg.inv(H1) @ H2
yandex_img = Image.fromarray(cv2.warpPerspective(np.array(yandex_img), H1, self.image_size))
google_img = Image.fromarray(cv2.warpPerspective(np.array(google_img), H2, self.image_size))
target_params = matrix_to_homography_params(H_combined, self.K)
target_matrix = H_combined
else:
target_params = np.zeros(6, dtype=np.float32)
target_matrix = np.eye(3, dtype=np.float32)
if self.transform:
google_img = self.transform(google_img)
yandex_img = self.transform(yandex_img)
return {
"google_img": google_img,
"yandex_img": yandex_img,
"homography_matrix": torch.from_numpy(target_matrix).float(),
"homography_params": torch.from_numpy(target_params).float(),
}
def create_data_loaders(root_dir, batch_size=32, train_split=0.8, num_workers=0,
image_size=(256, 256), augment_train=True):
transform = transforms.Compose([transforms.ToTensor()])
full_ds = YaGoDataset(root_dir, transform=transform, augment=False, image_size=image_size)
aug_ds = YaGoDataset(root_dir, transform=transform, augment=True, image_size=image_size)
indices = list(range(len(full_ds)))
random.shuffle(indices)
split = int(train_split * len(indices))
train_ds = Subset(aug_ds if augment_train else full_ds, indices[:split])
val_ds = Subset(full_ds, indices[split:])
return (DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=True),
DataLoader(val_ds, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=True))
if __name__ == "__main__":
ds = YaGoDataset(config["data_dir"], augment=True, image_size=config["image_size"])
print(f"Dataset size: {len(ds)}")
s = ds[0]
print(f"Keys: {list(s.keys())}")
print(f"Params: {s['homography_params'].numpy()}")