Files
autopilot/models/GAN/dataloader.py

162 lines
4.6 KiB
Python

"""Data loader for Yandex-to-Google image translation."""
import os
from typing import Dict, List, Tuple
import torch
from PIL import Image
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
class YaGoDataset(Dataset):
"""Dataset loading pairs of Yandex and Google map images."""
def __init__(
self,
root_dir: str,
image_size: Tuple[int, int] = (256, 256),
augment: bool = False,
):
"""
Args:
root_dir: Directory with images named {idx:04d}_google.png and {idx:04d}_yandex.png
image_size: Target image size (height, width)
augment: Whether to apply augmentation (not implemented for simplicity)
"""
self.root_dir = root_dir
self.image_size = image_size
self.augment = augment
# Discover image pairs
self.pairs = self._find_pairs()
# Transform to tensor + normalization
self.transform = transforms.Compose(
[
transforms.ToTensor(),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
]
)
def _find_pairs(self) -> List[Dict]:
"""Find all matching Google-Yandex image pairs."""
pairs = []
google_files = [f for f in os.listdir(self.root_dir) if f.endswith("_google.png")]
for google_file in sorted(google_files):
idx_str = google_file.split("_")[0]
try:
idx = int(idx_str)
except ValueError:
continue
yandex_file = f"{idx:04d}_yandex.png"
yandex_path = os.path.join(self.root_dir, yandex_file)
if os.path.exists(yandex_path):
pairs.append(
{
"idx": idx,
"google_path": os.path.join(self.root_dir, google_file),
"yandex_path": yandex_path,
}
)
return pairs
def __len__(self) -> int:
return len(self.pairs)
def __getitem__(self, idx: int) -> dict:
pair = self.pairs[idx]
# Load images
google_img = Image.open(pair["google_path"]).convert("RGB")
yandex_img = Image.open(pair["yandex_path"]).convert("RGB")
# Resize
google_img = google_img.resize((self.image_size[1], self.image_size[0]))
yandex_img = yandex_img.resize((self.image_size[1], self.image_size[0]))
# Apply transforms
google_tensor = self.transform(google_img)
yandex_tensor = self.transform(yandex_img)
return {
"google_img": google_tensor,
"yandex_img": yandex_tensor,
"idx": torch.tensor(pair["idx"], dtype=torch.long),
}
def create_data_loaders(
root_dir: str,
batch_size: int = 32,
train_split: float = 0.8,
num_workers: int = 0,
image_size: Tuple[int, int] = (256, 256),
) -> Tuple[DataLoader, DataLoader]:
"""
Create train and validation data loaders.
Args:
root_dir: Directory with image pairs
batch_size: Batch size
train_split: Fraction for training (0.0-1.0)
num_workers: DataLoader workers
image_size: Target image size
Returns:
(train_loader, val_loader)
"""
# Full dataset
dataset = YaGoDataset(root_dir=root_dir, image_size=image_size)
# Split
dataset_size = len(dataset)
train_size = int(train_split * dataset_size)
indices = torch.randperm(dataset_size).tolist()
train_indices = indices[:train_size]
val_indices = indices[train_size:]
# Subsets
from torch.utils.data import Subset
train_dataset = Subset(dataset, train_indices)
val_dataset = Subset(dataset, val_indices)
# DataLoaders
train_loader = DataLoader(
train_dataset,
batch_size=batch_size,
shuffle=True,
num_workers=num_workers,
pin_memory=True,
)
val_loader = DataLoader(
val_dataset,
batch_size=batch_size,
shuffle=False,
num_workers=num_workers,
pin_memory=True,
)
return train_loader, val_loader
if __name__ == "__main__":
# Quick test
from config import create_config
config = create_config()
train_loader, val_loader = create_data_loaders(
root_dir=config["data_dir"],
batch_size=4,
image_size=tuple(config["image_size"]),
)
batch = next(iter(train_loader))
print(f"Batch shapes: google={batch['google_img'].shape}, yandex={batch['yandex_img'].shape}")
print(f"Train batches: {len(train_loader)}, Val batches: {len(val_loader)}")