ref: simplify and modularize GAN codebase
This commit is contained in:
162
models/GAN/dataloader.py
Normal file
162
models/GAN/dataloader.py
Normal file
@@ -0,0 +1,162 @@
|
||||
"""Data loader for Yandex-to-Google image translation."""
|
||||
|
||||
import os
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
import torch
|
||||
from PIL import Image
|
||||
from torch.utils.data import DataLoader, Dataset
|
||||
from torchvision import transforms
|
||||
|
||||
|
||||
class YaGoDataset(Dataset):
|
||||
"""Dataset loading pairs of Yandex and Google map images."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
root_dir: str,
|
||||
image_size: Tuple[int, int] = (256, 256),
|
||||
augment: bool = False,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
root_dir: Directory with images named {idx:04d}_google.png and {idx:04d}_yandex.png
|
||||
image_size: Target image size (height, width)
|
||||
augment: Whether to apply augmentation (not implemented for simplicity)
|
||||
"""
|
||||
self.root_dir = root_dir
|
||||
self.image_size = image_size
|
||||
self.augment = augment
|
||||
|
||||
# Discover image pairs
|
||||
self.pairs = self._find_pairs()
|
||||
|
||||
# Transform to tensor + normalization
|
||||
self.transform = transforms.Compose(
|
||||
[
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
|
||||
]
|
||||
)
|
||||
|
||||
def _find_pairs(self) -> List[Dict]:
|
||||
"""Find all matching Google-Yandex image pairs."""
|
||||
pairs = []
|
||||
google_files = [f for f in os.listdir(self.root_dir) if f.endswith("_google.png")]
|
||||
|
||||
for google_file in sorted(google_files):
|
||||
idx_str = google_file.split("_")[0]
|
||||
try:
|
||||
idx = int(idx_str)
|
||||
except ValueError:
|
||||
continue
|
||||
|
||||
yandex_file = f"{idx:04d}_yandex.png"
|
||||
yandex_path = os.path.join(self.root_dir, yandex_file)
|
||||
|
||||
if os.path.exists(yandex_path):
|
||||
pairs.append(
|
||||
{
|
||||
"idx": idx,
|
||||
"google_path": os.path.join(self.root_dir, google_file),
|
||||
"yandex_path": yandex_path,
|
||||
}
|
||||
)
|
||||
|
||||
return pairs
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.pairs)
|
||||
|
||||
def __getitem__(self, idx: int) -> dict:
|
||||
pair = self.pairs[idx]
|
||||
|
||||
# Load images
|
||||
google_img = Image.open(pair["google_path"]).convert("RGB")
|
||||
yandex_img = Image.open(pair["yandex_path"]).convert("RGB")
|
||||
|
||||
# Resize
|
||||
google_img = google_img.resize((self.image_size[1], self.image_size[0]))
|
||||
yandex_img = yandex_img.resize((self.image_size[1], self.image_size[0]))
|
||||
|
||||
# Apply transforms
|
||||
google_tensor = self.transform(google_img)
|
||||
yandex_tensor = self.transform(yandex_img)
|
||||
|
||||
return {
|
||||
"google_img": google_tensor,
|
||||
"yandex_img": yandex_tensor,
|
||||
"idx": torch.tensor(pair["idx"], dtype=torch.long),
|
||||
}
|
||||
|
||||
|
||||
def create_data_loaders(
|
||||
root_dir: str,
|
||||
batch_size: int = 32,
|
||||
train_split: float = 0.8,
|
||||
num_workers: int = 0,
|
||||
image_size: Tuple[int, int] = (256, 256),
|
||||
) -> Tuple[DataLoader, DataLoader]:
|
||||
"""
|
||||
Create train and validation data loaders.
|
||||
|
||||
Args:
|
||||
root_dir: Directory with image pairs
|
||||
batch_size: Batch size
|
||||
train_split: Fraction for training (0.0-1.0)
|
||||
num_workers: DataLoader workers
|
||||
image_size: Target image size
|
||||
|
||||
Returns:
|
||||
(train_loader, val_loader)
|
||||
"""
|
||||
# Full dataset
|
||||
dataset = YaGoDataset(root_dir=root_dir, image_size=image_size)
|
||||
|
||||
# Split
|
||||
dataset_size = len(dataset)
|
||||
train_size = int(train_split * dataset_size)
|
||||
indices = torch.randperm(dataset_size).tolist()
|
||||
train_indices = indices[:train_size]
|
||||
val_indices = indices[train_size:]
|
||||
|
||||
# Subsets
|
||||
from torch.utils.data import Subset
|
||||
|
||||
train_dataset = Subset(dataset, train_indices)
|
||||
val_dataset = Subset(dataset, val_indices)
|
||||
|
||||
# DataLoaders
|
||||
train_loader = DataLoader(
|
||||
train_dataset,
|
||||
batch_size=batch_size,
|
||||
shuffle=True,
|
||||
num_workers=num_workers,
|
||||
pin_memory=True,
|
||||
)
|
||||
|
||||
val_loader = DataLoader(
|
||||
val_dataset,
|
||||
batch_size=batch_size,
|
||||
shuffle=False,
|
||||
num_workers=num_workers,
|
||||
pin_memory=True,
|
||||
)
|
||||
|
||||
return train_loader, val_loader
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Quick test
|
||||
from config import create_config
|
||||
|
||||
config = create_config()
|
||||
train_loader, val_loader = create_data_loaders(
|
||||
root_dir=config["data_dir"],
|
||||
batch_size=4,
|
||||
image_size=tuple(config["image_size"]),
|
||||
)
|
||||
|
||||
batch = next(iter(train_loader))
|
||||
print(f"Batch shapes: google={batch['google_img'].shape}, yandex={batch['yandex_img'].shape}")
|
||||
print(f"Train batches: {len(train_loader)}, Val batches: {len(val_loader)}")
|
||||
Reference in New Issue
Block a user