520 lines
16 KiB
Python
520 lines
16 KiB
Python
config = {
|
||
# Параметры оптимизатора
|
||
"learning_rate": 2e-4,
|
||
"beta1": 0.5,
|
||
"beta2": 0.999,
|
||
# Параметры обучения
|
||
"batch_size": 4,
|
||
"epochs": 100,
|
||
# Параметры GAN
|
||
"gan_mode": "vanilla", # "vanilla", "lsgan", или "wgangp"
|
||
"lambda_L1": 100.0, # Вес L1 потерь
|
||
# Регуляризация
|
||
"grad_clip": 1.0,
|
||
# Ранняя остановка
|
||
"early_stopping_patience": 20,
|
||
# Выходные данные
|
||
"output_dir": "runs/gan_training",
|
||
# Логирование
|
||
"log_interval": 10, # Логировать каждые N батчей
|
||
"save_interval": 5, # Сохранять чекпоинт каждые N эпох
|
||
"data_dir": r"C:\Users\admin\Projects\autopilot\datasets\ya_go_maps\images",
|
||
"batch_size": 32,
|
||
"image_size": [256, 256],
|
||
"train_split": 0.8,
|
||
"num_workers": 0,
|
||
}
|
||
|
||
|
||
import os
|
||
import random
|
||
from typing import Any, Dict, List, Optional, Tuple
|
||
|
||
import cv2
|
||
import numpy as np
|
||
import torch
|
||
from PIL import Image
|
||
from torch.utils.data import DataLoader, Dataset
|
||
|
||
|
||
class YaGoDataset(Dataset):
|
||
"""
|
||
Dataset for homography estimation between Yandex and Google map image pairs.
|
||
|
||
This dataset loads pairs of images (Yandex and Google maps) and provides
|
||
homography matrices for data augmentation and training.
|
||
"""
|
||
|
||
def __init__(
|
||
self,
|
||
root_dir: str,
|
||
transform=None,
|
||
augment: bool = True,
|
||
max_samples: Optional[int] = None,
|
||
image_size: Tuple[int, int] = (700, 700),
|
||
cache_homographies: bool = True,
|
||
device: torch.device = None,
|
||
):
|
||
"""
|
||
Initialize the YaGoDataset.
|
||
|
||
Args:
|
||
root_dir: Directory containing image pairs (format: {idx:04d}_google.png, {idx:04d}_yandex.png)
|
||
transform: Optional torchvision transforms to apply
|
||
augment: Whether to apply homography-based data augmentation
|
||
max_samples: Maximum number of samples to load (None for all)
|
||
image_size: Target size for images (height, width)
|
||
cache_homographies: Whether to cache generated homography matrices to disk
|
||
"""
|
||
self.root_dir = root_dir
|
||
self.transform = transform
|
||
self.augment = augment
|
||
self.image_size = image_size
|
||
self.cache_homographies = cache_homographies
|
||
self.device = device
|
||
|
||
# Find all image pairs
|
||
self.image_pairs = self._discover_image_pairs()
|
||
|
||
if max_samples is not None:
|
||
self.image_pairs = self.image_pairs[:max_samples]
|
||
|
||
print(f"Found {len(self.image_pairs)} image pairs in {root_dir}")
|
||
|
||
def _discover_image_pairs(self) -> List[Dict[str, Any]]:
|
||
"""Discover all Google-Yandex image pairs in the dataset directory."""
|
||
image_pairs = []
|
||
|
||
# Get all Google images
|
||
google_files = [
|
||
f for f in os.listdir(self.root_dir) if f.endswith("_google.png")
|
||
]
|
||
|
||
for google_file in sorted(google_files):
|
||
# Extract index from filename
|
||
idx_str = google_file.split("_")[0]
|
||
try:
|
||
idx = int(idx_str)
|
||
except ValueError:
|
||
continue
|
||
|
||
# Check if corresponding Yandex image exists
|
||
yandex_file = f"{idx:04d}_yandex.png"
|
||
yandex_path = os.path.join(self.root_dir, yandex_file)
|
||
|
||
if os.path.exists(yandex_path):
|
||
image_pairs.append(
|
||
{
|
||
"idx": idx,
|
||
"google_path": os.path.join(self.root_dir, google_file),
|
||
"yandex_path": yandex_path,
|
||
}
|
||
)
|
||
|
||
return image_pairs
|
||
|
||
def __len__(self) -> int:
|
||
"""Return the number of image pairs in the dataset."""
|
||
return len(self.image_pairs)
|
||
|
||
def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
|
||
"""
|
||
Get a sample from the dataset.
|
||
|
||
Returns a dictionary with:
|
||
- 'google_img': Google map image tensor
|
||
- 'yandex_img': Yandex map image tensor
|
||
- 'homography': Ground truth homography matrix (3x3)
|
||
- 'idx': Sample index
|
||
"""
|
||
pair_info = self.image_pairs[idx]
|
||
google_path = pair_info["google_path"]
|
||
yandex_path = pair_info["yandex_path"]
|
||
same_domain = True
|
||
|
||
if np.random.rand() > 0.5:
|
||
random_idx = np.random.randint(0, len(self))
|
||
google_path = self.image_pairs[random_idx]["google_path"]
|
||
same_domain = random_idx == idx
|
||
|
||
# Load images
|
||
yandex_img = Image.open(yandex_path).convert("RGB")
|
||
google_img = Image.open(google_path).convert("RGB")
|
||
|
||
# Resize images to target size
|
||
google_img = google_img.resize(
|
||
(self.image_size[1], self.image_size[0]), Image.BILINEAR
|
||
)
|
||
yandex_img = yandex_img.resize(
|
||
(self.image_size[1], self.image_size[0]), Image.BILINEAR
|
||
)
|
||
|
||
# Get or generate homography matrix
|
||
matrices: Tuple[np.ndarray, np.ndarray, np.ndarray] = (
|
||
self._get_homography_matrix(pair_info["idx"])
|
||
)
|
||
|
||
# Apply data augmentation if enabled
|
||
if self.augment:
|
||
google_img, yandex_img, homography_matrix = self._apply_augmentation(
|
||
google_img, yandex_img, matrices
|
||
)
|
||
|
||
# Convert images to tensors
|
||
if self.transform:
|
||
google_img = self.transform(google_img)
|
||
yandex_img = self.transform(yandex_img)
|
||
else:
|
||
# Default conversion to tensor
|
||
google_img = (
|
||
torch.from_numpy(np.array(google_img)).float().permute(2, 0, 1) / 255.0
|
||
)
|
||
yandex_img = (
|
||
torch.from_numpy(np.array(yandex_img)).float().permute(2, 0, 1) / 255.0
|
||
)
|
||
|
||
# Convert homography to tensor
|
||
if self.augment:
|
||
homography_tensor = torch.from_numpy(homography_matrix).float()
|
||
else:
|
||
homography_tensor = torch.from_numpy(np.eye(3))
|
||
|
||
return {
|
||
"google_img": google_img,
|
||
"yandex_img": yandex_img,
|
||
"homography": homography_tensor,
|
||
"same_domain": same_domain,
|
||
"idx": torch.tensor(pair_info["idx"], dtype=torch.long),
|
||
}
|
||
|
||
def _get_homography_matrix(
|
||
self, idx: int
|
||
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
|
||
"""
|
||
Get homography matrices for a given index.
|
||
|
||
If cached homography exists, load it. Otherwise generate a new one.
|
||
"""
|
||
|
||
# Generate new homography matrix
|
||
homography_matrix_1 = self.generate_random_homography()
|
||
homography_matrix_2 = self.generate_random_homography()
|
||
homography_matrix_r = np.linalg.inv(homography_matrix_1) @ homography_matrix_2
|
||
|
||
result = (homography_matrix_1, homography_matrix_2, homography_matrix_r)
|
||
|
||
return result
|
||
|
||
def generate_random_homography(self) -> np.ndarray:
|
||
"""
|
||
Generate a random homography matrix for data augmentation.
|
||
|
||
Returns:
|
||
np.ndarray: 3x3 homography matrix.
|
||
"""
|
||
# Generate random affine transformation parameters
|
||
scale = np.random.uniform(0.8, 1.2) # scaling factor
|
||
tx = np.random.uniform(-0.50, 0.50) # translation in x
|
||
ty = np.random.uniform(-0.50, 0.50) # translation in y
|
||
|
||
# rotation
|
||
angle_x = np.random.uniform(np.radians(-10), np.radians(10))
|
||
angle_y = np.random.uniform(np.radians(-10), np.radians(10))
|
||
angle_z = np.random.uniform(np.radians(-10), np.radians(10))
|
||
|
||
cy, sy = np.cos(angle_z), np.sin(angle_z)
|
||
cp, sp = np.cos(angle_y), np.sin(angle_y)
|
||
cr, sr = np.cos(angle_x), np.sin(angle_x)
|
||
|
||
Rz = np.array(
|
||
[
|
||
[cy, -sy, 0],
|
||
[sy, cy, 0],
|
||
[0, 0, 1],
|
||
]
|
||
)
|
||
|
||
Ry = np.array(
|
||
[
|
||
[cp, 0, sp],
|
||
[0, 1, 0],
|
||
[-sp, 0, cp],
|
||
]
|
||
)
|
||
|
||
Rx = np.array(
|
||
[
|
||
[1, 0, 0],
|
||
[0, cr, -sr],
|
||
[0, sr, cr],
|
||
]
|
||
)
|
||
|
||
# Create affine transformation matrix
|
||
T = np.array(
|
||
[
|
||
[1, 0, tx],
|
||
[0, 1, ty],
|
||
[0, 0, scale],
|
||
]
|
||
)
|
||
|
||
K = self.get_camera_matrix()
|
||
|
||
return K @ Rx @ Ry @ Rz @ T @ np.linalg.inv(K)
|
||
|
||
def get_camera_matrix(self) -> np.ndarray:
|
||
w, h = config["image_size"]
|
||
|
||
K = np.array(
|
||
[
|
||
[w / 2, 0, w / 2],
|
||
[0, h / 2, h / 2],
|
||
[0, 0, 1],
|
||
]
|
||
)
|
||
|
||
return K
|
||
|
||
def _apply_augmentation(
|
||
self,
|
||
google_img: Image.Image,
|
||
yandex_img: Image.Image,
|
||
matrices: Tuple[np.ndarray, np.ndarray, np.ndarray],
|
||
) -> Tuple[Image.Image, Image.Image, np.ndarray]:
|
||
"""
|
||
Apply homography-based data augmentation to image pair.
|
||
|
||
Args:
|
||
google_img: Google map image
|
||
yandex_img: Yandex map image
|
||
matrices: homography matrices
|
||
|
||
Returns:
|
||
Tuple of (augmented_google_img, augmented_yandex_img, augmented_homography)
|
||
"""
|
||
# Combine with base homography
|
||
combined_homography = matrices[2]
|
||
|
||
# Apply augmentation to both images
|
||
# google_aug = self._apply_homography_to_image(google_img, aug_homography)
|
||
yandex_aug = self._apply_homography_to_image(yandex_img, matrices[0])
|
||
google_aug = self._apply_homography_to_image(google_img, matrices[1])
|
||
|
||
return google_aug, yandex_aug, combined_homography
|
||
|
||
def _apply_homography_to_image(
|
||
self, img: Image.Image, homography: np.ndarray
|
||
) -> Image.Image:
|
||
"""
|
||
Apply homography transformation to a single image.
|
||
|
||
Args:
|
||
img: PIL Image to transform
|
||
homography: 3x3 homography matrix
|
||
|
||
Returns:
|
||
Transformed PIL Image
|
||
"""
|
||
# Convert to numpy array
|
||
img_np = np.array(img)
|
||
|
||
# Get image dimensions
|
||
h, w = img_np.shape[:2]
|
||
|
||
# Apply homography transformation
|
||
transformed = cv2.warpPerspective(
|
||
img_np,
|
||
homography,
|
||
(w, h),
|
||
flags=cv2.INTER_LINEAR,
|
||
# borderMode=cv2.BORDER_REFLECT,
|
||
)
|
||
|
||
# Convert back to PIL Image
|
||
return Image.fromarray(transformed)
|
||
|
||
def get_sample_without_augmentation(self, idx: int) -> Dict[str, Any]:
|
||
"""
|
||
Get a sample without data augmentation.
|
||
|
||
Useful for visualization and evaluation.
|
||
"""
|
||
pair_info = self.image_pairs[idx]
|
||
|
||
# Load images
|
||
google_img = Image.open(pair_info["google_path"]).convert("RGB")
|
||
yandex_img = Image.open(pair_info["yandex_path"]).convert("RGB")
|
||
|
||
# Resize
|
||
google_img = google_img.resize(
|
||
(self.image_size[1], self.image_size[0]), Image.BILINEAR
|
||
)
|
||
yandex_img = yandex_img.resize(
|
||
(self.image_size[1], self.image_size[0]), Image.BILINEAR
|
||
)
|
||
|
||
# Get homography matrix
|
||
homography_matrix = self._get_homography_matrix(pair_info["idx"])
|
||
|
||
return {
|
||
"google_img": google_img,
|
||
"yandex_img": yandex_img,
|
||
"homography": homography_matrix,
|
||
"idx": pair_info["idx"],
|
||
"google_path": pair_info["google_path"],
|
||
"yandex_path": pair_info["yandex_path"],
|
||
}
|
||
|
||
|
||
def create_data_loaders(
|
||
root_dir: str,
|
||
batch_size: int = 32,
|
||
train_split: float = 0.8,
|
||
num_workers: int = 4,
|
||
image_size: Tuple[int, int] = (256, 256),
|
||
augment_train: bool = True,
|
||
augment_val: bool = False,
|
||
device: torch.device = None,
|
||
) -> Tuple[DataLoader, DataLoader]:
|
||
"""
|
||
Create train and validation data loaders for homography estimation.
|
||
|
||
Args:
|
||
root_dir: Directory containing image pairs
|
||
batch_size: Batch size for data loaders
|
||
train_split: Fraction of data to use for training
|
||
num_workers: Number of worker processes for data loading
|
||
image_size: Target image size (height, width)
|
||
augment_train: Whether to augment training data
|
||
augment_val: Whether to augment validation data
|
||
device: Target device for tensors (optional)
|
||
|
||
Returns:
|
||
Tuple of (train_loader, val_loader)
|
||
"""
|
||
from torchvision import transforms
|
||
|
||
# Define transforms
|
||
transform = transforms.Compose(
|
||
[
|
||
transforms.ToTensor(),
|
||
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
||
]
|
||
)
|
||
|
||
# Create full dataset
|
||
full_dataset = YaGoDataset(
|
||
root_dir=root_dir,
|
||
transform=transform,
|
||
augment=False, # We'll handle augmentation separately
|
||
image_size=image_size,
|
||
cache_homographies=True,
|
||
device=device,
|
||
)
|
||
|
||
# Split dataset
|
||
dataset_size = len(full_dataset)
|
||
train_size = int(train_split * dataset_size)
|
||
val_size = dataset_size - train_size
|
||
|
||
# Create indices for splitting
|
||
indices = list(range(dataset_size))
|
||
random.shuffle(indices)
|
||
train_indices = indices[:train_size]
|
||
val_indices = indices[train_size:]
|
||
|
||
# Create subset samplers
|
||
from torch.utils.data import Subset
|
||
|
||
train_dataset = Subset(full_dataset, train_indices)
|
||
val_dataset = Subset(full_dataset, val_indices)
|
||
|
||
# Apply augmentation by overriding __getitem__ for train dataset
|
||
if augment_train:
|
||
|
||
class AugmentedSubset(Subset):
|
||
def __init__(self, dataset, indices, device=None):
|
||
super().__init__(dataset, indices)
|
||
self.device = device
|
||
|
||
def __getitem__(self, idx):
|
||
sample = self.dataset[self.indices[idx]]
|
||
# Apply augmentation
|
||
google_img = sample["google_img"]
|
||
yandex_img = sample["yandex_img"]
|
||
homography = sample["homography"]
|
||
|
||
if self.device is not None:
|
||
google_img = google_img.to(self.device)
|
||
yandex_img = yandex_img.to(self.device)
|
||
homography = homography.to(self.device)
|
||
|
||
# Generate augmentation homography
|
||
aug_homography = torch.from_numpy(
|
||
full_dataset.generate_random_homography()
|
||
).float()
|
||
|
||
if self.device is not None:
|
||
aug_homography = aug_homography.to(self.device)
|
||
|
||
# Combine homographies
|
||
combined_homography = aug_homography @ homography
|
||
|
||
# Apply augmentation (simplified - in practice would warp images)
|
||
# For now, we just return the combined homography
|
||
return {
|
||
"google_img": google_img,
|
||
"yandex_img": yandex_img,
|
||
"homography": combined_homography,
|
||
"idx": sample["idx"],
|
||
}
|
||
|
||
train_dataset = AugmentedSubset(full_dataset, train_indices, device=device)
|
||
|
||
# Create data loaders
|
||
train_loader = DataLoader(
|
||
train_dataset,
|
||
batch_size=batch_size,
|
||
shuffle=True,
|
||
num_workers=num_workers,
|
||
pin_memory=True,
|
||
)
|
||
|
||
val_loader = DataLoader(
|
||
val_dataset,
|
||
batch_size=batch_size,
|
||
shuffle=False,
|
||
num_workers=num_workers,
|
||
pin_memory=True,
|
||
)
|
||
|
||
return train_loader, val_loader
|
||
|
||
|
||
# Example usage
|
||
dataset = YaGoDataset(
|
||
root_dir=config["data_dir"],
|
||
augment=False,
|
||
image_size=(256, 256),
|
||
)
|
||
|
||
print(f"Dataset size: {len(dataset)}")
|
||
|
||
# Get a sample
|
||
sample = dataset[0]
|
||
print(f"Sample keys: {list(sample.keys())}")
|
||
print(f"Google image shape: {sample['google_img'].shape}")
|
||
print(f"Yandex image shape: {sample['yandex_img'].shape}")
|
||
print(f"Homography shape: {sample['homography'].shape}")
|
||
|
||
# Create data loaders
|
||
train_loader, val_loader = create_data_loaders(
|
||
root_dir=config["data_dir"],
|
||
batch_size=16,
|
||
train_split=0.8,
|
||
)
|
||
|
||
print(f"Train batches: {len(train_loader)}")
|
||
print(f"Val batches: {len(val_loader)}")
|