"""Data loader for Yandex-to-Google image translation.""" import os from typing import Dict, List, Tuple import torch from PIL import Image from torch.utils.data import DataLoader, Dataset from torchvision import transforms class YaGoDataset(Dataset): """Dataset loading pairs of Yandex and Google map images.""" def __init__( self, root_dir: str, image_size: Tuple[int, int] = (256, 256), augment: bool = False, ): """ Args: root_dir: Directory with images named {idx:04d}_google.png and {idx:04d}_yandex.png image_size: Target image size (height, width) augment: Whether to apply augmentation (not implemented for simplicity) """ self.root_dir = root_dir self.image_size = image_size self.augment = augment # Discover image pairs self.pairs = self._find_pairs() # Transform to tensor + normalization self.transform = transforms.Compose( [ transforms.ToTensor(), transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), ] ) def _find_pairs(self) -> List[Dict]: """Find all matching Google-Yandex image pairs.""" pairs = [] google_files = [f for f in os.listdir(self.root_dir) if f.endswith("_google.png")] for google_file in sorted(google_files): idx_str = google_file.split("_")[0] try: idx = int(idx_str) except ValueError: continue yandex_file = f"{idx:04d}_yandex.png" yandex_path = os.path.join(self.root_dir, yandex_file) if os.path.exists(yandex_path): pairs.append( { "idx": idx, "google_path": os.path.join(self.root_dir, google_file), "yandex_path": yandex_path, } ) return pairs def __len__(self) -> int: return len(self.pairs) def __getitem__(self, idx: int) -> dict: pair = self.pairs[idx] # Load images google_img = Image.open(pair["google_path"]).convert("RGB") yandex_img = Image.open(pair["yandex_path"]).convert("RGB") # Resize google_img = google_img.resize((self.image_size[1], self.image_size[0])) yandex_img = yandex_img.resize((self.image_size[1], self.image_size[0])) # Apply transforms google_tensor = self.transform(google_img) yandex_tensor = self.transform(yandex_img) return { "google_img": google_tensor, "yandex_img": yandex_tensor, "idx": torch.tensor(pair["idx"], dtype=torch.long), } def create_data_loaders( root_dir: str, batch_size: int = 32, train_split: float = 0.8, num_workers: int = 0, image_size: Tuple[int, int] = (256, 256), ) -> Tuple[DataLoader, DataLoader]: """ Create train and validation data loaders. Args: root_dir: Directory with image pairs batch_size: Batch size train_split: Fraction for training (0.0-1.0) num_workers: DataLoader workers image_size: Target image size Returns: (train_loader, val_loader) """ # Full dataset dataset = YaGoDataset(root_dir=root_dir, image_size=image_size) # Split dataset_size = len(dataset) train_size = int(train_split * dataset_size) indices = torch.randperm(dataset_size).tolist() train_indices = indices[:train_size] val_indices = indices[train_size:] # Subsets from torch.utils.data import Subset train_dataset = Subset(dataset, train_indices) val_dataset = Subset(dataset, val_indices) # DataLoaders train_loader = DataLoader( train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=True, ) val_loader = DataLoader( val_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=True, ) return train_loader, val_loader if __name__ == "__main__": # Quick test from config import create_config config = create_config() train_loader, val_loader = create_data_loaders( root_dir=config["data_dir"], batch_size=4, image_size=tuple(config["image_size"]), ) batch = next(iter(train_loader)) print(f"Batch shapes: google={batch['google_img'].shape}, yandex={batch['yandex_img'].shape}") print(f"Train batches: {len(train_loader)}, Val batches: {len(val_loader)}")