1680 lines
78 KiB
Plaintext
1680 lines
78 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 19,
|
|
"id": "92144cc0",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Found 327 image pairs in C:\\Users\\admin\\Projects\\autopilot\\datasets\\ya_go_maps\\images\n",
|
|
"Dataset size: 327\n",
|
|
"Sample keys: ['google_img', 'yandex_img', 'homography', 'idx']\n",
|
|
"Google image shape: torch.Size([3, 700, 700])\n",
|
|
"Yandex image shape: torch.Size([3, 700, 700])\n",
|
|
"Homography shape: torch.Size([3, 3])\n",
|
|
"Found 327 image pairs in C:\\Users\\admin\\Projects\\autopilot\\datasets\\ya_go_maps\\images\n",
|
|
"Train batches: 17\n",
|
|
"Val batches: 5\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"import os\n",
|
|
"import random\n",
|
|
"from typing import Any, Dict, List, Optional, Tuple\n",
|
|
"\n",
|
|
"import cv2\n",
|
|
"import numpy as np\n",
|
|
"import torch\n",
|
|
"from PIL import Image\n",
|
|
"from torch.utils.data import DataLoader, Dataset\n",
|
|
"\n",
|
|
"\n",
|
|
"class HomographyDataset(Dataset):\n",
|
|
" \"\"\"\n",
|
|
" Dataset for homography estimation between Yandex and Google map image pairs.\n",
|
|
"\n",
|
|
" This dataset loads pairs of images (Yandex and Google maps) and provides\n",
|
|
" homography matrices for data augmentation and training.\n",
|
|
" \"\"\"\n",
|
|
"\n",
|
|
" def __init__(\n",
|
|
" self,\n",
|
|
" root_dir: str,\n",
|
|
" transform=None,\n",
|
|
" augment: bool = True,\n",
|
|
" max_samples: Optional[int] = None,\n",
|
|
" image_size: Tuple[int, int] = (700, 700),\n",
|
|
" cache_homographies: bool = True,\n",
|
|
" ):\n",
|
|
" \"\"\"\n",
|
|
" Initialize the HomographyDataset.\n",
|
|
"\n",
|
|
" Args:\n",
|
|
" root_dir: Directory containing image pairs (format: {idx:04d}_google.png, {idx:04d}_yandex.png)\n",
|
|
" transform: Optional torchvision transforms to apply\n",
|
|
" augment: Whether to apply homography-based data augmentation\n",
|
|
" max_samples: Maximum number of samples to load (None for all)\n",
|
|
" image_size: Target size for images (height, width)\n",
|
|
" cache_homographies: Whether to cache generated homography matrices to disk\n",
|
|
" \"\"\"\n",
|
|
" self.root_dir = root_dir\n",
|
|
" self.transform = transform\n",
|
|
" self.augment = augment\n",
|
|
" self.image_size = image_size\n",
|
|
" self.cache_homographies = cache_homographies\n",
|
|
"\n",
|
|
" # Find all image pairs\n",
|
|
" self.image_pairs = self._discover_image_pairs()\n",
|
|
"\n",
|
|
" if max_samples is not None:\n",
|
|
" self.image_pairs = self.image_pairs[:max_samples]\n",
|
|
"\n",
|
|
" print(f\"Found {len(self.image_pairs)} image pairs in {root_dir}\")\n",
|
|
"\n",
|
|
" # Create directory for cached homographies if needed\n",
|
|
" if cache_homographies:\n",
|
|
" self.homography_cache_dir = os.path.join(root_dir, \"homography_cache\")\n",
|
|
" os.makedirs(self.homography_cache_dir, exist_ok=True)\n",
|
|
"\n",
|
|
" def _discover_image_pairs(self) -> List[Dict[str, Any]]:\n",
|
|
" \"\"\"Discover all Google-Yandex image pairs in the dataset directory.\"\"\"\n",
|
|
" image_pairs = []\n",
|
|
"\n",
|
|
" # Get all Google images\n",
|
|
" google_files = [\n",
|
|
" f for f in os.listdir(self.root_dir) if f.endswith(\"_google.png\")\n",
|
|
" ]\n",
|
|
"\n",
|
|
" for google_file in sorted(google_files):\n",
|
|
" # Extract index from filename\n",
|
|
" idx_str = google_file.split(\"_\")[0]\n",
|
|
" try:\n",
|
|
" idx = int(idx_str)\n",
|
|
" except ValueError:\n",
|
|
" continue\n",
|
|
"\n",
|
|
" # Check if corresponding Yandex image exists\n",
|
|
" yandex_file = f\"{idx:04d}_yandex.png\"\n",
|
|
" yandex_path = os.path.join(self.root_dir, yandex_file)\n",
|
|
"\n",
|
|
" if os.path.exists(yandex_path):\n",
|
|
" image_pairs.append(\n",
|
|
" {\n",
|
|
" \"idx\": idx,\n",
|
|
" \"google_path\": os.path.join(self.root_dir, google_file),\n",
|
|
" \"yandex_path\": yandex_path,\n",
|
|
" }\n",
|
|
" )\n",
|
|
"\n",
|
|
" return image_pairs\n",
|
|
"\n",
|
|
" def __len__(self) -> int:\n",
|
|
" \"\"\"Return the number of image pairs in the dataset.\"\"\"\n",
|
|
" return len(self.image_pairs)\n",
|
|
"\n",
|
|
" def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:\n",
|
|
" \"\"\"\n",
|
|
" Get a sample from the dataset.\n",
|
|
"\n",
|
|
" Returns a dictionary with:\n",
|
|
" - 'google_img': Google map image tensor\n",
|
|
" - 'yandex_img': Yandex map image tensor\n",
|
|
" - 'homography': Ground truth homography matrix (3x3)\n",
|
|
" - 'idx': Sample index\n",
|
|
" \"\"\"\n",
|
|
" pair_info = self.image_pairs[idx]\n",
|
|
"\n",
|
|
" # Load images\n",
|
|
" google_img = Image.open(pair_info[\"google_path\"]).convert(\"RGB\")\n",
|
|
" yandex_img = Image.open(pair_info[\"yandex_path\"]).convert(\"RGB\")\n",
|
|
"\n",
|
|
" # Resize images to target size\n",
|
|
" google_img = google_img.resize(\n",
|
|
" (self.image_size[1], self.image_size[0]), Image.BILINEAR\n",
|
|
" )\n",
|
|
" yandex_img = yandex_img.resize(\n",
|
|
" (self.image_size[1], self.image_size[0]), Image.BILINEAR\n",
|
|
" )\n",
|
|
"\n",
|
|
" # Get or generate homography matrix\n",
|
|
" homography_matrix = self._get_homography_matrix(pair_info[\"idx\"])\n",
|
|
"\n",
|
|
" # Apply data augmentation if enabled\n",
|
|
" if self.augment:\n",
|
|
" google_img, yandex_img, homography_matrix = self._apply_augmentation(\n",
|
|
" google_img, yandex_img, homography_matrix\n",
|
|
" )\n",
|
|
"\n",
|
|
" # Convert images to tensors\n",
|
|
" if self.transform:\n",
|
|
" google_img = self.transform(google_img)\n",
|
|
" yandex_img = self.transform(yandex_img)\n",
|
|
" else:\n",
|
|
" # Default conversion to tensor\n",
|
|
" google_img = (\n",
|
|
" torch.from_numpy(np.array(google_img)).float().permute(2, 0, 1) / 255.0\n",
|
|
" )\n",
|
|
" yandex_img = (\n",
|
|
" torch.from_numpy(np.array(yandex_img)).float().permute(2, 0, 1) / 255.0\n",
|
|
" )\n",
|
|
"\n",
|
|
" # Convert homography to tensor\n",
|
|
" homography_tensor = torch.from_numpy(homography_matrix).float()\n",
|
|
"\n",
|
|
" return {\n",
|
|
" \"google_img\": google_img,\n",
|
|
" \"yandex_img\": yandex_img,\n",
|
|
" \"homography\": homography_tensor,\n",
|
|
" \"idx\": torch.tensor(pair_info[\"idx\"], dtype=torch.long),\n",
|
|
" }\n",
|
|
"\n",
|
|
" def _get_homography_matrix(self, idx: int) -> np.ndarray:\n",
|
|
" \"\"\"\n",
|
|
" Get homography matrix for a given index.\n",
|
|
"\n",
|
|
" If cached homography exists, load it. Otherwise generate a new one.\n",
|
|
" \"\"\"\n",
|
|
" if self.cache_homographies:\n",
|
|
" cache_path = os.path.join(\n",
|
|
" self.homography_cache_dir, f\"{idx:04d}_homography.npy\"\n",
|
|
" )\n",
|
|
" if os.path.exists(cache_path):\n",
|
|
" return np.load(cache_path)\n",
|
|
"\n",
|
|
" # Generate new homography matrix\n",
|
|
" homography_matrix = self.generate_random_homography()\n",
|
|
"\n",
|
|
" # Cache if enabled\n",
|
|
" if self.cache_homographies:\n",
|
|
" np.save(cache_path, homography_matrix)\n",
|
|
"\n",
|
|
" return homography_matrix\n",
|
|
"\n",
|
|
" def generate_random_homography(self) -> np.ndarray:\n",
|
|
" \"\"\"\n",
|
|
" Generate a random homography matrix for data augmentation.\n",
|
|
"\n",
|
|
" Returns:\n",
|
|
" np.ndarray: 3x3 homography matrix.\n",
|
|
" \"\"\"\n",
|
|
" # Generate random affine transformation parameters\n",
|
|
" angle = np.random.uniform(-30, 30) # rotation in degrees\n",
|
|
" scale = np.random.uniform(0.8, 1.2) # scaling factor\n",
|
|
" tx = np.random.uniform(-50, 50) # translation in x\n",
|
|
" ty = np.random.uniform(-50, 50) # translation in y\n",
|
|
"\n",
|
|
" # Convert angle to radians\n",
|
|
" theta = np.radians(angle)\n",
|
|
"\n",
|
|
" # Create affine transformation matrix\n",
|
|
" affine_matrix = np.array(\n",
|
|
" [\n",
|
|
" [scale * np.cos(theta), -scale * np.sin(theta), tx],\n",
|
|
" [scale * np.sin(theta), scale * np.cos(theta), ty],\n",
|
|
" [0, 0, 1],\n",
|
|
" ]\n",
|
|
" )\n",
|
|
"\n",
|
|
" # Add small perspective distortion\n",
|
|
" perspective = np.random.uniform(-0.001, 0.001, (2, 3))\n",
|
|
" perspective = np.vstack([perspective, [0, 0, 0]])\n",
|
|
"\n",
|
|
" homography_matrix = affine_matrix + perspective\n",
|
|
"\n",
|
|
" return homography_matrix\n",
|
|
"\n",
|
|
" def _apply_augmentation(\n",
|
|
" self,\n",
|
|
" google_img: Image.Image,\n",
|
|
" yandex_img: Image.Image,\n",
|
|
" base_homography: np.ndarray,\n",
|
|
" ) -> Tuple[Image.Image, Image.Image, np.ndarray]:\n",
|
|
" \"\"\"\n",
|
|
" Apply homography-based data augmentation to image pair.\n",
|
|
"\n",
|
|
" Args:\n",
|
|
" google_img: Google map image\n",
|
|
" yandex_img: Yandex map image\n",
|
|
" base_homography: Base homography matrix\n",
|
|
"\n",
|
|
" Returns:\n",
|
|
" Tuple of (augmented_google_img, augmented_yandex_img, augmented_homography)\n",
|
|
" \"\"\"\n",
|
|
" # Generate augmentation homography\n",
|
|
" aug_homography = self.generate_random_homography()\n",
|
|
"\n",
|
|
" # Combine with base homography\n",
|
|
" combined_homography = aug_homography @ base_homography\n",
|
|
"\n",
|
|
" # Apply augmentation to both images\n",
|
|
" google_aug = self._apply_homography_to_image(google_img, aug_homography)\n",
|
|
" yandex_aug = self._apply_homography_to_image(yandex_img, aug_homography)\n",
|
|
"\n",
|
|
" return google_aug, yandex_aug, combined_homography\n",
|
|
"\n",
|
|
" def _apply_homography_to_image(\n",
|
|
" self, img: Image.Image, homography: np.ndarray\n",
|
|
" ) -> Image.Image:\n",
|
|
" \"\"\"\n",
|
|
" Apply homography transformation to a single image.\n",
|
|
"\n",
|
|
" Args:\n",
|
|
" img: PIL Image to transform\n",
|
|
" homography: 3x3 homography matrix\n",
|
|
"\n",
|
|
" Returns:\n",
|
|
" Transformed PIL Image\n",
|
|
" \"\"\"\n",
|
|
" # Convert to numpy array\n",
|
|
" img_np = np.array(img)\n",
|
|
"\n",
|
|
" # Get image dimensions\n",
|
|
" h, w = img_np.shape[:2]\n",
|
|
"\n",
|
|
" # Apply homography transformation\n",
|
|
" transformed = cv2.warpPerspective(\n",
|
|
" img_np,\n",
|
|
" homography,\n",
|
|
" (w, h),\n",
|
|
" flags=cv2.INTER_LINEAR,\n",
|
|
" borderMode=cv2.BORDER_REFLECT,\n",
|
|
" )\n",
|
|
"\n",
|
|
" # Convert back to PIL Image\n",
|
|
" return Image.fromarray(transformed)\n",
|
|
"\n",
|
|
" def get_sample_without_augmentation(self, idx: int) -> Dict[str, Any]:\n",
|
|
" \"\"\"\n",
|
|
" Get a sample without data augmentation.\n",
|
|
"\n",
|
|
" Useful for visualization and evaluation.\n",
|
|
" \"\"\"\n",
|
|
" pair_info = self.image_pairs[idx]\n",
|
|
"\n",
|
|
" # Load images\n",
|
|
" google_img = Image.open(pair_info[\"google_path\"]).convert(\"RGB\")\n",
|
|
" yandex_img = Image.open(pair_info[\"yandex_path\"]).convert(\"RGB\")\n",
|
|
"\n",
|
|
" # Resize\n",
|
|
" google_img = google_img.resize(\n",
|
|
" (self.image_size[1], self.image_size[0]), Image.BILINEAR\n",
|
|
" )\n",
|
|
" yandex_img = yandex_img.resize(\n",
|
|
" (self.image_size[1], self.image_size[0]), Image.BILINEAR\n",
|
|
" )\n",
|
|
"\n",
|
|
" # Get homography matrix\n",
|
|
" homography_matrix = self._get_homography_matrix(pair_info[\"idx\"])\n",
|
|
"\n",
|
|
" return {\n",
|
|
" \"google_img\": google_img,\n",
|
|
" \"yandex_img\": yandex_img,\n",
|
|
" \"homography\": homography_matrix,\n",
|
|
" \"idx\": pair_info[\"idx\"],\n",
|
|
" \"google_path\": pair_info[\"google_path\"],\n",
|
|
" \"yandex_path\": pair_info[\"yandex_path\"],\n",
|
|
" }\n",
|
|
"\n",
|
|
"\n",
|
|
"def create_data_loaders(\n",
|
|
" root_dir: str,\n",
|
|
" batch_size: int = 32,\n",
|
|
" train_split: float = 0.8,\n",
|
|
" num_workers: int = 4,\n",
|
|
" image_size: Tuple[int, int] = (256, 256),\n",
|
|
" augment_train: bool = True,\n",
|
|
" augment_val: bool = False,\n",
|
|
") -> Tuple[DataLoader, DataLoader]:\n",
|
|
" \"\"\"\n",
|
|
" Create train and validation data loaders for homography estimation.\n",
|
|
"\n",
|
|
" Args:\n",
|
|
" root_dir: Directory containing image pairs\n",
|
|
" batch_size: Batch size for data loaders\n",
|
|
" train_split: Fraction of data to use for training\n",
|
|
" num_workers: Number of worker processes for data loading\n",
|
|
" image_size: Target image size (height, width)\n",
|
|
" augment_train: Whether to augment training data\n",
|
|
" augment_val: Whether to augment validation data\n",
|
|
"\n",
|
|
" Returns:\n",
|
|
" Tuple of (train_loader, val_loader)\n",
|
|
" \"\"\"\n",
|
|
" from torchvision import transforms\n",
|
|
"\n",
|
|
" # Define transforms\n",
|
|
" transform = transforms.Compose(\n",
|
|
" [\n",
|
|
" transforms.ToTensor(),\n",
|
|
" transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),\n",
|
|
" ]\n",
|
|
" )\n",
|
|
"\n",
|
|
" # Create full dataset\n",
|
|
" full_dataset = HomographyDataset(\n",
|
|
" root_dir=root_dir,\n",
|
|
" transform=transform,\n",
|
|
" augment=False, # We'll handle augmentation separately\n",
|
|
" image_size=image_size,\n",
|
|
" cache_homographies=True,\n",
|
|
" )\n",
|
|
"\n",
|
|
" # Split dataset\n",
|
|
" dataset_size = len(full_dataset)\n",
|
|
" train_size = int(train_split * dataset_size)\n",
|
|
" val_size = dataset_size - train_size\n",
|
|
"\n",
|
|
" # Create indices for splitting\n",
|
|
" indices = list(range(dataset_size))\n",
|
|
" random.shuffle(indices)\n",
|
|
" train_indices = indices[:train_size]\n",
|
|
" val_indices = indices[train_size:]\n",
|
|
"\n",
|
|
" # Create subset samplers\n",
|
|
" from torch.utils.data import Subset\n",
|
|
"\n",
|
|
" train_dataset = Subset(full_dataset, train_indices)\n",
|
|
" val_dataset = Subset(full_dataset, val_indices)\n",
|
|
"\n",
|
|
" # Apply augmentation by overriding __getitem__ for train dataset\n",
|
|
" if augment_train:\n",
|
|
"\n",
|
|
" class AugmentedSubset(Subset):\n",
|
|
" def __getitem__(self, idx):\n",
|
|
" sample = self.dataset[self.indices[idx]]\n",
|
|
" # Apply augmentation\n",
|
|
" google_img = sample[\"google_img\"]\n",
|
|
" yandex_img = sample[\"yandex_img\"]\n",
|
|
" homography = sample[\"homography\"]\n",
|
|
"\n",
|
|
" # Generate augmentation homography\n",
|
|
" aug_homography = torch.from_numpy(\n",
|
|
" full_dataset.generate_random_homography()\n",
|
|
" ).float()\n",
|
|
"\n",
|
|
" # Combine homographies\n",
|
|
" combined_homography = aug_homography @ homography\n",
|
|
"\n",
|
|
" # Apply augmentation (simplified - in practice would warp images)\n",
|
|
" # For now, we just return the combined homography\n",
|
|
" return {\n",
|
|
" \"google_img\": google_img,\n",
|
|
" \"yandex_img\": yandex_img,\n",
|
|
" \"homography\": combined_homography,\n",
|
|
" \"idx\": sample[\"idx\"],\n",
|
|
" }\n",
|
|
"\n",
|
|
" train_dataset = AugmentedSubset(full_dataset, train_indices)\n",
|
|
"\n",
|
|
" # Create data loaders\n",
|
|
" train_loader = DataLoader(\n",
|
|
" train_dataset,\n",
|
|
" batch_size=batch_size,\n",
|
|
" shuffle=True,\n",
|
|
" num_workers=num_workers,\n",
|
|
" pin_memory=True,\n",
|
|
" )\n",
|
|
"\n",
|
|
" val_loader = DataLoader(\n",
|
|
" val_dataset,\n",
|
|
" batch_size=batch_size,\n",
|
|
" shuffle=False,\n",
|
|
" num_workers=num_workers,\n",
|
|
" pin_memory=True,\n",
|
|
" )\n",
|
|
"\n",
|
|
" return train_loader, val_loader\n",
|
|
"\n",
|
|
"\n",
|
|
"# Example usage\n",
|
|
"dataset = HomographyDataset(\n",
|
|
" root_dir=r\"C:\\Users\\admin\\Projects\\autopilot\\datasets\\ya_go_maps\\images\",\n",
|
|
" augment=True,\n",
|
|
" image_size=(700, 700),\n",
|
|
")\n",
|
|
"\n",
|
|
"print(f\"Dataset size: {len(dataset)}\")\n",
|
|
"\n",
|
|
"# Get a sample\n",
|
|
"sample = dataset[0]\n",
|
|
"print(f\"Sample keys: {list(sample.keys())}\")\n",
|
|
"print(f\"Google image shape: {sample['google_img'].shape}\")\n",
|
|
"print(f\"Yandex image shape: {sample['yandex_img'].shape}\")\n",
|
|
"print(f\"Homography shape: {sample['homography'].shape}\")\n",
|
|
"\n",
|
|
"# Create data loaders\n",
|
|
"train_loader, val_loader = create_data_loaders(\n",
|
|
" root_dir=r\"C:\\Users\\admin\\Projects\\autopilot\\datasets\\ya_go_maps\\images\",\n",
|
|
" batch_size=16,\n",
|
|
" train_split=0.8,\n",
|
|
")\n",
|
|
"\n",
|
|
"print(f\"Train batches: {len(train_loader)}\")\n",
|
|
"print(f\"Val batches: {len(val_loader)}\")\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 20,
|
|
"id": "bf3b0524",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Using device: cpu\n",
|
|
"Model created with 9,013,385 parameters\n",
|
|
"\n",
|
|
"Testing forward pass...\n",
|
|
"Output shape: torch.Size([4, 3, 3])\n",
|
|
"Sample output:\n",
|
|
"tensor([[-1.3744e+01, 9.4431e+00, 1.8618e+01],\n",
|
|
" [ 9.8099e+00, 5.7875e+00, -2.4102e+01],\n",
|
|
" [ 9.3618e-03, 3.3153e+00, 1.0000e+00]], grad_fn=<SelectBackward0>)\n",
|
|
"\n",
|
|
"Testing prediction...\n",
|
|
"Prediction shape: torch.Size([4, 3, 3])\n",
|
|
"Last element (should be ~1): 1.000000\n",
|
|
"\n",
|
|
"Testing loss function...\n",
|
|
"Loss value: 582368034816.000000\n",
|
|
"\n",
|
|
"Testing metrics...\n",
|
|
"matrix_mse: 4946.436523\n",
|
|
"corner_error: 5.424056\n",
|
|
"corner_error_px: 694.279114\n",
|
|
"\n",
|
|
"Testing model factory...\n",
|
|
"Model2 created with 1,779,145 parameters\n",
|
|
"\n",
|
|
"All tests completed successfully!\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"from typing import Optional, Tuple\n",
|
|
"\n",
|
|
"import torch\n",
|
|
"import torch.nn as nn\n",
|
|
"import torch.nn.functional as F\n",
|
|
"\n",
|
|
"\n",
|
|
"class HomographyCNN(nn.Module):\n",
|
|
" \"\"\"\n",
|
|
" CNN model for homography estimation between two images.\n",
|
|
"\n",
|
|
" This model takes two images (Google and Yandex maps) as input and\n",
|
|
" outputs a 3x3 homography matrix that transforms one image to align with the other.\n",
|
|
" \"\"\"\n",
|
|
"\n",
|
|
" def __init__(\n",
|
|
" self,\n",
|
|
" input_channels: int = 3,\n",
|
|
" hidden_channels: int = 64,\n",
|
|
" num_blocks: int = 4,\n",
|
|
" dropout_rate: float = 0.3,\n",
|
|
" use_batch_norm: bool = True,\n",
|
|
" output_size: int = 9, # Flattened 3x3 homography matrix\n",
|
|
" ):\n",
|
|
" \"\"\"\n",
|
|
" Initialize the HomographyCNN model.\n",
|
|
"\n",
|
|
" Args:\n",
|
|
" input_channels: Number of input channels per image (3 for RGB)\n",
|
|
" hidden_channels: Base number of channels in the network\n",
|
|
" num_blocks: Number of convolutional blocks\n",
|
|
" dropout_rate: Dropout rate for regularization\n",
|
|
" use_batch_norm: Whether to use batch normalization\n",
|
|
" output_size: Size of output vector (9 for flattened 3x3 matrix)\n",
|
|
" \"\"\"\n",
|
|
" super().__init__()\n",
|
|
"\n",
|
|
" self.input_channels = input_channels\n",
|
|
" self.hidden_channels = hidden_channels\n",
|
|
" self.num_blocks = num_blocks\n",
|
|
" self.dropout_rate = dropout_rate\n",
|
|
" self.use_batch_norm = use_batch_norm\n",
|
|
"\n",
|
|
" # Feature extraction for each image separately\n",
|
|
" self.google_encoder = self._build_encoder()\n",
|
|
" self.yandex_encoder = self._build_encoder()\n",
|
|
"\n",
|
|
" # Fusion layers to combine features from both images\n",
|
|
" self.fusion_layers = self._build_fusion_layers()\n",
|
|
"\n",
|
|
" # Regression head for homography estimation\n",
|
|
" self.regression_head = self._build_regression_head(output_size)\n",
|
|
"\n",
|
|
" # Initialize weights\n",
|
|
" self._initialize_weights()\n",
|
|
"\n",
|
|
" def _build_encoder(self) -> nn.Module:\n",
|
|
" \"\"\"Build the encoder network for a single image.\"\"\"\n",
|
|
" layers = []\n",
|
|
" in_channels = self.input_channels\n",
|
|
" out_channels = self.hidden_channels\n",
|
|
"\n",
|
|
" # First convolutional block\n",
|
|
" layers.append(\n",
|
|
" nn.Conv2d(in_channels, out_channels, kernel_size=7, stride=2, padding=3)\n",
|
|
" )\n",
|
|
" if self.use_batch_norm:\n",
|
|
" layers.append(nn.BatchNorm2d(out_channels))\n",
|
|
" layers.append(nn.ReLU(inplace=True))\n",
|
|
" layers.append(nn.MaxPool2d(kernel_size=3, stride=2, padding=1))\n",
|
|
"\n",
|
|
" # Additional convolutional blocks\n",
|
|
" for i in range(self.num_blocks):\n",
|
|
" block_in_channels = out_channels\n",
|
|
" block_out_channels = out_channels * 2 if i < 2 else out_channels\n",
|
|
"\n",
|
|
" layers.append(\n",
|
|
" ResidualBlock(\n",
|
|
" in_channels=block_in_channels,\n",
|
|
" out_channels=block_out_channels,\n",
|
|
" stride=1 if i == 0 else 2,\n",
|
|
" dropout_rate=self.dropout_rate,\n",
|
|
" use_batch_norm=self.use_batch_norm,\n",
|
|
" )\n",
|
|
" )\n",
|
|
"\n",
|
|
" if i < 2:\n",
|
|
" out_channels = block_out_channels\n",
|
|
"\n",
|
|
" return nn.Sequential(*layers)\n",
|
|
"\n",
|
|
" def _build_fusion_layers(self) -> nn.Module:\n",
|
|
" \"\"\"Build layers to fuse features from both images.\"\"\"\n",
|
|
" # After encoding, each image has hidden_channels * 4 features\n",
|
|
" fused_channels = (\n",
|
|
" self.hidden_channels * 8\n",
|
|
" ) # Concatenated features from both images\n",
|
|
"\n",
|
|
" layers = [\n",
|
|
" # Reduce dimensionality\n",
|
|
" nn.Conv2d(\n",
|
|
" fused_channels, self.hidden_channels * 4, kernel_size=3, padding=1\n",
|
|
" ),\n",
|
|
" nn.BatchNorm2d(self.hidden_channels * 4)\n",
|
|
" if self.use_batch_norm\n",
|
|
" else nn.Identity(),\n",
|
|
" nn.ReLU(inplace=True),\n",
|
|
" nn.Dropout2d(self.dropout_rate),\n",
|
|
" # Further processing\n",
|
|
" nn.Conv2d(\n",
|
|
" self.hidden_channels * 4,\n",
|
|
" self.hidden_channels * 2,\n",
|
|
" kernel_size=3,\n",
|
|
" padding=1,\n",
|
|
" ),\n",
|
|
" nn.BatchNorm2d(self.hidden_channels * 2)\n",
|
|
" if self.use_batch_norm\n",
|
|
" else nn.Identity(),\n",
|
|
" nn.ReLU(inplace=True),\n",
|
|
" nn.Dropout2d(self.dropout_rate),\n",
|
|
" # Global pooling\n",
|
|
" nn.AdaptiveAvgPool2d((1, 1)),\n",
|
|
" ]\n",
|
|
"\n",
|
|
" return nn.Sequential(*layers)\n",
|
|
"\n",
|
|
" def _build_regression_head(self, output_size: int) -> nn.Module:\n",
|
|
" \"\"\"Build the regression head for homography estimation.\"\"\"\n",
|
|
" # Input size after fusion and global pooling\n",
|
|
" input_features = self.hidden_channels * 2\n",
|
|
"\n",
|
|
" layers = [\n",
|
|
" nn.Flatten(),\n",
|
|
" nn.Linear(input_features, 512),\n",
|
|
" nn.BatchNorm1d(512) if self.use_batch_norm else nn.Identity(),\n",
|
|
" nn.ReLU(inplace=True),\n",
|
|
" nn.Dropout(self.dropout_rate),\n",
|
|
" nn.Linear(512, 256),\n",
|
|
" nn.BatchNorm1d(256) if self.use_batch_norm else nn.Identity(),\n",
|
|
" nn.ReLU(inplace=True),\n",
|
|
" nn.Dropout(self.dropout_rate),\n",
|
|
" nn.Linear(256, 128),\n",
|
|
" nn.BatchNorm1d(128) if self.use_batch_norm else nn.Identity(),\n",
|
|
" nn.ReLU(inplace=True),\n",
|
|
" nn.Dropout(self.dropout_rate),\n",
|
|
" nn.Linear(128, output_size),\n",
|
|
" ]\n",
|
|
"\n",
|
|
" return nn.Sequential(*layers)\n",
|
|
"\n",
|
|
" def _initialize_weights(self):\n",
|
|
" \"\"\"Initialize model weights.\"\"\"\n",
|
|
" for m in self.modules():\n",
|
|
" if isinstance(m, nn.Conv2d):\n",
|
|
" nn.init.kaiming_normal_(m.weight, mode=\"fan_out\", nonlinearity=\"relu\")\n",
|
|
" if m.bias is not None:\n",
|
|
" nn.init.constant_(m.bias, 0)\n",
|
|
" elif isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d):\n",
|
|
" nn.init.constant_(m.weight, 1)\n",
|
|
" nn.init.constant_(m.bias, 0)\n",
|
|
" elif isinstance(m, nn.Linear):\n",
|
|
" nn.init.normal_(m.weight, 0, 0.01)\n",
|
|
" nn.init.constant_(m.bias, 0)\n",
|
|
"\n",
|
|
" def forward(\n",
|
|
" self,\n",
|
|
" google_img: torch.Tensor,\n",
|
|
" yandex_img: torch.Tensor,\n",
|
|
" return_matrix: bool = True,\n",
|
|
" ) -> torch.Tensor:\n",
|
|
" \"\"\"\n",
|
|
" Forward pass of the model.\n",
|
|
"\n",
|
|
" Args:\n",
|
|
" google_img: Google map image tensor of shape (B, C, H, W)\n",
|
|
" yandex_img: Yandex map image tensor of shape (B, C, H, W)\n",
|
|
" return_matrix: If True, return 3x3 matrix; if False, return flattened vector\n",
|
|
"\n",
|
|
" Returns:\n",
|
|
" Homography matrix tensor of shape (B, 3, 3) or flattened vector of shape (B, 9)\n",
|
|
" \"\"\"\n",
|
|
" # Extract features from both images\n",
|
|
" google_features = self.google_encoder(google_img)\n",
|
|
" yandex_features = self.yandex_encoder(yandex_img)\n",
|
|
"\n",
|
|
" # Concatenate features along channel dimension\n",
|
|
" combined_features = torch.cat([google_features, yandex_features], dim=1)\n",
|
|
"\n",
|
|
" # Fuse features\n",
|
|
" fused_features = self.fusion_layers(combined_features)\n",
|
|
"\n",
|
|
" # Regression to get homography parameters\n",
|
|
" homography_flat = self.regression_head(fused_features)\n",
|
|
"\n",
|
|
" if return_matrix:\n",
|
|
" # Reshape to 3x3 matrix\n",
|
|
" batch_size = homography_flat.shape[0]\n",
|
|
" homography_matrix = homography_flat.view(batch_size, 3, 3)\n",
|
|
"\n",
|
|
" # Ensure the last element is 1 (homogeneous coordinate normalization)\n",
|
|
" # Add small epsilon to prevent division by zero\n",
|
|
" epsilon = 1e-8\n",
|
|
" homography_matrix = homography_matrix / (\n",
|
|
" homography_matrix[:, 2, 2].view(-1, 1, 1) + epsilon\n",
|
|
" )\n",
|
|
"\n",
|
|
" return homography_matrix\n",
|
|
" else:\n",
|
|
" return homography_flat\n",
|
|
"\n",
|
|
" def predict_homography(\n",
|
|
" self,\n",
|
|
" google_img: torch.Tensor,\n",
|
|
" yandex_img: torch.Tensor,\n",
|
|
" normalize: bool = True,\n",
|
|
" ) -> torch.Tensor:\n",
|
|
" \"\"\"\n",
|
|
" Predict homography matrix with optional normalization.\n",
|
|
"\n",
|
|
" Args:\n",
|
|
" google_img: Google map image tensor\n",
|
|
" yandex_img: Yandex map image tensor\n",
|
|
" normalize: Whether to normalize the homography matrix\n",
|
|
"\n",
|
|
" Returns:\n",
|
|
" Predicted homography matrix\n",
|
|
" \"\"\"\n",
|
|
" self.eval()\n",
|
|
" with torch.no_grad():\n",
|
|
" homography = self.forward(google_img, yandex_img, return_matrix=True)\n",
|
|
"\n",
|
|
" if normalize:\n",
|
|
" # Normalize so that last element is 1\n",
|
|
" homography = homography / homography[:, 2, 2].view(-1, 1, 1)\n",
|
|
"\n",
|
|
" return homography\n",
|
|
"\n",
|
|
"\n",
|
|
"class ResidualBlock(nn.Module):\n",
|
|
" \"\"\"Residual block with optional downsampling.\"\"\"\n",
|
|
"\n",
|
|
" def __init__(\n",
|
|
" self,\n",
|
|
" in_channels: int,\n",
|
|
" out_channels: int,\n",
|
|
" stride: int = 1,\n",
|
|
" dropout_rate: float = 0.3,\n",
|
|
" use_batch_norm: bool = True,\n",
|
|
" ):\n",
|
|
" super().__init__()\n",
|
|
"\n",
|
|
" self.conv1 = nn.Conv2d(\n",
|
|
" in_channels,\n",
|
|
" out_channels,\n",
|
|
" kernel_size=3,\n",
|
|
" stride=stride,\n",
|
|
" padding=1,\n",
|
|
" bias=False,\n",
|
|
" )\n",
|
|
" self.bn1 = nn.BatchNorm2d(out_channels) if use_batch_norm else nn.Identity()\n",
|
|
" self.relu1 = nn.ReLU(inplace=True)\n",
|
|
" self.dropout1 = (\n",
|
|
" nn.Dropout2d(dropout_rate) if dropout_rate > 0 else nn.Identity()\n",
|
|
" )\n",
|
|
"\n",
|
|
" self.conv2 = nn.Conv2d(\n",
|
|
" out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False\n",
|
|
" )\n",
|
|
" self.bn2 = nn.BatchNorm2d(out_channels) if use_batch_norm else nn.Identity()\n",
|
|
" self.relu2 = nn.ReLU(inplace=True)\n",
|
|
" self.dropout2 = (\n",
|
|
" nn.Dropout2d(dropout_rate) if dropout_rate > 0 else nn.Identity()\n",
|
|
" )\n",
|
|
"\n",
|
|
" # Shortcut connection\n",
|
|
" self.shortcut = nn.Sequential()\n",
|
|
" if stride != 1 or in_channels != out_channels:\n",
|
|
" self.shortcut = nn.Sequential(\n",
|
|
" nn.Conv2d(\n",
|
|
" in_channels, out_channels, kernel_size=1, stride=stride, bias=False\n",
|
|
" ),\n",
|
|
" nn.BatchNorm2d(out_channels) if use_batch_norm else nn.Identity(),\n",
|
|
" )\n",
|
|
"\n",
|
|
" def forward(self, x: torch.Tensor) -> torch.Tensor:\n",
|
|
" identity = self.shortcut(x)\n",
|
|
"\n",
|
|
" out = self.conv1(x)\n",
|
|
" out = self.bn1(out)\n",
|
|
" out = self.relu1(out)\n",
|
|
" out = self.dropout1(out)\n",
|
|
"\n",
|
|
" out = self.conv2(out)\n",
|
|
" out = self.bn2(out)\n",
|
|
"\n",
|
|
" out += identity\n",
|
|
" out = self.relu2(out)\n",
|
|
" out = self.dropout2(out)\n",
|
|
"\n",
|
|
" return out\n",
|
|
"\n",
|
|
"\n",
|
|
"class HomographyLoss(nn.Module):\n",
|
|
" \"\"\"\n",
|
|
" Custom loss function for homography estimation.\n",
|
|
"\n",
|
|
" Combines multiple loss terms:\n",
|
|
" 1. Matrix element-wise L2 loss\n",
|
|
" 2. Geometric consistency loss (warping error)\n",
|
|
" 3. Determinant regularization (to prevent degenerate matrices)\n",
|
|
" \"\"\"\n",
|
|
"\n",
|
|
" def __init__(\n",
|
|
" self,\n",
|
|
" matrix_weight: float = 1.0,\n",
|
|
" geometric_weight: float = 0.5,\n",
|
|
" reg_weight: float = 0.1,\n",
|
|
" grid_size: int = 8,\n",
|
|
" ):\n",
|
|
" super().__init__()\n",
|
|
" self.matrix_weight = matrix_weight\n",
|
|
" self.geometric_weight = geometric_weight\n",
|
|
" self.reg_weight = reg_weight\n",
|
|
" self.grid_size = grid_size\n",
|
|
"\n",
|
|
" # Create grid of points for geometric loss\n",
|
|
" self.register_buffer(\n",
|
|
" \"grid_points\",\n",
|
|
" self._create_grid_points(grid_size),\n",
|
|
" persistent=False,\n",
|
|
" )\n",
|
|
"\n",
|
|
" def _create_grid_points(self, grid_size: int) -> torch.Tensor:\n",
|
|
" \"\"\"Create a grid of points for geometric consistency loss.\"\"\"\n",
|
|
" x = torch.linspace(-1, 1, grid_size)\n",
|
|
" y = torch.linspace(-1, 1, grid_size)\n",
|
|
" grid_y, grid_x = torch.meshgrid(y, x, indexing=\"ij\")\n",
|
|
" grid_points = torch.stack([grid_x.flatten(), grid_y.flatten()], dim=1)\n",
|
|
" # Add homogeneous coordinate\n",
|
|
" ones = torch.ones(grid_points.shape[0], 1)\n",
|
|
" grid_points = torch.cat([grid_points, ones], dim=1)\n",
|
|
" return grid_points.T # Shape: (3, grid_size*grid_size)\n",
|
|
"\n",
|
|
" def forward(\n",
|
|
" self,\n",
|
|
" pred_homography: torch.Tensor,\n",
|
|
" target_homography: torch.Tensor,\n",
|
|
" google_img: Optional[torch.Tensor] = None,\n",
|
|
" yandex_img: Optional[torch.Tensor] = None,\n",
|
|
" ) -> torch.Tensor:\n",
|
|
" \"\"\"\n",
|
|
" Compute homography loss.\n",
|
|
"\n",
|
|
" Args:\n",
|
|
" pred_homography: Predicted homography matrices (B, 3, 3)\n",
|
|
" target_homography: Target homography matrices (B, 3, 3)\n",
|
|
" google_img: Google images (optional, for geometric loss)\n",
|
|
" yandex_img: Yandex images (optional, for geometric loss)\n",
|
|
"\n",
|
|
" Returns:\n",
|
|
" Combined loss value\n",
|
|
" \"\"\"\n",
|
|
" batch_size = pred_homography.shape[0]\n",
|
|
"\n",
|
|
" # 1. Matrix element-wise L2 loss\n",
|
|
" matrix_loss = F.mse_loss(pred_homography, target_homography)\n",
|
|
"\n",
|
|
" # 2. Geometric consistency loss (if images provided)\n",
|
|
" geometric_loss = torch.tensor(0.0, device=pred_homography.device)\n",
|
|
" if google_img is not None and yandex_img is not None:\n",
|
|
" # Warp grid points with predicted homography\n",
|
|
" grid_points = self.grid_points.unsqueeze(0).expand(batch_size, -1, -1)\n",
|
|
" warped_points = torch.bmm(pred_homography, grid_points)\n",
|
|
"\n",
|
|
" # Normalize homogeneous coordinates\n",
|
|
" warped_points = warped_points / (warped_points[:, 2:3, :] + 1e-8)\n",
|
|
"\n",
|
|
" # Warp with target homography for comparison\n",
|
|
" target_warped_points = torch.bmm(target_homography, grid_points)\n",
|
|
" target_warped_points = target_warped_points / (\n",
|
|
" target_warped_points[:, 2:3, :] + 1e-8\n",
|
|
" )\n",
|
|
"\n",
|
|
" # Compute point-wise distance\n",
|
|
" geometric_loss = F.mse_loss(\n",
|
|
" warped_points[:, :2, :], target_warped_points[:, :2, :]\n",
|
|
" )\n",
|
|
"\n",
|
|
" # 3. Regularization loss (prevent degenerate matrices)\n",
|
|
" # Encourage determinant to be close to 1\n",
|
|
" pred_det = torch.det(pred_homography)\n",
|
|
" reg_loss = F.mse_loss(pred_det, torch.ones_like(pred_det))\n",
|
|
"\n",
|
|
" # Combine losses\n",
|
|
" total_loss = (\n",
|
|
" self.matrix_weight * matrix_loss\n",
|
|
" + self.geometric_weight * geometric_loss\n",
|
|
" + self.reg_weight * reg_loss\n",
|
|
" )\n",
|
|
"\n",
|
|
" return total_loss\n",
|
|
"\n",
|
|
" def compute_metrics(\n",
|
|
" self,\n",
|
|
" pred_homography: torch.Tensor,\n",
|
|
" target_homography: torch.Tensor,\n",
|
|
" ) -> dict:\n",
|
|
" \"\"\"\n",
|
|
" Compute evaluation metrics for homography estimation.\n",
|
|
"\n",
|
|
" Args:\n",
|
|
" pred_homography: Predicted homography matrices\n",
|
|
" target_homography: Target homography matrices\n",
|
|
"\n",
|
|
" Returns:\n",
|
|
" Dictionary of metrics\n",
|
|
" \"\"\"\n",
|
|
" with torch.no_grad():\n",
|
|
" # Normalize matrices\n",
|
|
" pred_norm = pred_homography / pred_homography[:, 2, 2].view(-1, 1, 1)\n",
|
|
" target_norm = target_homography / target_homography[:, 2, 2].view(-1, 1, 1)\n",
|
|
"\n",
|
|
" # Matrix L2 error\n",
|
|
" matrix_error = F.mse_loss(pred_norm, target_norm, reduction=\"none\").mean(\n",
|
|
" dim=(1, 2)\n",
|
|
" )\n",
|
|
"\n",
|
|
" # Corner error (warp 4 corners of the image)\n",
|
|
" corners = torch.tensor(\n",
|
|
" [[-1, -1, 1], [1, -1, 1], [1, 1, 1], [-1, 1, 1]],\n",
|
|
" dtype=torch.float32,\n",
|
|
" device=pred_homography.device,\n",
|
|
" ).T # Shape: (3, 4)\n",
|
|
"\n",
|
|
" corners = corners.unsqueeze(0).expand(pred_homography.shape[0], -1, -1)\n",
|
|
"\n",
|
|
" pred_corners = torch.bmm(pred_norm, corners)\n",
|
|
" pred_corners = pred_corners / (pred_corners[:, 2:3, :] + 1e-8)\n",
|
|
"\n",
|
|
" target_corners = torch.bmm(target_norm, corners)\n",
|
|
" target_corners = target_corners / (target_corners[:, 2:3, :] + 1e-8)\n",
|
|
"\n",
|
|
" corner_error = torch.mean(\n",
|
|
" torch.norm(pred_corners[:, :2, :] - target_corners[:, :2, :], dim=1),\n",
|
|
" dim=1,\n",
|
|
" )\n",
|
|
"\n",
|
|
" # Average corner error in pixels (assuming image coordinates in [-1, 1])\n",
|
|
" # Convert to pixel error if image size is known\n",
|
|
" avg_corner_error = corner_error.mean().item()\n",
|
|
"\n",
|
|
" return {\n",
|
|
" \"matrix_mse\": matrix_error.mean().item(),\n",
|
|
" \"corner_error\": avg_corner_error,\n",
|
|
" \"corner_error_px\": avg_corner_error * 128, # Assuming 256x256 images\n",
|
|
" }\n",
|
|
"\n",
|
|
"\n",
|
|
"def create_homography_model(\n",
|
|
" model_type: str = \"cnn\",\n",
|
|
" input_size: Tuple[int, int] = (256, 256),\n",
|
|
" **kwargs,\n",
|
|
") -> nn.Module:\n",
|
|
" \"\"\"\n",
|
|
" Factory function to create homography estimation model.\n",
|
|
"\n",
|
|
" Args:\n",
|
|
" model_type: Type of model to create ('cnn' or 'resnet')\n",
|
|
" input_size: Input image size (height, width)\n",
|
|
" **kwargs: Additional arguments passed to model constructor\n",
|
|
"\n",
|
|
" Returns:\n",
|
|
" Homography estimation model\n",
|
|
" \"\"\"\n",
|
|
" if model_type == \"cnn\":\n",
|
|
" return HomographyCNN(**kwargs)\n",
|
|
" else:\n",
|
|
" raise ValueError(f\"Unknown model type: {model_type}\")\n",
|
|
"\n",
|
|
"\n",
|
|
"# Test the model\n",
|
|
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
|
|
"print(f\"Using device: {device}\")\n",
|
|
"\n",
|
|
"# Create model\n",
|
|
"model = HomographyCNN(\n",
|
|
" input_channels=3,\n",
|
|
" hidden_channels=64,\n",
|
|
" num_blocks=4,\n",
|
|
" dropout_rate=0.3,\n",
|
|
" use_batch_norm=True,\n",
|
|
").to(device)\n",
|
|
"\n",
|
|
"print(\n",
|
|
" f\"Model created with {sum(p.numel() for p in model.parameters()):,} parameters\"\n",
|
|
")\n",
|
|
"\n",
|
|
"# Create dummy input\n",
|
|
"batch_size = 4\n",
|
|
"height, width = 700, 700\n",
|
|
"\n",
|
|
"google_img = torch.randn(batch_size, 3, height, width).to(device)\n",
|
|
"yandex_img = torch.randn(batch_size, 3, height, width).to(device)\n",
|
|
"\n",
|
|
"# Test forward pass\n",
|
|
"print(\"\\nTesting forward pass...\")\n",
|
|
"output = model(google_img, yandex_img, return_matrix=True)\n",
|
|
"print(f\"Output shape: {output.shape}\") # Should be (4, 3, 3)\n",
|
|
"print(f\"Sample output:\\n{output[0]}\")\n",
|
|
"\n",
|
|
"# Test prediction\n",
|
|
"print(\"\\nTesting prediction...\")\n",
|
|
"pred = model.predict_homography(google_img, yandex_img)\n",
|
|
"print(f\"Prediction shape: {pred.shape}\")\n",
|
|
"print(f\"Last element (should be ~1): {pred[0, 2, 2]:.6f}\")\n",
|
|
"\n",
|
|
"# Test loss function\n",
|
|
"print(\"\\nTesting loss function...\")\n",
|
|
"target_homography = torch.eye(3).unsqueeze(0).expand(batch_size, -1, -1).to(device)\n",
|
|
"loss_fn = HomographyLoss(\n",
|
|
" matrix_weight=1.0,\n",
|
|
" geometric_weight=0.5,\n",
|
|
" reg_weight=0.1,\n",
|
|
" grid_size=8,\n",
|
|
").to(device)\n",
|
|
"\n",
|
|
"loss = loss_fn(output, target_homography, google_img, yandex_img)\n",
|
|
"print(f\"Loss value: {loss.item():.6f}\")\n",
|
|
"\n",
|
|
"# Test metrics\n",
|
|
"print(\"\\nTesting metrics...\")\n",
|
|
"metrics = loss_fn.compute_metrics(output, target_homography)\n",
|
|
"for key, value in metrics.items():\n",
|
|
" print(f\"{key}: {value:.6f}\")\n",
|
|
"\n",
|
|
"# Test model factory\n",
|
|
"print(\"\\nTesting model factory...\")\n",
|
|
"model2 = create_homography_model(\n",
|
|
" model_type=\"cnn\",\n",
|
|
" input_size=(256, 256),\n",
|
|
" input_channels=3,\n",
|
|
" hidden_channels=32,\n",
|
|
" num_blocks=3,\n",
|
|
").to(device)\n",
|
|
"\n",
|
|
"print(\n",
|
|
" f\"Model2 created with {sum(p.numel() for p in model2.parameters()):,} parameters\"\n",
|
|
")\n",
|
|
"\n",
|
|
"print(\"\\nAll tests completed successfully!\")\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 24,
|
|
"id": "d7979efa",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Using device: cpu\n",
|
|
"Creating data loaders...\n",
|
|
"Found 327 image pairs in C:\\Users\\admin\\Projects\\autopilot\\datasets\\ya_go_maps\\images\n",
|
|
"Train batches: 9\n",
|
|
"Val batches: 3\n",
|
|
"Creating model...\n",
|
|
"Training configuration saved to runs\\homography\\config.json\n",
|
|
"Model has 8,999,817 parameters\n",
|
|
"Starting training for 100 epochs...\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Epoch 1: 0%| | 0/9 [00:05<?, ?it/s]\n"
|
|
]
|
|
},
|
|
{
|
|
"ename": "RuntimeError",
|
|
"evalue": "DataLoader worker (pid(s) 29616) exited unexpectedly",
|
|
"output_type": "error",
|
|
"traceback": [
|
|
"\u001b[31m---------------------------------------------------------------------------\u001b[39m",
|
|
"\u001b[31mEmpty\u001b[39m Traceback (most recent call last)",
|
|
"\u001b[36mFile \u001b[39m\u001b[32mc:\\Users\\admin\\Projects\\autopilot\\.venv\\Lib\\site-packages\\torch\\utils\\data\\dataloader.py:1310\u001b[39m, in \u001b[36m_MultiProcessingDataLoaderIter._try_get_data\u001b[39m\u001b[34m(self, timeout)\u001b[39m\n\u001b[32m 1309\u001b[39m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[32m-> \u001b[39m\u001b[32m1310\u001b[39m data = \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43m_data_queue\u001b[49m\u001b[43m.\u001b[49m\u001b[43mget\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtimeout\u001b[49m\u001b[43m=\u001b[49m\u001b[43mtimeout\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 1311\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m (\u001b[38;5;28;01mTrue\u001b[39;00m, data)\n",
|
|
"\u001b[36mFile \u001b[39m\u001b[32m~\\AppData\\Local\\Programs\\Python\\Python311\\Lib\\multiprocessing\\queues.py:114\u001b[39m, in \u001b[36mQueue.get\u001b[39m\u001b[34m(self, block, timeout)\u001b[39m\n\u001b[32m 113\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28mself\u001b[39m._poll(timeout):\n\u001b[32m--> \u001b[39m\u001b[32m114\u001b[39m \u001b[38;5;28;01mraise\u001b[39;00m Empty\n\u001b[32m 115\u001b[39m \u001b[38;5;28;01melif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28mself\u001b[39m._poll():\n",
|
|
"\u001b[31mEmpty\u001b[39m: ",
|
|
"\nThe above exception was the direct cause of the following exception:\n",
|
|
"\u001b[31mRuntimeError\u001b[39m Traceback (most recent call last)",
|
|
"\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[24]\u001b[39m\u001b[32m, line 533\u001b[39m\n\u001b[32m 530\u001b[39m trainer.evaluate()\n\u001b[32m 531\u001b[39m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[32m 532\u001b[39m \u001b[38;5;66;03m# Train the model\u001b[39;00m\n\u001b[32m--> \u001b[39m\u001b[32m533\u001b[39m \u001b[43mtrainer\u001b[49m\u001b[43m.\u001b[49m\u001b[43mtrain\u001b[49m\u001b[43m(\u001b[49m\u001b[43mnum_epochs\u001b[49m\u001b[43m=\u001b[49m\u001b[43margs\u001b[49m\u001b[43m.\u001b[49m\u001b[43mepochs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 535\u001b[39m \u001b[38;5;66;03m# Final evaluation\u001b[39;00m\n\u001b[32m 536\u001b[39m \u001b[38;5;28mprint\u001b[39m(\u001b[33m\"\u001b[39m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[33mPerforming final evaluation...\u001b[39m\u001b[33m\"\u001b[39m)\n",
|
|
"\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[24]\u001b[39m\u001b[32m, line 299\u001b[39m, in \u001b[36mHomographyTrainer.train\u001b[39m\u001b[34m(self, num_epochs)\u001b[39m\n\u001b[32m 296\u001b[39m \u001b[38;5;28mself\u001b[39m.current_epoch = epoch\n\u001b[32m 298\u001b[39m \u001b[38;5;66;03m# Train for one epoch\u001b[39;00m\n\u001b[32m--> \u001b[39m\u001b[32m299\u001b[39m train_loss = \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43mtrain_epoch\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 301\u001b[39m \u001b[38;5;66;03m# Validate\u001b[39;00m\n\u001b[32m 302\u001b[39m val_loss, val_metrics = \u001b[38;5;28mself\u001b[39m.validate()\n",
|
|
"\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[24]\u001b[39m\u001b[32m, line 143\u001b[39m, in \u001b[36mHomographyTrainer.train_epoch\u001b[39m\u001b[34m(self)\u001b[39m\n\u001b[32m 140\u001b[39m num_batches = \u001b[38;5;28mlen\u001b[39m(\u001b[38;5;28mself\u001b[39m.train_loader)\n\u001b[32m 142\u001b[39m progress_bar = tqdm(\u001b[38;5;28mself\u001b[39m.train_loader, desc=\u001b[33mf\u001b[39m\u001b[33m\"\u001b[39m\u001b[33mEpoch \u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28mself\u001b[39m.current_epoch\u001b[38;5;250m \u001b[39m+\u001b[38;5;250m \u001b[39m\u001b[32m1\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m\"\u001b[39m)\n\u001b[32m--> \u001b[39m\u001b[32m143\u001b[39m \u001b[43m\u001b[49m\u001b[38;5;28;43;01mfor\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43mbatch_idx\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mbatch\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;129;43;01min\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[38;5;28;43menumerate\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43mprogress_bar\u001b[49m\u001b[43m)\u001b[49m\u001b[43m:\u001b[49m\n\u001b[32m 144\u001b[39m \u001b[43m \u001b[49m\u001b[38;5;66;43;03m# Move data to device\u001b[39;49;00m\n\u001b[32m 145\u001b[39m \u001b[43m \u001b[49m\u001b[43mgoogle_img\u001b[49m\u001b[43m \u001b[49m\u001b[43m=\u001b[49m\u001b[43m \u001b[49m\u001b[43mbatch\u001b[49m\u001b[43m[\u001b[49m\u001b[33;43m\"\u001b[39;49m\u001b[33;43mgoogle_img\u001b[39;49m\u001b[33;43m\"\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m.\u001b[49m\u001b[43mto\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43mdevice\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 146\u001b[39m \u001b[43m \u001b[49m\u001b[43myandex_img\u001b[49m\u001b[43m \u001b[49m\u001b[43m=\u001b[49m\u001b[43m \u001b[49m\u001b[43mbatch\u001b[49m\u001b[43m[\u001b[49m\u001b[33;43m\"\u001b[39;49m\u001b[33;43myandex_img\u001b[39;49m\u001b[33;43m\"\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m.\u001b[49m\u001b[43mto\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43mdevice\u001b[49m\u001b[43m)\u001b[49m\n",
|
|
"\u001b[36mFile \u001b[39m\u001b[32mc:\\Users\\admin\\Projects\\autopilot\\.venv\\Lib\\site-packages\\tqdm\\std.py:1181\u001b[39m, in \u001b[36mtqdm.__iter__\u001b[39m\u001b[34m(self)\u001b[39m\n\u001b[32m 1178\u001b[39m time = \u001b[38;5;28mself\u001b[39m._time\n\u001b[32m 1180\u001b[39m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[32m-> \u001b[39m\u001b[32m1181\u001b[39m \u001b[43m \u001b[49m\u001b[38;5;28;43;01mfor\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43mobj\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;129;43;01min\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43miterable\u001b[49m\u001b[43m:\u001b[49m\n\u001b[32m 1182\u001b[39m \u001b[43m \u001b[49m\u001b[38;5;28;43;01myield\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43mobj\u001b[49m\n\u001b[32m 1183\u001b[39m \u001b[43m \u001b[49m\u001b[38;5;66;43;03m# Update and possibly print the progressbar.\u001b[39;49;00m\n\u001b[32m 1184\u001b[39m \u001b[43m \u001b[49m\u001b[38;5;66;43;03m# Note: does not call self.update(1) for speed optimisation.\u001b[39;49;00m\n",
|
|
"\u001b[36mFile \u001b[39m\u001b[32mc:\\Users\\admin\\Projects\\autopilot\\.venv\\Lib\\site-packages\\torch\\utils\\data\\dataloader.py:741\u001b[39m, in \u001b[36m_BaseDataLoaderIter.__next__\u001b[39m\u001b[34m(self)\u001b[39m\n\u001b[32m 738\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m._sampler_iter \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[32m 739\u001b[39m \u001b[38;5;66;03m# TODO(https://github.com/pytorch/pytorch/issues/76750)\u001b[39;00m\n\u001b[32m 740\u001b[39m \u001b[38;5;28mself\u001b[39m._reset() \u001b[38;5;66;03m# type: ignore[call-arg]\u001b[39;00m\n\u001b[32m--> \u001b[39m\u001b[32m741\u001b[39m data = \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43m_next_data\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 742\u001b[39m \u001b[38;5;28mself\u001b[39m._num_yielded += \u001b[32m1\u001b[39m\n\u001b[32m 743\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m (\n\u001b[32m 744\u001b[39m \u001b[38;5;28mself\u001b[39m._dataset_kind == _DatasetKind.Iterable\n\u001b[32m 745\u001b[39m \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;28mself\u001b[39m._IterableDataset_len_called \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[32m 746\u001b[39m \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;28mself\u001b[39m._num_yielded > \u001b[38;5;28mself\u001b[39m._IterableDataset_len_called\n\u001b[32m 747\u001b[39m ):\n",
|
|
"\u001b[36mFile \u001b[39m\u001b[32mc:\\Users\\admin\\Projects\\autopilot\\.venv\\Lib\\site-packages\\torch\\utils\\data\\dataloader.py:1524\u001b[39m, in \u001b[36m_MultiProcessingDataLoaderIter._next_data\u001b[39m\u001b[34m(self)\u001b[39m\n\u001b[32m 1520\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m._shutdown \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m._tasks_outstanding <= \u001b[32m0\u001b[39m:\n\u001b[32m 1521\u001b[39m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mAssertionError\u001b[39;00m(\n\u001b[32m 1522\u001b[39m \u001b[33m\"\u001b[39m\u001b[33mInvalid iterator state: shutdown or no outstanding tasks when fetching next data\u001b[39m\u001b[33m\"\u001b[39m\n\u001b[32m 1523\u001b[39m )\n\u001b[32m-> \u001b[39m\u001b[32m1524\u001b[39m idx, data = \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43m_get_data\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 1525\u001b[39m \u001b[38;5;28mself\u001b[39m._tasks_outstanding -= \u001b[32m1\u001b[39m\n\u001b[32m 1526\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m._dataset_kind == _DatasetKind.Iterable:\n\u001b[32m 1527\u001b[39m \u001b[38;5;66;03m# Check for _IterableDatasetStopIteration\u001b[39;00m\n",
|
|
"\u001b[36mFile \u001b[39m\u001b[32mc:\\Users\\admin\\Projects\\autopilot\\.venv\\Lib\\site-packages\\torch\\utils\\data\\dataloader.py:1483\u001b[39m, in \u001b[36m_MultiProcessingDataLoaderIter._get_data\u001b[39m\u001b[34m(self)\u001b[39m\n\u001b[32m 1479\u001b[39m \u001b[38;5;66;03m# In this case, `self._data_queue` is a `queue.Queue`,. But we don't\u001b[39;00m\n\u001b[32m 1480\u001b[39m \u001b[38;5;66;03m# need to call `.task_done()` because we don't use `.join()`.\u001b[39;00m\n\u001b[32m 1481\u001b[39m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[32m 1482\u001b[39m \u001b[38;5;28;01mwhile\u001b[39;00m \u001b[38;5;28;01mTrue\u001b[39;00m:\n\u001b[32m-> \u001b[39m\u001b[32m1483\u001b[39m success, data = \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43m_try_get_data\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 1484\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m success:\n\u001b[32m 1485\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m data\n",
|
|
"\u001b[36mFile \u001b[39m\u001b[32mc:\\Users\\admin\\Projects\\autopilot\\.venv\\Lib\\site-packages\\torch\\utils\\data\\dataloader.py:1323\u001b[39m, in \u001b[36m_MultiProcessingDataLoaderIter._try_get_data\u001b[39m\u001b[34m(self, timeout)\u001b[39m\n\u001b[32m 1321\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mlen\u001b[39m(failed_workers) > \u001b[32m0\u001b[39m:\n\u001b[32m 1322\u001b[39m pids_str = \u001b[33m\"\u001b[39m\u001b[33m, \u001b[39m\u001b[33m\"\u001b[39m.join(\u001b[38;5;28mstr\u001b[39m(w.pid) \u001b[38;5;28;01mfor\u001b[39;00m w \u001b[38;5;129;01min\u001b[39;00m failed_workers)\n\u001b[32m-> \u001b[39m\u001b[32m1323\u001b[39m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mRuntimeError\u001b[39;00m(\n\u001b[32m 1324\u001b[39m \u001b[33mf\u001b[39m\u001b[33m\"\u001b[39m\u001b[33mDataLoader worker (pid(s) \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mpids_str\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m) exited unexpectedly\u001b[39m\u001b[33m\"\u001b[39m\n\u001b[32m 1325\u001b[39m ) \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01me\u001b[39;00m\n\u001b[32m 1326\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(e, queue.Empty):\n\u001b[32m 1327\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m (\u001b[38;5;28;01mFalse\u001b[39;00m, \u001b[38;5;28;01mNone\u001b[39;00m)\n",
|
|
"\u001b[31mRuntimeError\u001b[39m: DataLoader worker (pid(s) 29616) exited unexpectedly"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"\"\"\"\n",
|
|
"Training script for homography estimation between Google and Yandex map images.\n",
|
|
"\n",
|
|
"This script trains a CNN model to estimate homography matrices that align\n",
|
|
"Google map images with Yandex map images.\n",
|
|
"\"\"\"\n",
|
|
"\n",
|
|
"import argparse\n",
|
|
"import json\n",
|
|
"import os\n",
|
|
"import time\n",
|
|
"from datetime import datetime\n",
|
|
"from pathlib import Path\n",
|
|
"from typing import Dict, List, Optional, Tuple\n",
|
|
"\n",
|
|
"import numpy as np\n",
|
|
"import torch\n",
|
|
"import torch.nn as nn\n",
|
|
"import torch.optim as optim\n",
|
|
"from torch.utils.data import DataLoader\n",
|
|
"from torch.utils.tensorboard import SummaryWriter\n",
|
|
"from tqdm import tqdm\n",
|
|
"\n",
|
|
"\n",
|
|
"class HomographyTrainer:\n",
|
|
" \"\"\"Trainer class for homography estimation model.\"\"\"\n",
|
|
"\n",
|
|
" def __init__(\n",
|
|
" self,\n",
|
|
" model: nn.Module,\n",
|
|
" train_loader: DataLoader,\n",
|
|
" val_loader: DataLoader,\n",
|
|
" device: torch.device,\n",
|
|
" config: Dict,\n",
|
|
" ):\n",
|
|
" \"\"\"\n",
|
|
" Initialize the trainer.\n",
|
|
"\n",
|
|
" Args:\n",
|
|
" model: Homography estimation model\n",
|
|
" train_loader: Training data loader\n",
|
|
" val_loader: Validation data loader\n",
|
|
" device: Device to run training on\n",
|
|
" config: Training configuration dictionary\n",
|
|
" \"\"\"\n",
|
|
" self.model = model.to(device)\n",
|
|
" self.train_loader = train_loader\n",
|
|
" self.val_loader = val_loader\n",
|
|
" self.device = device\n",
|
|
" self.config = config\n",
|
|
"\n",
|
|
" # Loss function\n",
|
|
" self.criterion = HomographyLoss(\n",
|
|
" matrix_weight=config.get(\"matrix_weight\", 1.0),\n",
|
|
" geometric_weight=config.get(\"geometric_weight\", 0.5),\n",
|
|
" reg_weight=config.get(\"reg_weight\", 0.1),\n",
|
|
" grid_size=config.get(\"grid_size\", 8),\n",
|
|
" ).to(device)\n",
|
|
"\n",
|
|
" # Optimizer\n",
|
|
" optimizer_name = config.get(\"optimizer\", \"adam\").lower()\n",
|
|
" lr = config.get(\"learning_rate\", 1e-3)\n",
|
|
" weight_decay = config.get(\"weight_decay\", 1e-4)\n",
|
|
"\n",
|
|
" if optimizer_name == \"adam\":\n",
|
|
" self.optimizer = optim.Adam(\n",
|
|
" self.model.parameters(), lr=lr, weight_decay=weight_decay\n",
|
|
" )\n",
|
|
" elif optimizer_name == \"adamw\":\n",
|
|
" self.optimizer = optim.AdamW(\n",
|
|
" self.model.parameters(), lr=lr, weight_decay=weight_decay\n",
|
|
" )\n",
|
|
" elif optimizer_name == \"sgd\":\n",
|
|
" self.optimizer = optim.SGD(\n",
|
|
" self.model.parameters(),\n",
|
|
" lr=lr,\n",
|
|
" momentum=config.get(\"momentum\", 0.9),\n",
|
|
" weight_decay=weight_decay,\n",
|
|
" )\n",
|
|
" else:\n",
|
|
" raise ValueError(f\"Unknown optimizer: {optimizer_name}\")\n",
|
|
"\n",
|
|
" # Learning rate scheduler\n",
|
|
" scheduler_name = config.get(\"scheduler\", \"plateau\").lower()\n",
|
|
" if scheduler_name == \"plateau\":\n",
|
|
" self.scheduler = optim.lr_scheduler.ReduceLROnPlateau(\n",
|
|
" self.optimizer,\n",
|
|
" mode=\"min\",\n",
|
|
" factor=config.get(\"scheduler_factor\", 0.5),\n",
|
|
" patience=config.get(\"scheduler_patience\", 5),\n",
|
|
" )\n",
|
|
" elif scheduler_name == \"cosine\":\n",
|
|
" self.scheduler = optim.lr_scheduler.CosineAnnealingLR(\n",
|
|
" self.optimizer,\n",
|
|
" T_max=config.get(\"epochs\", 100),\n",
|
|
" eta_min=config.get(\"min_lr\", 1e-6),\n",
|
|
" )\n",
|
|
" elif scheduler_name == \"step\":\n",
|
|
" self.scheduler = optim.lr_scheduler.StepLR(\n",
|
|
" self.optimizer,\n",
|
|
" step_size=config.get(\"step_size\", 30),\n",
|
|
" gamma=config.get(\"gamma\", 0.1),\n",
|
|
" )\n",
|
|
" else:\n",
|
|
" self.scheduler = None\n",
|
|
"\n",
|
|
" # Training state\n",
|
|
" self.current_epoch = 0\n",
|
|
" self.best_val_loss = float(\"inf\")\n",
|
|
" self.train_losses: List[float] = []\n",
|
|
" self.val_losses: List[float] = []\n",
|
|
" self.val_metrics: List[Dict] = []\n",
|
|
"\n",
|
|
" # Create output directory\n",
|
|
" self.output_dir = Path(config.get(\"output_dir\", \"runs/homography\"))\n",
|
|
" self.output_dir.mkdir(parents=True, exist_ok=True)\n",
|
|
"\n",
|
|
" # TensorBoard writer\n",
|
|
" self.writer = SummaryWriter(log_dir=self.output_dir / \"tensorboard\")\n",
|
|
"\n",
|
|
" # Save configuration\n",
|
|
" config_path = self.output_dir / \"config.json\"\n",
|
|
" with open(config_path, \"w\") as f:\n",
|
|
" json.dump(config, f, indent=2)\n",
|
|
"\n",
|
|
" print(f\"Training configuration saved to {config_path}\")\n",
|
|
" print(\n",
|
|
" f\"Model has {sum(p.numel() for p in self.model.parameters()):,} parameters\"\n",
|
|
" )\n",
|
|
"\n",
|
|
" def train_epoch(self) -> float:\n",
|
|
" \"\"\"\n",
|
|
" Train for one epoch.\n",
|
|
"\n",
|
|
" Returns:\n",
|
|
" Average training loss for the epoch\n",
|
|
" \"\"\"\n",
|
|
" self.model.train()\n",
|
|
" total_loss = 0.0\n",
|
|
" num_batches = len(self.train_loader)\n",
|
|
"\n",
|
|
" progress_bar = tqdm(self.train_loader, desc=f\"Epoch {self.current_epoch + 1}\")\n",
|
|
" for batch_idx, batch in enumerate(progress_bar):\n",
|
|
" # Move data to device\n",
|
|
" google_img = batch[\"google_img\"].to(self.device)\n",
|
|
" yandex_img = batch[\"yandex_img\"].to(self.device)\n",
|
|
" target_homography = batch[\"homography\"].to(self.device)\n",
|
|
"\n",
|
|
" # Forward pass\n",
|
|
" self.optimizer.zero_grad()\n",
|
|
" pred_homography = self.model(google_img, yandex_img, return_matrix=True)\n",
|
|
"\n",
|
|
" # Compute loss\n",
|
|
" loss = self.criterion(\n",
|
|
" pred_homography,\n",
|
|
" target_homography,\n",
|
|
" google_img,\n",
|
|
" yandex_img,\n",
|
|
" )\n",
|
|
"\n",
|
|
" # Backward pass\n",
|
|
" loss.backward()\n",
|
|
"\n",
|
|
" # Gradient clipping\n",
|
|
" if self.config.get(\"grad_clip\", 1.0) > 0:\n",
|
|
" torch.nn.utils.clip_grad_norm_(\n",
|
|
" self.model.parameters(),\n",
|
|
" self.config.get(\"grad_clip\", 1.0),\n",
|
|
" )\n",
|
|
"\n",
|
|
" # Optimizer step\n",
|
|
" self.optimizer.step()\n",
|
|
"\n",
|
|
" # Update statistics\n",
|
|
" total_loss += loss.item()\n",
|
|
"\n",
|
|
" # Update progress bar\n",
|
|
" progress_bar.set_postfix({\"loss\": loss.item()})\n",
|
|
"\n",
|
|
" # Log batch loss to TensorBoard\n",
|
|
" global_step = self.current_epoch * num_batches + batch_idx\n",
|
|
" self.writer.add_scalar(\"train/batch_loss\", loss.item(), global_step)\n",
|
|
"\n",
|
|
" avg_loss = total_loss / num_batches\n",
|
|
" self.train_losses.append(avg_loss)\n",
|
|
"\n",
|
|
" return avg_loss\n",
|
|
"\n",
|
|
" @torch.no_grad()\n",
|
|
" def validate(self) -> Tuple[float, Dict]:\n",
|
|
" \"\"\"\n",
|
|
" Validate the model.\n",
|
|
"\n",
|
|
" Returns:\n",
|
|
" Tuple of (average validation loss, validation metrics)\n",
|
|
" \"\"\"\n",
|
|
" self.model.eval()\n",
|
|
" total_loss = 0.0\n",
|
|
" all_metrics = []\n",
|
|
"\n",
|
|
" progress_bar = tqdm(self.val_loader, desc=\"Validation\")\n",
|
|
" for batch in progress_bar:\n",
|
|
" # Move data to device\n",
|
|
" google_img = batch[\"google_img\"].to(self.device)\n",
|
|
" yandex_img = batch[\"yandex_img\"].to(self.device)\n",
|
|
" target_homography = batch[\"homography\"].to(self.device)\n",
|
|
"\n",
|
|
" # Forward pass\n",
|
|
" pred_homography = self.model(google_img, yandex_img, return_matrix=True)\n",
|
|
"\n",
|
|
" # Compute loss\n",
|
|
" loss = self.criterion(\n",
|
|
" pred_homography,\n",
|
|
" target_homography,\n",
|
|
" google_img,\n",
|
|
" yandex_img,\n",
|
|
" )\n",
|
|
"\n",
|
|
" # Compute metrics\n",
|
|
" metrics = self.criterion.compute_metrics(pred_homography, target_homography)\n",
|
|
"\n",
|
|
" # Update statistics\n",
|
|
" total_loss += loss.item()\n",
|
|
" all_metrics.append(metrics)\n",
|
|
"\n",
|
|
" # Update progress bar\n",
|
|
" progress_bar.set_postfix({\"loss\": loss.item()})\n",
|
|
"\n",
|
|
" avg_loss = total_loss / len(self.val_loader)\n",
|
|
" self.val_losses.append(avg_loss)\n",
|
|
"\n",
|
|
" # Aggregate metrics\n",
|
|
" avg_metrics = {}\n",
|
|
" for key in all_metrics[0].keys():\n",
|
|
" avg_metrics[key] = np.mean([m[key] for m in all_metrics])\n",
|
|
"\n",
|
|
" self.val_metrics.append(avg_metrics)\n",
|
|
"\n",
|
|
" return avg_loss, avg_metrics\n",
|
|
"\n",
|
|
" def save_checkpoint(self, is_best: bool = False):\n",
|
|
" \"\"\"Save model checkpoint.\"\"\"\n",
|
|
" checkpoint = {\n",
|
|
" \"epoch\": self.current_epoch,\n",
|
|
" \"model_state_dict\": self.model.state_dict(),\n",
|
|
" \"optimizer_state_dict\": self.optimizer.state_dict(),\n",
|
|
" \"train_losses\": self.train_losses,\n",
|
|
" \"val_losses\": self.val_losses,\n",
|
|
" \"val_metrics\": self.val_metrics,\n",
|
|
" \"best_val_loss\": self.best_val_loss,\n",
|
|
" \"config\": self.config,\n",
|
|
" }\n",
|
|
"\n",
|
|
" if self.scheduler is not None:\n",
|
|
" checkpoint[\"scheduler_state_dict\"] = self.scheduler.state_dict()\n",
|
|
"\n",
|
|
" # Save latest checkpoint\n",
|
|
" checkpoint_path = self.output_dir / \"checkpoint_latest.pth\"\n",
|
|
" torch.save(checkpoint, checkpoint_path)\n",
|
|
"\n",
|
|
" # Save best checkpoint\n",
|
|
" if is_best:\n",
|
|
" best_path = self.output_dir / \"checkpoint_best.pth\"\n",
|
|
" torch.save(checkpoint, best_path)\n",
|
|
" print(f\"Best model saved to {best_path}\")\n",
|
|
"\n",
|
|
" def load_checkpoint(self, checkpoint_path: str):\n",
|
|
" \"\"\"Load model checkpoint.\"\"\"\n",
|
|
" checkpoint = torch.load(checkpoint_path, map_location=self.device)\n",
|
|
"\n",
|
|
" self.model.load_state_dict(checkpoint[\"model_state_dict\"])\n",
|
|
" self.optimizer.load_state_dict(checkpoint[\"optimizer_state_dict\"])\n",
|
|
"\n",
|
|
" if self.scheduler is not None and \"scheduler_state_dict\" in checkpoint:\n",
|
|
" self.scheduler.load_state_dict(checkpoint[\"scheduler_state_dict\"])\n",
|
|
"\n",
|
|
" self.current_epoch = checkpoint[\"epoch\"]\n",
|
|
" self.train_losses = checkpoint[\"train_losses\"]\n",
|
|
" self.val_losses = checkpoint[\"val_losses\"]\n",
|
|
" self.val_metrics = checkpoint[\"val_metrics\"]\n",
|
|
" self.best_val_loss = checkpoint[\"best_val_loss\"]\n",
|
|
"\n",
|
|
" print(f\"Loaded checkpoint from epoch {self.current_epoch}\")\n",
|
|
"\n",
|
|
" def train(self, num_epochs: int):\n",
|
|
" \"\"\"\n",
|
|
" Train the model for specified number of epochs.\n",
|
|
"\n",
|
|
" Args:\n",
|
|
" num_epochs: Number of epochs to train\n",
|
|
" \"\"\"\n",
|
|
" print(f\"Starting training for {num_epochs} epochs...\")\n",
|
|
" start_time = time.time()\n",
|
|
"\n",
|
|
" for epoch in range(num_epochs):\n",
|
|
" self.current_epoch = epoch\n",
|
|
"\n",
|
|
" # Train for one epoch\n",
|
|
" train_loss = self.train_epoch()\n",
|
|
"\n",
|
|
" # Validate\n",
|
|
" val_loss, val_metrics = self.validate()\n",
|
|
"\n",
|
|
" # Update learning rate scheduler\n",
|
|
" if self.scheduler is not None:\n",
|
|
" if isinstance(self.scheduler, optim.lr_scheduler.ReduceLROnPlateau):\n",
|
|
" self.scheduler.step(val_loss)\n",
|
|
" else:\n",
|
|
" self.scheduler.step()\n",
|
|
"\n",
|
|
" # Log to TensorBoard\n",
|
|
" self.writer.add_scalar(\"train/epoch_loss\", train_loss, epoch)\n",
|
|
" self.writer.add_scalar(\"val/epoch_loss\", val_loss, epoch)\n",
|
|
" for metric_name, metric_value in val_metrics.items():\n",
|
|
" self.writer.add_scalar(f\"val/{metric_name}\", metric_value, epoch)\n",
|
|
"\n",
|
|
" # Print epoch summary\n",
|
|
" print(f\"\\nEpoch {epoch + 1}/{num_epochs}:\")\n",
|
|
" print(f\" Train Loss: {train_loss:.6f}\")\n",
|
|
" print(f\" Val Loss: {val_loss:.6f}\")\n",
|
|
" print(\" Val Metrics:\")\n",
|
|
" for metric_name, metric_value in val_metrics.items():\n",
|
|
" print(f\" {metric_name}: {metric_value:.6f}\")\n",
|
|
"\n",
|
|
" # Save checkpoint\n",
|
|
" is_best = val_loss < self.best_val_loss\n",
|
|
" if is_best:\n",
|
|
" self.best_val_loss = val_loss\n",
|
|
"\n",
|
|
" self.save_checkpoint(is_best=is_best)\n",
|
|
"\n",
|
|
" # Early stopping\n",
|
|
" if self.config.get(\"early_stopping_patience\", 0) > 0:\n",
|
|
" if (\n",
|
|
" epoch - np.argmin(self.val_losses)\n",
|
|
" >= self.config[\"early_stopping_patience\"]\n",
|
|
" ):\n",
|
|
" print(f\"Early stopping at epoch {epoch + 1}\")\n",
|
|
" break\n",
|
|
"\n",
|
|
" # Training completed\n",
|
|
" training_time = time.time() - start_time\n",
|
|
" print(f\"\\nTraining completed in {training_time:.2f} seconds\")\n",
|
|
" print(f\"Best validation loss: {self.best_val_loss:.6f}\")\n",
|
|
"\n",
|
|
" # Save final model\n",
|
|
" final_model_path = self.output_dir / \"model_final.pth\"\n",
|
|
" torch.save(self.model.state_dict(), final_model_path)\n",
|
|
" print(f\"Final model saved to {final_model_path}\")\n",
|
|
"\n",
|
|
" # Close TensorBoard writer\n",
|
|
" self.writer.close()\n",
|
|
"\n",
|
|
" def evaluate(self, test_loader: Optional[DataLoader] = None) -> Dict:\n",
|
|
" \"\"\"\n",
|
|
" Evaluate the model on test data.\n",
|
|
"\n",
|
|
" Args:\n",
|
|
" test_loader: Test data loader (uses validation loader if None)\n",
|
|
"\n",
|
|
" Returns:\n",
|
|
" Dictionary of evaluation metrics\n",
|
|
" \"\"\"\n",
|
|
" if test_loader is None:\n",
|
|
" test_loader = self.val_loader\n",
|
|
"\n",
|
|
" self.model.eval()\n",
|
|
" all_metrics = []\n",
|
|
"\n",
|
|
" print(\"Evaluating model...\")\n",
|
|
" with torch.no_grad():\n",
|
|
" for batch in tqdm(test_loader):\n",
|
|
" # Move data to device\n",
|
|
" google_img = batch[\"google_img\"].to(self.device)\n",
|
|
" yandex_img = batch[\"yandex_img\"].to(self.device)\n",
|
|
" target_homography = batch[\"homography\"].to(self.device)\n",
|
|
"\n",
|
|
" # Forward pass\n",
|
|
" pred_homography = self.model(google_img, yandex_img, return_matrix=True)\n",
|
|
"\n",
|
|
" # Compute metrics\n",
|
|
" metrics = self.criterion.compute_metrics(\n",
|
|
" pred_homography, target_homography\n",
|
|
" )\n",
|
|
" all_metrics.append(metrics)\n",
|
|
"\n",
|
|
" # Aggregate metrics\n",
|
|
" avg_metrics = {}\n",
|
|
" for key in all_metrics[0].keys():\n",
|
|
" avg_metrics[key] = np.mean([m[key] for m in all_metrics])\n",
|
|
"\n",
|
|
" # Print evaluation results\n",
|
|
" print(\"\\nEvaluation Results:\")\n",
|
|
" for metric_name, metric_value in avg_metrics.items():\n",
|
|
" print(f\" {metric_name}: {metric_value:.6f}\")\n",
|
|
"\n",
|
|
" # Save evaluation results\n",
|
|
" eval_path = self.output_dir / \"evaluation_results.json\"\n",
|
|
" with open(eval_path, \"w\") as f:\n",
|
|
" json.dump(avg_metrics, f, indent=2)\n",
|
|
" print(f\"Evaluation results saved to {eval_path}\")\n",
|
|
"\n",
|
|
" return avg_metrics\n",
|
|
"\n",
|
|
"\n",
|
|
"from types import SimpleNamespace\n",
|
|
"\n",
|
|
"# Дефолтные значения параметров\n",
|
|
"args = SimpleNamespace(\n",
|
|
" # Data arguments\n",
|
|
" data_dir=r\"C:\\Users\\admin\\Projects\\autopilot\\datasets\\ya_go_maps\\images\",\n",
|
|
" batch_size=32,\n",
|
|
" image_size=[256, 256],\n",
|
|
" train_split=0.8,\n",
|
|
" num_workers=1,\n",
|
|
" \n",
|
|
" # Model arguments\n",
|
|
" model_type=\"cnn\",\n",
|
|
" hidden_channels=64,\n",
|
|
" num_blocks=4,\n",
|
|
" dropout_rate=0.3,\n",
|
|
" use_batch_norm=False,\n",
|
|
" \n",
|
|
" # Training arguments\n",
|
|
" epochs=100,\n",
|
|
" lr=1e-3,\n",
|
|
" weight_decay=1e-4,\n",
|
|
" optimizer=\"adam\",\n",
|
|
" scheduler=\"plateau\",\n",
|
|
" grad_clip=1.0,\n",
|
|
" \n",
|
|
" # Loss arguments\n",
|
|
" matrix_weight=1.0,\n",
|
|
" geometric_weight=0.5,\n",
|
|
" reg_weight=0.1,\n",
|
|
" \n",
|
|
" # Other arguments\n",
|
|
" output_dir=\"runs/homography\",\n",
|
|
" resume=None,\n",
|
|
" eval_only=False,\n",
|
|
" seed=42\n",
|
|
")\n",
|
|
"\n",
|
|
"\n",
|
|
" # Set random seeds for reproducibility\n",
|
|
"torch.manual_seed(args.seed)\n",
|
|
"np.random.seed(args.seed)\n",
|
|
"if torch.cuda.is_available():\n",
|
|
" torch.cuda.manual_seed(args.seed)\n",
|
|
" torch.cuda.manual_seed_all(args.seed)\n",
|
|
"\n",
|
|
"# Set device\n",
|
|
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
|
|
"print(f\"Using device: {device}\")\n",
|
|
"\n",
|
|
"# Create data loaders\n",
|
|
"print(\"Creating data loaders...\")\n",
|
|
"train_loader, val_loader = create_data_loaders(\n",
|
|
" root_dir=args.data_dir,\n",
|
|
" batch_size=args.batch_size,\n",
|
|
" train_split=args.train_split,\n",
|
|
" num_workers=args.num_workers,\n",
|
|
" image_size=tuple(args.image_size),\n",
|
|
" augment_train=False,\n",
|
|
" augment_val=False,\n",
|
|
")\n",
|
|
"\n",
|
|
"print(f\"Train batches: {len(train_loader)}\")\n",
|
|
"print(f\"Val batches: {len(val_loader)}\")\n",
|
|
"\n",
|
|
"# Create model\n",
|
|
"print(\"Creating model...\")\n",
|
|
"model = create_homography_model(\n",
|
|
" model_type=args.model_type,\n",
|
|
" input_size=tuple(args.image_size),\n",
|
|
" input_channels=3,\n",
|
|
" hidden_channels=args.hidden_channels,\n",
|
|
" num_blocks=args.num_blocks,\n",
|
|
" dropout_rate=args.dropout_rate,\n",
|
|
" use_batch_norm=args.use_batch_norm,\n",
|
|
")\n",
|
|
"\n",
|
|
"# Create trainer configuration\n",
|
|
"config = {\n",
|
|
" # Model config\n",
|
|
" \"model_type\": args.model_type,\n",
|
|
" \"hidden_channels\": args.hidden_channels,\n",
|
|
" \"num_blocks\": args.num_blocks,\n",
|
|
" \"dropout_rate\": args.dropout_rate,\n",
|
|
" \"use_batch_norm\": args.use_batch_norm,\n",
|
|
" \"image_size\": args.image_size,\n",
|
|
" # Training config\n",
|
|
" \"epochs\": args.epochs,\n",
|
|
" \"batch_size\": args.batch_size,\n",
|
|
" \"learning_rate\": args.lr,\n",
|
|
" \"weight_decay\": args.weight_decay,\n",
|
|
" \"optimizer\": args.optimizer,\n",
|
|
" \"scheduler\": args.scheduler,\n",
|
|
" \"grad_clip\": args.grad_clip,\n",
|
|
" # Loss config\n",
|
|
" \"matrix_weight\": args.matrix_weight,\n",
|
|
" \"geometric_weight\": args.geometric_weight,\n",
|
|
" \"reg_weight\": args.reg_weight,\n",
|
|
" \"grid_size\": 8,\n",
|
|
" # Data config\n",
|
|
" \"data_dir\": args.data_dir,\n",
|
|
" \"train_split\": args.train_split,\n",
|
|
" \"num_workers\": args.num_workers,\n",
|
|
" # Output config\n",
|
|
" \"output_dir\": args.output_dir,\n",
|
|
" \"seed\": args.seed,\n",
|
|
"}\n",
|
|
"\n",
|
|
"# Create trainer\n",
|
|
"trainer = HomographyTrainer(\n",
|
|
" model=model,\n",
|
|
" train_loader=train_loader,\n",
|
|
" val_loader=val_loader,\n",
|
|
" device=device,\n",
|
|
" config=config,\n",
|
|
")\n",
|
|
"\n",
|
|
"# Resume from checkpoint if specified\n",
|
|
"if args.resume:\n",
|
|
" print(f\"Resuming from checkpoint: {args.resume}\")\n",
|
|
" trainer.load_checkpoint(args.resume)\n",
|
|
"\n",
|
|
"# Evaluate only mode\n",
|
|
"if args.eval_only:\n",
|
|
" trainer.evaluate()\n",
|
|
"else:\n",
|
|
" # Train the model\n",
|
|
" trainer.train(num_epochs=args.epochs)\n",
|
|
"\n",
|
|
" # Final evaluation\n",
|
|
" print(\"\\nPerforming final evaluation...\")\n",
|
|
" trainer.evaluate()\n",
|
|
"\n",
|
|
" print(\"\\nTraining completed successfully!\")\n",
|
|
" print(f\"Results saved to: {args.output_dir}\")\n",
|
|
"\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "d1cd4bb8",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": []
|
|
}
|
|
],
|
|
"metadata": {
|
|
"kernelspec": {
|
|
"display_name": ".venv",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.11.0"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 5
|
|
}
|