162 lines
4.6 KiB
Python
162 lines
4.6 KiB
Python
"""Data loader for Yandex-to-Google image translation."""
|
|
|
|
import os
|
|
from typing import Dict, List, Tuple
|
|
|
|
import torch
|
|
from PIL import Image
|
|
from torch.utils.data import DataLoader, Dataset
|
|
from torchvision import transforms
|
|
|
|
|
|
class YaGoDataset(Dataset):
|
|
"""Dataset loading pairs of Yandex and Google map images."""
|
|
|
|
def __init__(
|
|
self,
|
|
root_dir: str,
|
|
image_size: Tuple[int, int] = (256, 256),
|
|
augment: bool = False,
|
|
):
|
|
"""
|
|
Args:
|
|
root_dir: Directory with images named {idx:04d}_google.png and {idx:04d}_yandex.png
|
|
image_size: Target image size (height, width)
|
|
augment: Whether to apply augmentation (not implemented for simplicity)
|
|
"""
|
|
self.root_dir = root_dir
|
|
self.image_size = image_size
|
|
self.augment = augment
|
|
|
|
# Discover image pairs
|
|
self.pairs = self._find_pairs()
|
|
|
|
# Transform to tensor + normalization
|
|
self.transform = transforms.Compose(
|
|
[
|
|
transforms.ToTensor(),
|
|
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
|
|
]
|
|
)
|
|
|
|
def _find_pairs(self) -> List[Dict]:
|
|
"""Find all matching Google-Yandex image pairs."""
|
|
pairs = []
|
|
google_files = [f for f in os.listdir(self.root_dir) if f.endswith("_google.png")]
|
|
|
|
for google_file in sorted(google_files):
|
|
idx_str = google_file.split("_")[0]
|
|
try:
|
|
idx = int(idx_str)
|
|
except ValueError:
|
|
continue
|
|
|
|
yandex_file = f"{idx:04d}_yandex.png"
|
|
yandex_path = os.path.join(self.root_dir, yandex_file)
|
|
|
|
if os.path.exists(yandex_path):
|
|
pairs.append(
|
|
{
|
|
"idx": idx,
|
|
"google_path": os.path.join(self.root_dir, google_file),
|
|
"yandex_path": yandex_path,
|
|
}
|
|
)
|
|
|
|
return pairs
|
|
|
|
def __len__(self) -> int:
|
|
return len(self.pairs)
|
|
|
|
def __getitem__(self, idx: int) -> dict:
|
|
pair = self.pairs[idx]
|
|
|
|
# Load images
|
|
google_img = Image.open(pair["google_path"]).convert("RGB")
|
|
yandex_img = Image.open(pair["yandex_path"]).convert("RGB")
|
|
|
|
# Resize
|
|
google_img = google_img.resize((self.image_size[1], self.image_size[0]))
|
|
yandex_img = yandex_img.resize((self.image_size[1], self.image_size[0]))
|
|
|
|
# Apply transforms
|
|
google_tensor = self.transform(google_img)
|
|
yandex_tensor = self.transform(yandex_img)
|
|
|
|
return {
|
|
"google_img": google_tensor,
|
|
"yandex_img": yandex_tensor,
|
|
"idx": torch.tensor(pair["idx"], dtype=torch.long),
|
|
}
|
|
|
|
|
|
def create_data_loaders(
|
|
root_dir: str,
|
|
batch_size: int = 32,
|
|
train_split: float = 0.8,
|
|
num_workers: int = 0,
|
|
image_size: Tuple[int, int] = (256, 256),
|
|
) -> Tuple[DataLoader, DataLoader]:
|
|
"""
|
|
Create train and validation data loaders.
|
|
|
|
Args:
|
|
root_dir: Directory with image pairs
|
|
batch_size: Batch size
|
|
train_split: Fraction for training (0.0-1.0)
|
|
num_workers: DataLoader workers
|
|
image_size: Target image size
|
|
|
|
Returns:
|
|
(train_loader, val_loader)
|
|
"""
|
|
# Full dataset
|
|
dataset = YaGoDataset(root_dir=root_dir, image_size=image_size)
|
|
|
|
# Split
|
|
dataset_size = len(dataset)
|
|
train_size = int(train_split * dataset_size)
|
|
indices = torch.randperm(dataset_size).tolist()
|
|
train_indices = indices[:train_size]
|
|
val_indices = indices[train_size:]
|
|
|
|
# Subsets
|
|
from torch.utils.data import Subset
|
|
|
|
train_dataset = Subset(dataset, train_indices)
|
|
val_dataset = Subset(dataset, val_indices)
|
|
|
|
# DataLoaders
|
|
train_loader = DataLoader(
|
|
train_dataset,
|
|
batch_size=batch_size,
|
|
shuffle=True,
|
|
num_workers=num_workers,
|
|
pin_memory=True,
|
|
)
|
|
|
|
val_loader = DataLoader(
|
|
val_dataset,
|
|
batch_size=batch_size,
|
|
shuffle=False,
|
|
num_workers=num_workers,
|
|
pin_memory=True,
|
|
)
|
|
|
|
return train_loader, val_loader
|
|
|
|
|
|
if __name__ == "__main__":
|
|
# Quick test
|
|
from config import create_config
|
|
|
|
config = create_config()
|
|
train_loader, val_loader = create_data_loaders(
|
|
root_dir=config["data_dir"],
|
|
batch_size=4,
|
|
image_size=tuple(config["image_size"]),
|
|
)
|
|
|
|
batch = next(iter(train_loader))
|
|
print(f"Batch shapes: google={batch['google_img'].shape}, yandex={batch['yandex_img'].shape}")
|
|
print(f"Train batches: {len(train_loader)}, Val batches: {len(val_loader)}") |