From fc072d798e24f6b709e7c19001848bc140017386 Mon Sep 17 00:00:00 2001 From: russian_proger Date: Tue, 14 Apr 2026 12:37:09 +0300 Subject: [PATCH] try custom cnn --- models/SiaN/notebook.gen.ipynb | 2092 ++++++++++++++++++-------------- models/SiaN/src/main.py | 18 +- models/SiaN/src/model.py | 123 ++ models/SiaN/src/train.py | 4 +- 4 files changed, 1349 insertions(+), 888 deletions(-) diff --git a/models/SiaN/notebook.gen.ipynb b/models/SiaN/notebook.gen.ipynb index 3895176..2554560 100644 --- a/models/SiaN/notebook.gen.ipynb +++ b/models/SiaN/notebook.gen.ipynb @@ -1,880 +1,1212 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "import os\n", - "import random\n", - "import logging\n", - "from typing import Tuple\n", - "import cv2\n", - "import numpy as np\n", - "import torch\n", - "import torch.nn as nn\n", - "import torch.optim as optim\n", - "import matplotlib.pyplot as plt\n", - "from PIL import Image\n", - "from torch.utils.data import DataLoader, Dataset, Subset\n", - "from torch.utils.tensorboard import SummaryWriter\n", - "from torchvision import transforms, models\n", - "from tqdm import tqdm\n" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": "# Configuration\n\nGlobal settings for:\n- Data paths and image parameters\n- Training hyperparameters\n- Model architecture options\n" - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "\n", - "\n", - "config = {\n", - " \"data_dir\": r\"C:\\Users\\admin\\Projects\\autopilot\\datasets\\ya_go_maps\\images\",\n", - " \"image_size\": (256, 256),\n", - " \"batch_size\": 32,\n", - " \"train_split\": 0.8,\n", - " \"num_workers\": 0,\n", - " \"epochs\": 10,\n", - " \"learning_rate\": 2e-4,\n", - " \"dropout_rate\": 0.5,\n", - " \"backbone\": \"resnet18\",\n", - " \"output_dir\": r\"C:\\Users\\admin\\Projects\\autopilot\\models\\SiaN\\runs\",\n", - " \"save_every_n_epochs\": 15,\n", - "}\n", - "\n", - "\n", - "def get_camera_matrix(w, h):\n", - " return np.array([[w / 2, 0, w / 2], [0, h / 2, h / 2], [0, 0, 1]], dtype=np.float32)\n", - "\n", - "\n", - "def generate_random_homography_params(angle_range=10, translation_range=0.1, scale_range=(0.9, 1.1)):\n", - " scale = np.random.uniform(*scale_range)\n", - " tx = np.random.uniform(-translation_range, translation_range)\n", - " ty = np.random.uniform(-translation_range, translation_range)\n", - " rx = np.radians(np.random.uniform(-angle_range, angle_range))\n", - " ry = np.radians(np.random.uniform(-angle_range, angle_range))\n", - " rz = np.radians(np.random.uniform(-angle_range, angle_range))\n", - " return np.array([tx, ty, rx, ry, rz, scale])\n", - "\n", - "\n", - "def homography_params_to_matrix(params, K):\n", - " tx, ty, rx, ry, rz, scale = params\n", - " cy, sy = np.cos(rz), np.sin(rz)\n", - " cp, sp = np.cos(ry), np.sin(ry)\n", - " cr, sr = np.cos(rx), np.sin(rx)\n", - " Rz = np.array([[cy, -sy, 0], [sy, cy, 0], [0, 0, 1]], dtype=np.float32)\n", - " Ry = np.array([[cp, 0, sp], [0, 1, 0], [-sp, 0, cp]], dtype=np.float32)\n", - " Rx = np.array([[1, 0, 0], [0, cr, -sr], [0, sr, cr]], dtype=np.float32)\n", - " T = np.array([[1, 0, tx], [0, 1, ty], [0, 0, scale]], dtype=np.float32)\n", - " return K @ Rx @ Ry @ Rz @ T @ np.linalg.inv(K)\n", - "\n", - "\n", - "def matrix_to_homography_params(H, K):\n", - " if hasattr(H, 'numpy'):\n", - " H = H.numpy()\n", - " K_inv = np.linalg.inv(K)\n", - " E = K_inv @ H @ K\n", - " scale = E[2, 2]\n", - " R_normalized = E / scale\n", - " rz = np.arctan2(R_normalized[1, 0], R_normalized[0, 0])\n", - " ry = np.arctan2(-R_normalized[2, 0], np.sqrt(R_normalized[2, 1]**2 + R_normalized[2, 2]**2))\n", - " rx = np.arctan2(R_normalized[2, 1], R_normalized[2, 2])\n", - " A = R_normalized[:2, :2]\n", - " correction = scale * np.array([R_normalized[0, 2], R_normalized[1, 2]])\n", - " tx, ty = np.linalg.solve(A, E[:2, 2] - correction)\n", - " return np.array([tx, ty, rx, ry, rz, scale], dtype=np.float32)\n", - "\n" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": "## Dataset\n\nGoogle/Yandex image pair loader with homography augmentation.\n\n**Features:**\n- Loads paired images from dual camera sources\n- Applies random homography transformations\n- Supports configurable train/val split\n\n**Returns:**\n" - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "\n", - "\n", - "\n", - "\n", - "class YaGoDataset(Dataset):\n", - " def __init__(self, root_dir: str, transform=None, augment: bool = True, \n", - " image_size: Tuple[int, int] = (256, 256), cache_level: int = 5):\n", - " self.root_dir = root_dir\n", - " self.transform = transform\n", - " self.augment = augment\n", - " self.image_size = image_size\n", - " self.cache_level = cache_level\n", - " self.K = get_camera_matrix(image_size[1], image_size[0])\n", - " self.image_pairs = self._discover_image_pairs()\n", - " self._load_images_to_memory()\n", - " self._init_cache()\n", - "\n", - " def _discover_image_pairs(self):\n", - " pairs = []\n", - " for f in os.listdir(self.root_dir):\n", - " if f.endswith(\"_google.png\"):\n", - " idx = f.split(\"_\")[0]\n", - " yandex_path = os.path.join(self.root_dir, f\"{idx}_yandex.png\")\n", - " if os.path.exists(yandex_path):\n", - " pairs.append({\"idx\": int(idx), \"google\": os.path.join(self.root_dir, f), \"yandex\": yandex_path})\n", - " return sorted(pairs, key=lambda x: x[\"idx\"])\n", - "\n", - " def _load_images_to_memory(self):\n", - " self._google_images = []\n", - " self._yandex_images = []\n", - " for pair in self.image_pairs:\n", - " google_img = cv2.imread(pair[\"google\"])\n", - " google_img = cv2.cvtColor(google_img, cv2.COLOR_BGR2RGB)\n", - " google_img = cv2.resize(google_img, (self.image_size[1], self.image_size[0]), interpolation=cv2.INTER_LINEAR)\n", - " \n", - " yandex_img = cv2.imread(pair[\"yandex\"])\n", - " yandex_img = cv2.cvtColor(yandex_img, cv2.COLOR_BGR2RGB)\n", - " yandex_img = cv2.resize(yandex_img, (self.image_size[1], self.image_size[0]), interpolation=cv2.INTER_LINEAR)\n", - " \n", - " self._google_images.append(google_img)\n", - " self._yandex_images.append(yandex_img)\n", - "\n", - " def _init_cache(self):\n", - " self._access_counts = [0] * len(self.image_pairs)\n", - " self._cached_google = [None] * len(self.image_pairs)\n", - " self._cached_yandex = [None] * len(self.image_pairs)\n", - " self._cached_homography = [None] * len(self.image_pairs)\n", - " self._cached_params = [None] * len(self.image_pairs)\n", - "\n", - " def _generate_augmented(self, idx):\n", - " google_img = self._google_images[idx].copy()\n", - " yandex_img = self._yandex_images[idx].copy()\n", - "\n", - " params1 = generate_random_homography_params()\n", - " params2 = generate_random_homography_params()\n", - " H1 = homography_params_to_matrix(params1, self.K)\n", - " H2 = homography_params_to_matrix(params2, self.K)\n", - " \n", - " yandex_warped = cv2.warpPerspective(yandex_img, H1, (self.image_size[1], self.image_size[0]))\n", - " google_warped = cv2.warpPerspective(google_img, H2 @ H1, (self.image_size[1], self.image_size[0]))\n", - " \n", - " return google_warped, yandex_warped, H2, params2\n", - "\n", - " def __len__(self):\n", - " return len(self.image_pairs)\n", - "\n", - " def __getitem__(self, idx):\n", - " self._access_counts[idx] += 1\n", - " \n", - " use_cache = self.augment and self.cache_level > 0 and self._access_counts[idx] > 1 and (self._access_counts[idx] - 1) % self.cache_level != 0\n", - " \n", - " if use_cache:\n", - " google_img = self._cached_google[idx]\n", - " yandex_img = self._cached_yandex[idx]\n", - " target_matrix = self._cached_homography[idx]\n", - " target_params = self._cached_params[idx]\n", - " elif self.augment:\n", - " google_img, yandex_img, target_matrix, target_params = self._generate_augmented(idx)\n", - " if self.cache_level > 0:\n", - " self._cached_google[idx] = google_img\n", - " self._cached_yandex[idx] = yandex_img\n", - " self._cached_homography[idx] = target_matrix\n", - " self._cached_params[idx] = target_params\n", - " else:\n", - " google_img = self._google_images[idx]\n", - " yandex_img = self._yandex_images[idx]\n", - " target_params = np.array([0, 0, 0, 0, 0, 1], dtype=np.float32)\n", - " target_matrix = np.eye(3, dtype=np.float32)\n", - "\n", - " google_img = Image.fromarray(google_img)\n", - " yandex_img = Image.fromarray(yandex_img)\n", - "\n", - " if self.transform:\n", - " google_img = self.transform(google_img)\n", - " yandex_img = self.transform(yandex_img)\n", - "\n", - " return {\n", - " \"google_img\": google_img,\n", - " \"yandex_img\": yandex_img,\n", - " \"homography_matrix\": torch.from_numpy(target_matrix).float(),\n", - " \"homography_params\": torch.from_numpy(target_params).float(),\n", - " }\n", - "\n", - "\n", - "def create_data_loaders(root_dir, batch_size=32, train_split=0.8, num_workers=0, \n", - " image_size=(256, 256), augment_train=True, cache_level=5):\n", - " transform = transforms.Compose([\n", - " transforms.ToTensor(),\n", - " # transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),\n", - " ])\n", - " \n", - " full_ds = YaGoDataset(root_dir, transform=transform, augment=False, image_size=image_size, cache_level=cache_level)\n", - " aug_ds = YaGoDataset(root_dir, transform=transform, augment=True, image_size=image_size, cache_level=cache_level)\n", - "\n", - " indices = list(range(len(full_ds)))\n", - " random.shuffle(indices)\n", - " split = int(train_split * len(indices))\n", - " \n", - " train_ds = Subset(aug_ds if augment_train else full_ds, indices[:split])\n", - " val_ds = Subset(aug_ds, indices[split:])\n", - "\n", - " return (DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=True),\n", - " DataLoader(val_ds, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=True))\n", - "\n", - "\n", - "def get_dataset_info():\n", - " ds = YaGoDataset(config[\"data_dir\"], augment=True, image_size=config[\"image_size\"])\n", - " return {\n", - " \"size\": len(ds),\n", - " \"sample_keys\": list(ds[0].keys()),\n", - " \"sample_params\": ds[0][\"homography_params\"].numpy()\n", - " }\n", - "\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "\n", - "train_loader, val_loader = create_data_loaders(config['data_dir'])\n", - "batch = next(iter(train_loader))\n", - "google_img = batch['google_img'][0]\n", - "yandex_img = batch['yandex_img'][0]\n", - "\n", - "# google_img.permute((1, 2, 0)) * 255\n", - "batch['homography_params'].mean(axis=0)\n", - "\n", - "print(batch['homography_matrix'][0])\n", - "print(batch['homography_params'][0])\n", - "K = get_camera_matrix(config['image_size'][0], config['image_size'][1])\n", - "print(homography_params_to_matrix(batch['homography_params'][0], K))\n", - "print(matrix_to_homography_params(batch['homography_matrix'][0].numpy(), K))\n", - "\n" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": "## Model\n\n`HomographyCNN6` — CNN architecture for homography estimation.\n\n**Output:** 6 parameters\n- `rx, ry, rz` — rotation angles (radians)\n- `tx, ty` — translation offsets\n- `scale` — isotropic scale factor\n\n**Architecture:**\n- Dual-branch CNN (Google + Yandex images)\n- Shared backbone (configurable: resnet18/34/50)\n" - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "\n", - "\n", - "def angular_difference(pred_angles, target_angles):\n", - " diff = pred_angles - target_angles\n", - " diff = torch.atan2(torch.sin(diff), torch.cos(diff))\n", - " return torch.abs(diff)\n", - "\n", - "\n", - "class HomographyCNN6(nn.Module):\n", - " def __init__(self, input_channels=3, backbone_name=\"resnet18\", pretrained=True, dropout_rate=0.3):\n", - " super().__init__()\n", - " backbone = getattr(models, backbone_name)(weights=models.ResNet18_Weights.IMAGENET1K_V1 if pretrained else None)\n", - " self.feature_dim = backbone.fc.in_features\n", - " backbone.fc = nn.Identity()\n", - " self.backbone = backbone\n", - "\n", - " self.head = nn.Sequential(\n", - " nn.Linear(self.feature_dim * 4, 1024),\n", - " nn.ReLU(inplace=True),\n", - " nn.Dropout(dropout_rate),\n", - " nn.Linear(1024, 512),\n", - " nn.ReLU(inplace=True),\n", - " nn.Dropout(dropout_rate),\n", - " nn.Linear(512, 256),\n", - " nn.ReLU(inplace=True),\n", - " nn.Dropout(dropout_rate),\n", - " nn.Linear(256, 6),\n", - " )\n", - "\n", - " def _normalize_sin_cos(self, _sin, _cos):\n", - " _len = torch.sqrt(_sin ** 2 + _cos ** 2)\n", - " return _sin / _len, _cos / _len\n", - "\n", - " def forward(self, img1, img2):\n", - " f1 = self.backbone(img1)\n", - " f2 = self.backbone(img2)\n", - " combined = torch.cat([f1, f2, torch.abs(f1 - f2), f1 * f2], dim=1)\n", - "\n", - " output = self.head(combined)\n", - "\n", - " output = torch.tanh(output) # [-1; 1]\n", - " modified = output.clone()\n", - " modified[:, 2:6] = torch.mul(output[:, 2:6], torch.pi) # [-pi; pi]\n", - "\n", - " return modified\n", - "\n", - " def decode_output(self, output):\n", - " tx = output[:, 0]\n", - " ty = output[:, 1]\n", - " scale = output[:, 5]\n", - " angle1 = output[:, 2]\n", - " angle2 = output[:, 3]\n", - " angle3 = output[:, 4]\n", - "\n", - " return torch.stack([tx, ty, angle1, angle2, angle3, scale], dim=1)\n", - "\n", - " def get_components(self, output):\n", - " decoded = self.decode_output(output)\n", - " return {\n", - " \"tx\": decoded[:, 0],\n", - " \"ty\": decoded[:, 1],\n", - " \"rx\": decoded[:, 2],\n", - " \"ry\": decoded[:, 3],\n", - " \"rz\": decoded[:, 4],\n", - " \"scale\": decoded[:, 5],\n", - " }\n", - "\n", - "\n", - "class HomographyLoss6(nn.Module):\n", - " def __init__(self, angle_loss_weight=1.0, trans_loss_weight=1.0, scale_loss_weight=1.0):\n", - " super().__init__()\n", - " self.criterion = nn.MSELoss()\n", - " self.angle_loss_weight = angle_loss_weight\n", - " self.trans_loss_weight = trans_loss_weight\n", - " self.scale_loss_weight = scale_loss_weight\n", - "\n", - " @staticmethod\n", - " def dot_angles(src, dest):\n", - " sin_src = torch.sin(src)\n", - " cos_src = torch.cos(src)\n", - " sin_dest = torch.sin(dest)\n", - " cos_dest = torch.cos(dest)\n", - " return sin_src * sin_dest + cos_src * cos_dest\n", - "\n", - " def forward(self, pred, target):\n", - " tx_loss = self.criterion(pred[:, 0], target[:, 0])\n", - " ty_loss = self.criterion(pred[:, 1], target[:, 1])\n", - "\n", - " dot_rx = HomographyLoss6.dot_angles(pred[:, 2], target[:, 2])\n", - " dot_ry = HomographyLoss6.dot_angles(pred[:, 3], target[:, 3])\n", - " dot_rz = HomographyLoss6.dot_angles(pred[:, 4], target[:, 4])\n", - "\n", - " rx_loss = self.criterion(dot_rx, torch.ones_like(dot_rx))\n", - " ry_loss = self.criterion(dot_ry, torch.ones_like(dot_ry))\n", - " rz_loss = self.criterion(dot_rz, torch.ones_like(dot_rz))\n", - "\n", - " scale_loss = self.criterion(pred[:, 5], target[:, 5])\n", - "\n", - " total_loss = (\n", - " self.trans_loss_weight * (tx_loss + ty_loss) +\n", - " self.angle_loss_weight * (rx_loss + ry_loss + rz_loss) +\n", - " self.scale_loss_weight * scale_loss\n", - " )\n", - "\n", - " return total_loss\n", - "\n", - " def compute_mse_components(self, decoded, target):\n", - " tx_mse = self.criterion(decoded[:, 0], target[:, 0]).item()\n", - " ty_mse = self.criterion(decoded[:, 1], target[:, 1]).item()\n", - " \n", - " dot_rx = HomographyLoss6.dot_angles(decoded[:, 2], target[:, 2])\n", - " dot_ry = HomographyLoss6.dot_angles(decoded[:, 3], target[:, 3])\n", - " dot_rz = HomographyLoss6.dot_angles(decoded[:, 4], target[:, 4])\n", - "\n", - " rx_mse = self.criterion(dot_rx, torch.ones_like(dot_rx)).item()\n", - " ry_mse = self.criterion(dot_ry, torch.ones_like(dot_ry)).item()\n", - " rz_mse = self.criterion(dot_rz, torch.ones_like(dot_rz)).item()\n", - "\n", - " scale_mse = self.criterion(decoded[:, 5], target[:, 5]).item()\n", - "\n", - " avg_angle_loss = (rx_mse + ry_mse + rz_mse) / 3\n", - "\n", - " return {\n", - " 'trans': (tx_mse + ty_mse) / 2,\n", - " 'angle': avg_angle_loss,\n", - " 'scale': scale_mse\n", - " }\n", - "\n", - "\n", - "def count_parameters(model):\n", - " return sum(p.numel() for p in model.parameters())\n", - "\n" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": "## Training\n\n`HomographyTrainer` — training loop with validation and checkpointing.\n\n**Features:**\n- Epoch-based training with tqdm progress bar\n- Adam optimizer with configurable LR\n- Validation after each epoch\n- Best model auto-save\n- Periodic checkpoints (every N epochs via `save_every_n_epochs`)\n\n**Checkpoint saving:**\n- `best_model.pt` — lowest validation loss\n" - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "\n", - "\n", - "\n", - "\n", - "class HomographyTrainer:\n", - " def __init__(self, model, train_loader, val_loader, device):\n", - " self.model = model.to(device)\n", - " self.train_loader = train_loader\n", - " self.val_loader = val_loader\n", - " self.device = device\n", - " self.criterion = HomographyLoss6()\n", - " self.optimizer = optim.Adam(model.parameters(), lr=config[\"learning_rate\"], weight_decay=1e-4)\n", - " self.writer = None\n", - " self.best_val_loss = float(\"inf\")\n", - " self.train_losses = []\n", - " self.val_losses = []\n", - " self.train_mse_trans = []\n", - " self.train_mse_angle = []\n", - " self.train_mse_scale = []\n", - " self.val_mse_trans = []\n", - " self.val_mse_angle = []\n", - " self.val_mse_scale = []\n", - "\n", - " def train_epoch(self, epoch):\n", - " self.model.train()\n", - " total_loss, total_samples = 0, 0\n", - " mse_trans_sum, mse_angle_sum, mse_scale_sum = 0, 0, 0\n", - " pbar = tqdm(self.train_loader, desc=f\"Epoch {epoch}\")\n", - " for batch_idx, batch in enumerate(pbar):\n", - " google_img = batch[\"google_img\"].to(self.device)\n", - " yandex_img = batch[\"yandex_img\"].to(self.device)\n", - " target = batch[\"homography_params\"].to(self.device)\n", - "\n", - " self.optimizer.zero_grad()\n", - " output = self.model(google_img, yandex_img)\n", - " loss = self.criterion(output, target)\n", - " loss.backward()\n", - " self.optimizer.step()\n", - "\n", - " total_loss += loss.item() * google_img.size(0)\n", - " total_samples += google_img.size(0)\n", - " \n", - " decoded_output = self.model.decode_output(output)\n", - " mse_components = self.criterion.compute_mse_components(decoded_output, target)\n", - " mse_trans_sum += mse_components['trans'] * google_img.size(0)\n", - " mse_angle_sum += mse_components['angle'] * google_img.size(0)\n", - " mse_scale_sum += mse_components['scale'] * google_img.size(0)\n", - " \n", - " pbar.set_postfix({\"loss\": loss.item()})\n", - "\n", - " self.train_mse_trans.append(mse_trans_sum / total_samples)\n", - " self.train_mse_angle.append(mse_angle_sum / total_samples)\n", - " self.train_mse_scale.append(mse_scale_sum / total_samples)\n", - " \n", - " return {\"loss\": total_loss / total_samples}\n", - "\n", - " def validate(self):\n", - " self.model.eval()\n", - " total_loss, total_samples = 0, 0\n", - " mse_trans_sum, mse_angle_sum, mse_scale_sum = 0, 0, 0\n", - " with torch.no_grad():\n", - " for batch in tqdm(self.val_loader, desc=\"Validation\"):\n", - " google_img = batch[\"google_img\"].to(self.device)\n", - " yandex_img = batch[\"yandex_img\"].to(self.device)\n", - " target = batch[\"homography_params\"].to(self.device)\n", - " output = self.model(google_img, yandex_img)\n", - " decoded_output = self.model.decode_output(output)\n", - " loss = self.criterion(output, target)\n", - " total_loss += loss.item() * google_img.size(0)\n", - " total_samples += google_img.size(0)\n", - " \n", - " mse_components = self.criterion.compute_mse_components(decoded_output, target)\n", - " mse_trans_sum += mse_components['trans'] * google_img.size(0)\n", - " mse_angle_sum += mse_components['angle'] * google_img.size(0)\n", - " mse_scale_sum += mse_components['scale'] * google_img.size(0)\n", - " \n", - " self.val_mse_trans.append(mse_trans_sum / total_samples)\n", - " self.val_mse_angle.append(mse_angle_sum / total_samples)\n", - " self.val_mse_scale.append(mse_scale_sum / total_samples)\n", - " \n", - " return {\"loss\": total_loss / total_samples}\n", - "\n", - " def train(self, num_epochs):\n", - " log_dir = config[\"output_dir\"]\n", - " os.makedirs(log_dir, exist_ok=True)\n", - " self.writer = SummaryWriter(log_dir)\n", - "\n", - " for epoch in range(1, num_epochs + 1):\n", - " train_metrics = self.train_epoch(epoch)\n", - " val_metrics = self.validate()\n", - " self.train_losses.append(train_metrics[\"loss\"])\n", - " self.val_losses.append(val_metrics[\"loss\"])\n", - " print(f\"Train Loss: {train_metrics['loss']:.4f}, Val Loss: {val_metrics['loss']:.4f}\")\n", - " print(f\" MSE - Trans: {self.val_mse_trans[-1]:.4f}, Angle: {self.val_mse_angle[-1]:.4f}, Scale: {self.val_mse_scale[-1]:.4f}\")\n", - "\n", - " if val_metrics[\"loss\"] < self.best_val_loss:\n", - " self.best_val_loss = val_metrics[\"loss\"]\n", - " self.save_checkpoint(epoch, is_best=True)\n", - " print(f\"Best model saved (val loss: {val_metrics['loss']:.4f})\")\n", - "\n", - " if epoch % config[\"save_every_n_epochs\"] == 0:\n", - " self.save_checkpoint(epoch, is_best=False)\n", - " print(f\"Checkpoint saved at epoch {epoch}\")\n", - "\n", - " self.writer.close()\n", - "\n", - " def save_checkpoint(self, epoch, is_best=False):\n", - " ckpt_dir = os.path.join(config[\"output_dir\"], \"checkpoints\")\n", - " os.makedirs(ckpt_dir, exist_ok=True)\n", - " ckpt = {\"epoch\": epoch, \"model_state_dict\": self.model.state_dict(), \"val_loss\": self.best_val_loss}\n", - " torch.save(ckpt, os.path.join(ckpt_dir, f\"checkpoint_epoch_{epoch}.pt\"))\n", - " if is_best:\n", - " torch.save(ckpt, os.path.join(ckpt_dir, \"best_model.pt\"))\n", - "\n" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": "## Analysis\n\nVisualization and evaluation tools:\n\n- Training metrics plots (loss curves)\n- Prediction visualization on sample images\n" - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "\n", - "\n", - "IMG_DIR = os.path.join(config[\"output_dir\"], \"images\")\n", - "os.makedirs(IMG_DIR, exist_ok=True)\n", - "\n", - "\n", - "def analyze_training(trainer):\n", - " print(\"=== Training Analysis ===\\n\")\n", - "\n", - " if trainer.writer:\n", - " print(\"TensorBoard logs available at:\", trainer.writer.log_dir)\n", - "\n", - " print(f\"\\nBest val loss: {trainer.best_val_loss:.4f}\")\n", - "\n", - " best_model_path = os.path.join(config[\"output_dir\"], \"checkpoints\", \"best_model.pt\")\n", - " if os.path.exists(best_model_path):\n", - " checkpoint = torch.load(best_model_path, map_location=trainer.device)\n", - " trainer.model.load_state_dict(checkpoint[\"model_state_dict\"])\n", - " print(f\"\\nLoaded best model from epoch {checkpoint['epoch']} (val loss: {checkpoint['val_loss']:.4f})\")\n", - " \n", - " trainer.model.eval()\n", - " \n", - " n_samples = 50\n", - " names = [\"tx\", \"ty\", \"rx\", \"ry\", \"rz\", \"scale\"]\n", - " \n", - " _, val_loader_for_analysis = create_data_loaders(\n", - " root_dir=config[\"data_dir\"],\n", - " batch_size=config[\"batch_size\"],\n", - " train_split=config[\"train_split\"],\n", - " num_workers=config[\"num_workers\"],\n", - " image_size=config[\"image_size\"],\n", - " augment_train=True,\n", - " cache_level=0,\n", - " )\n", - " \n", - " with torch.no_grad():\n", - " all_errors = [[] for _ in range(6)]\n", - " all_targets = [[] for _ in range(6)]\n", - " all_preds = [[] for _ in range(6)]\n", - " \n", - " sample_count = 0\n", - " for batch in val_loader_for_analysis:\n", - " if sample_count >= n_samples:\n", - " break\n", - " \n", - " google_img = batch[\"google_img\"].to(trainer.device)\n", - " yandex_img = batch[\"yandex_img\"].to(trainer.device)\n", - " target_params = batch[\"homography_params\"].to(trainer.device)\n", - " pred_params = trainer.model(google_img, yandex_img)\n", - " decoded_pred = trainer.model.decode_output(pred_params)\n", - " \n", - " batch_size = google_img.size(0)\n", - " for i in range(batch_size):\n", - " if sample_count >= n_samples:\n", - " break\n", - " \n", - " tx_error = torch.abs(decoded_pred[i, 0] - target_params[i, 0]).item()\n", - " ty_error = torch.abs(decoded_pred[i, 1] - target_params[i, 1]).item()\n", - " rx_error = angular_difference(decoded_pred[i, 2], target_params[i, 2]).item()\n", - " ry_error = angular_difference(decoded_pred[i, 3], target_params[i, 3]).item()\n", - " rz_error = angular_difference(decoded_pred[i, 4], target_params[i, 4]).item()\n", - " scale_error = torch.abs(decoded_pred[i, 5] - target_params[i, 5]).item()\n", - " \n", - " errors = [tx_error, ty_error, rx_error, ry_error, rz_error, scale_error]\n", - " target_reordered = target_params[i].cpu().numpy()\n", - " pred_reordered = decoded_pred[i].cpu().numpy()\n", - " \n", - " for j in range(6):\n", - " all_errors[j].append(errors[j])\n", - " all_targets[j].append(target_reordered[j])\n", - " all_preds[j].append(pred_reordered[j])\n", - " \n", - " sample_count += 1\n", - " \n", - " mean_errors = [np.mean(all_errors[i]) for i in range(6)]\n", - " std_errors = [np.std(all_errors[i]) for i in range(6)]\n", - " \n", - " if len(trainer.train_losses) > 0:\n", - " epochs = range(1, len(trainer.train_losses) + 1)\n", - " fig, axes = plt.subplots(2, 2, figsize=(16, 12))\n", - " \n", - " axes[0, 0].plot(epochs, trainer.train_losses, \"b-\", label=\"Train Loss\")\n", - " axes[0, 0].plot(epochs, trainer.val_losses, \"r-\", label=\"Val Loss\")\n", - " axes[0, 0].set_xlabel(\"Epoch\")\n", - " axes[0, 0].set_ylabel(\"Loss\")\n", - " axes[0, 0].set_title(\"Training & Validation Loss\")\n", - " axes[0, 0].legend()\n", - " axes[0, 0].grid(True, alpha=0.3)\n", - " \n", - " axes[0, 1].plot(epochs, trainer.val_losses, \"r-\", label=\"Val Loss\")\n", - " axes[0, 1].set_xlabel(\"Epoch\")\n", - " axes[0, 1].set_ylabel(\"Loss\")\n", - " axes[0, 1].set_title(\"Validation Loss\")\n", - " axes[0, 1].legend()\n", - " axes[0, 1].grid(True, alpha=0.3)\n", - " \n", - " axes[1, 0].plot(epochs, trainer.val_mse_trans, \"g-\", label=\"Translation (tx, ty)\")\n", - " axes[1, 0].plot(epochs, trainer.val_mse_angle, \"m-\", label=\"Angle (rx, ry, rz)\")\n", - " axes[1, 0].plot(epochs, trainer.val_mse_scale, \"c-\", label=\"Scale\")\n", - " axes[1, 0].set_xlabel(\"Epoch\")\n", - " axes[1, 0].set_ylabel(\"MSE\")\n", - " axes[1, 0].set_title(\"Validation MSE by Category\")\n", - " axes[1, 0].legend()\n", - " axes[1, 0].grid(True, alpha=0.3)\n", - " \n", - " x_pos = np.arange(6)\n", - " axes[1, 1].bar(x_pos, mean_errors, yerr=std_errors, capsize=5, color=[\"c\", \"m\", \"y\", \"g\", \"b\", \"r\"], alpha=0.8)\n", - " axes[1, 1].set_xticks(x_pos)\n", - " axes[1, 1].set_xticklabels(names)\n", - " axes[1, 1].set_ylabel(\"Mean Absolute Error\")\n", - " axes[1, 1].set_title(f\"Mean Absolute Error per Parameter ({n_samples} samples)\")\n", - " axes[1, 1].grid(True, alpha=0.3, axis=\"y\")\n", - " \n", - " plt.tight_layout()\n", - " plt.savefig(os.path.join(IMG_DIR, \"training_loss_plots.png\"), dpi=150)\n", - " print(\"Saved training_loss_plots.png\")\n", - " plt.show()\n", - " \n", - " fig, axes = plt.subplots(2, 3, figsize=(18, 10))\n", - " for j in range(6):\n", - " row = j // 3\n", - " col = j % 3\n", - " axes[row, col].bar(range(len(all_errors[j])), all_errors[j], color=\"steelblue\", alpha=0.7)\n", - " axes[row, col].set_xlabel(\"Sample\")\n", - " axes[row, col].set_ylabel(\"Absolute Error\")\n", - " axes[row, col].set_title(f\"{names[j]}: Mean={np.mean(all_errors[j]):.4f}, Std={np.std(all_errors[j]):.4f}\")\n", - " axes[row, col].grid(True, alpha=0.3, axis=\"y\")\n", - " plt.suptitle(f\"Mean Absolute Error per Parameter ({n_samples} samples)\", fontsize=14)\n", - " plt.tight_layout()\n", - " plt.savefig(os.path.join(IMG_DIR, \"mae_per_parameter.png\"), dpi=150)\n", - " print(\"Saved mae_per_parameter.png\")\n", - " plt.show()\n", - " \n", - " fig, axes = plt.subplots(1, 2, figsize=(14, 6))\n", - " \n", - " x_pos = np.arange(6)\n", - " axes[0].bar(x_pos, mean_errors, yerr=std_errors, capsize=5, color=[\"c\", \"m\", \"y\", \"g\", \"b\", \"r\"], alpha=0.8)\n", - " axes[0].set_xticks(x_pos)\n", - " axes[0].set_xticklabels(names)\n", - " axes[0].set_ylabel(\"Mean Absolute Error\")\n", - " axes[0].set_title(\"Mean Absolute Error per Parameter (with std)\")\n", - " axes[0].grid(True, alpha=0.3, axis=\"y\")\n", - " \n", - " bp = axes[1].boxplot([all_errors[i] for i in range(6)], labels=names, patch_artist=True)\n", - " colors = [\"c\", \"m\", \"y\", \"g\", \"b\", \"r\"]\n", - " for patch, color in zip(bp[\"boxes\"], colors):\n", - " patch.set_facecolor(color)\n", - " patch.set_alpha(0.7)\n", - " axes[1].set_ylabel(\"Absolute Error\")\n", - " axes[1].set_title(f\"Error Distribution per Parameter ({n_samples} samples)\")\n", - " axes[1].grid(True, alpha=0.3, axis=\"y\")\n", - " \n", - " plt.tight_layout()\n", - " plt.savefig(os.path.join(IMG_DIR, \"mae_boxplot.png\"), dpi=150)\n", - " print(\"Saved mae_boxplot.png\")\n", - " plt.show()\n", - " \n", - " print(\"\\n=== Sample Predictions (20 pairs) ===\")\n", - " n_vis_samples = 20\n", - " \n", - " with torch.no_grad():\n", - " vis_count = 0\n", - " for batch in val_loader_for_analysis:\n", - " if vis_count >= n_vis_samples:\n", - " break\n", - " batch_size = batch[\"google_img\"].size(0)\n", - " \n", - " for i in range(batch_size):\n", - " if vis_count >= n_vis_samples:\n", - " break\n", - " \n", - " google_img = batch[\"google_img\"][i:i+1].to(trainer.device)\n", - " yandex_img = batch[\"yandex_img\"][i:i+1].to(trainer.device)\n", - " target_params = batch[\"homography_params\"][i:i+1].to(trainer.device)\n", - " pred_params = trainer.model(google_img, yandex_img)\n", - " decoded_pred = trainer.model.decode_output(pred_params)\n", - " \n", - " tx_error = torch.abs(decoded_pred[0, 0] - target_params[0, 0]).item()\n", - " ty_error = torch.abs(decoded_pred[0, 1] - target_params[0, 1]).item()\n", - " rx_error = angular_difference(decoded_pred[0, 2], target_params[0, 2]).item()\n", - " ry_error = angular_difference(decoded_pred[0, 3], target_params[0, 3]).item()\n", - " rz_error = angular_difference(decoded_pred[0, 4], target_params[0, 4]).item()\n", - " scale_error = torch.abs(decoded_pred[0, 5] - target_params[0, 5]).item()\n", - " \n", - " errors = np.array([tx_error, ty_error, rx_error, ry_error, rz_error, scale_error])\n", - " targets = target_params[0].cpu().numpy()\n", - " preds = decoded_pred[0].cpu().numpy()\n", - " \n", - " fig, axes = plt.subplots(2, 2, figsize=(12, 10))\n", - " \n", - " axes[0, 0].imshow(google_img[0].cpu().permute(1, 2, 0))\n", - " axes[0, 0].set_title(f\"Google Image\")\n", - " axes[0, 0].axis(\"off\")\n", - " \n", - " axes[0, 1].imshow(yandex_img[0].cpu().permute(1, 2, 0))\n", - " axes[0, 1].set_title(f\"Yandex Image\")\n", - " axes[0, 1].axis(\"off\")\n", - " \n", - " x_pos = np.arange(6)\n", - " width = 0.35\n", - " axes[1, 0].bar(x_pos - width/2, targets, width, label=\"Target\", color=\"steelblue\", alpha=0.8)\n", - " axes[1, 0].bar(x_pos + width/2, preds, width, label=\"Predicted\", color=\"coral\", alpha=0.8)\n", - " axes[1, 0].set_xticks(x_pos)\n", - " axes[1, 0].set_xticklabels(names)\n", - " axes[1, 0].set_ylabel(\"Parameter Value\")\n", - " axes[1, 0].set_title(\"Target vs Predicted\")\n", - " axes[1, 0].legend()\n", - " axes[1, 0].grid(True, alpha=0.3, axis=\"y\")\n", - " \n", - " axes[1, 1].bar(x_pos, errors, color=[\"c\", \"m\", \"y\", \"g\", \"b\", \"r\"], alpha=0.8)\n", - " axes[1, 1].set_xticks(x_pos)\n", - " axes[1, 1].set_xticklabels(names)\n", - " axes[1, 1].set_ylabel(\"Absolute Error\")\n", - " axes[1, 1].set_title(f\"Prediction Error (Mean: {np.mean(errors):.4f})\")\n", - " axes[1, 1].grid(True, alpha=0.3, axis=\"y\")\n", - " for i_e, e in enumerate(errors):\n", - " axes[1, 1].text(i_e, e + 0.01, f\"{e:.3f}\", ha=\"center\", va=\"bottom\", fontsize=8)\n", - " \n", - " plt.suptitle(f\"Sample {vis_count + 1}\", fontsize=14)\n", - " plt.tight_layout()\n", - " plt.savefig(os.path.join(IMG_DIR, f\"prediction_sample_{vis_count + 1:02d}.png\"), dpi=100)\n", - " plt.show()\n", - " print(f\"Saved prediction_sample_{vis_count + 1:02d}.png\")\n", - " \n", - " vis_count += 1\n", - " \n", - " print(f\"\\nPrediction errors over {n_samples} samples:\")\n", - " print(f\"{'Param':<8} {'Mean Error':>12} {'Std Error':>12} {'Min':>8} {'Max':>8}\")\n", - " print(\"-\" * 52)\n", - " for i in range(6):\n", - " mean_err = np.mean(all_errors[i])\n", - " std_err = np.std(all_errors[i])\n", - " min_err = np.min(all_errors[i])\n", - " max_err = np.max(all_errors[i])\n", - " print(f\"{names[i]:<8} {mean_err:>12.4f} {std_err:>12.4f} {min_err:>8.4f} {max_err:>8.4f}\")\n", - "\n", - " return {\n", - " \"best_val_loss\": trainer.best_val_loss,\n", - " \"train_losses\": trainer.train_losses,\n", - " \"val_losses\": trainer.val_losses,\n", - " \"val_mse_trans\": trainer.val_mse_trans,\n", - " \"val_mse_angle\": trainer.val_mse_angle,\n", - " \"val_mse_scale\": trainer.val_mse_scale,\n", - " }\n", - "\n" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": "## Main Pipeline\n\nExecutes the full training workflow:\n1. Load dataset info\n2. Create data loaders\n3. Initialize model\n4. Train with validation\n5. Analyze and export results\n\n**Outputs:**\n- Model checkpoints in `runs/checkpoints/`\n- TensorBoard logs in `runs/`\n" - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "\n", - "\n", - "\n", - "\n", - "logging.basicConfig(level=logging.INFO, format=\"%(asctime)s - %(message)s\")\n", - "logger = logging.getLogger(__name__)\n", - "\n", - "logger.info(\"=\" * 50)\n", - "logger.info(\"SiaN Training Pipeline\")\n", - "logger.info(\"=\" * 50)\n", - "\n", - "dataset_info = get_dataset_info()\n", - "logger.info(f\"Dataset: {dataset_info['size']} samples, keys={dataset_info['sample_keys']}\")\n", - "\n", - "train_loader, val_loader = create_data_loaders(\n", - " root_dir=config[\"data_dir\"],\n", - " batch_size=config[\"batch_size\"],\n", - " train_split=config[\"train_split\"],\n", - " num_workers=config[\"num_workers\"],\n", - " image_size=config[\"image_size\"],\n", - ")\n", - "logger.info(f\"Data loaders created: train={len(train_loader.dataset)}, val={len(val_loader.dataset)}\")\n", - "\n", - "model = HomographyCNN6(\n", - " input_channels=3,\n", - " backbone_name=config[\"backbone\"],\n", - " pretrained=True,\n", - " dropout_rate=config[\"dropout_rate\"]\n", - ")\n", - "logger.info(f\"Model created with {count_parameters(model):,} parameters\")\n", - "\n", - "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", - "logger.info(f\"Using device: {device}\")\n", - "\n", - "trainer = HomographyTrainer(model, train_loader, val_loader, device)\n", - "logger.info(\"Starting training...\")\n", - "trainer.train(config[\"epochs\"])\n", - "logger.info(\"Training completed\")\n", - "\n", - "logger.info(\"Analyzing model...\")\n", - "results = analyze_training(trainer)\n", - "logger.info(f\"Analysis complete: best_val_loss={results['best_val_loss']:.4f}\")\n", - "\n", - "logger.info(\"=\" * 50)\n", - "logger.info(\"Pipeline completed successfully\")\n", - "logger.info(\"=\" * 50)\n", - "\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "!zip artefacts.zip runs/checkpoints/best_model.pt runs/images/ runs/events.*\n", - "\n" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": ".venv", - "language": "python", - "name": "python3" - }, - "language_info": { - "name": "python", - "version": "3.11.0" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} \ No newline at end of file +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 35, + "id": "4230e9cd", + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "import random\n", + "import logging\n", + "from typing import Tuple\n", + "import cv2\n", + "import numpy as np\n", + "import torch\n", + "import torch.nn as nn\n", + "import torch.optim as optim\n", + "import matplotlib.pyplot as plt\n", + "from PIL import Image\n", + "import seaborn as sns\n", + "from torch.utils.data import DataLoader, Dataset, Subset\n", + "from torch.utils.tensorboard import SummaryWriter\n", + "from torchvision import transforms, models\n", + "from tqdm import tqdm\n" + ] + }, + { + "cell_type": "markdown", + "id": "d41ef314", + "metadata": {}, + "source": [ + "# Configuration\n", + "\n", + "Global settings for:\n", + "- Data paths and image parameters\n", + "- Training hyperparameters\n", + "- Model architecture options\n" + ] + }, + { + "cell_type": "code", + "execution_count": 36, + "id": "4463d9d3", + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "\n", + "config = {\n", + " \"data_dir\": r\"C:\\Users\\admin\\Projects\\autopilot\\datasets\\ya_go_maps\\images\",\n", + " \"image_size\": (256, 256),\n", + " \"batch_size\": 32,\n", + " \"train_split\": 0.8,\n", + " \"num_workers\": 0,\n", + " \"epochs\": 10,\n", + " \"learning_rate\": 2e-4,\n", + " \"dropout_rate\": 0.5,\n", + " \"backbone\": \"resnet18\",\n", + " \"output_dir\": r\"C:\\Users\\admin\\Projects\\autopilot\\models\\SiaN\\runs\",\n", + " \"save_every_n_epochs\": 15,\n", + "}\n", + "\n", + "\n", + "def get_camera_matrix(w, h):\n", + " return np.array([[w / 2, 0, w / 2], [0, h / 2, h / 2], [0, 0, 1]], dtype=np.float32)\n", + "\n", + "\n", + "def generate_random_homography_params(angle_range=10, translation_range=0.1, scale_range=(0.9, 1.1)):\n", + " scale = np.random.uniform(*scale_range)\n", + " tx = np.random.uniform(-translation_range, translation_range)\n", + " ty = np.random.uniform(-translation_range, translation_range)\n", + " rx = np.radians(np.random.uniform(-angle_range, angle_range))\n", + " ry = np.radians(np.random.uniform(-angle_range, angle_range))\n", + " rz = np.radians(np.random.uniform(-angle_range, angle_range))\n", + " return np.array([tx, ty, rx, ry, rz, scale])\n", + "\n", + "\n", + "def homography_params_to_matrix(params, K):\n", + " tx, ty, rx, ry, rz, scale = params\n", + " cy, sy = np.cos(rz), np.sin(rz)\n", + " cp, sp = np.cos(ry), np.sin(ry)\n", + " cr, sr = np.cos(rx), np.sin(rx)\n", + " Rz = np.array([[cy, -sy, 0], [sy, cy, 0], [0, 0, 1]], dtype=np.float32)\n", + " Ry = np.array([[cp, 0, sp], [0, 1, 0], [-sp, 0, cp]], dtype=np.float32)\n", + " Rx = np.array([[1, 0, 0], [0, cr, -sr], [0, sr, cr]], dtype=np.float32)\n", + " T = np.array([[1, 0, tx], [0, 1, ty], [0, 0, scale]], dtype=np.float32)\n", + " return K @ Rx @ Ry @ Rz @ T @ np.linalg.inv(K)\n", + "\n", + "\n", + "def matrix_to_homography_params(H, K):\n", + " if hasattr(H, 'numpy'):\n", + " H = H.numpy()\n", + " K_inv = np.linalg.inv(K)\n", + " E = K_inv @ H @ K\n", + " scale = E[2, 2]\n", + " R_normalized = E / scale\n", + " rz = np.arctan2(R_normalized[1, 0], R_normalized[0, 0])\n", + " ry = np.arctan2(-R_normalized[2, 0], np.sqrt(R_normalized[2, 1]**2 + R_normalized[2, 2]**2))\n", + " rx = np.arctan2(R_normalized[2, 1], R_normalized[2, 2])\n", + " A = R_normalized[:2, :2]\n", + " correction = scale * np.array([R_normalized[0, 2], R_normalized[1, 2]])\n", + " tx, ty = np.linalg.solve(A, E[:2, 2] - correction)\n", + " return np.array([tx, ty, rx, ry, rz, scale], dtype=np.float32)\n", + "\n" + ] + }, + { + "cell_type": "markdown", + "id": "e5a40be4", + "metadata": {}, + "source": [ + "## Dataset\n", + "\n", + "Google/Yandex image pair loader with homography augmentation.\n", + "\n", + "**Features:**\n", + "- Loads paired images from dual camera sources\n", + "- Applies random homography transformations\n", + "- Supports configurable train/val split\n", + "\n", + "**Returns:**\n" + ] + }, + { + "cell_type": "code", + "execution_count": 37, + "id": "37358bfe", + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "\n", + "\n", + "\n", + "class YaGoDataset(Dataset):\n", + " def __init__(self, root_dir: str, transform=None, augment: bool = True, \n", + " image_size: Tuple[int, int] = (256, 256), cache_level: int = 5):\n", + " self.root_dir = root_dir\n", + " self.transform = transform\n", + " self.augment = augment\n", + " self.image_size = image_size\n", + " self.cache_level = cache_level\n", + " self.K = get_camera_matrix(image_size[1], image_size[0])\n", + " self.image_pairs = self._discover_image_pairs()\n", + " self._load_images_to_memory()\n", + " self._init_cache()\n", + "\n", + " def _discover_image_pairs(self):\n", + " pairs = []\n", + " for f in os.listdir(self.root_dir):\n", + " if f.endswith(\"_google.png\"):\n", + " idx = f.split(\"_\")[0]\n", + " yandex_path = os.path.join(self.root_dir, f\"{idx}_yandex.png\")\n", + " if os.path.exists(yandex_path):\n", + " pairs.append({\"idx\": int(idx), \"google\": os.path.join(self.root_dir, f), \"yandex\": yandex_path})\n", + " return sorted(pairs, key=lambda x: x[\"idx\"])\n", + "\n", + " def _load_images_to_memory(self):\n", + " self._google_images = []\n", + " self._yandex_images = []\n", + " for pair in self.image_pairs:\n", + " google_img = cv2.imread(pair[\"google\"])\n", + " google_img = cv2.cvtColor(google_img, cv2.COLOR_BGR2RGB)\n", + " google_img = cv2.resize(google_img, (self.image_size[1], self.image_size[0]), interpolation=cv2.INTER_LINEAR)\n", + " \n", + " yandex_img = cv2.imread(pair[\"yandex\"])\n", + " yandex_img = cv2.cvtColor(yandex_img, cv2.COLOR_BGR2RGB)\n", + " yandex_img = cv2.resize(yandex_img, (self.image_size[1], self.image_size[0]), interpolation=cv2.INTER_LINEAR)\n", + " \n", + " self._google_images.append(google_img)\n", + " self._yandex_images.append(yandex_img)\n", + "\n", + " def _init_cache(self):\n", + " self._access_counts = [0] * len(self.image_pairs)\n", + " self._cached_google = [None] * len(self.image_pairs)\n", + " self._cached_yandex = [None] * len(self.image_pairs)\n", + " self._cached_homography = [None] * len(self.image_pairs)\n", + " self._cached_params = [None] * len(self.image_pairs)\n", + "\n", + " def _generate_augmented(self, idx):\n", + " google_img = self._google_images[idx].copy()\n", + " yandex_img = self._yandex_images[idx].copy()\n", + "\n", + " params1 = generate_random_homography_params()\n", + " params2 = generate_random_homography_params()\n", + " H1 = homography_params_to_matrix(params1, self.K)\n", + " H2 = homography_params_to_matrix(params2, self.K)\n", + " \n", + " yandex_warped = cv2.warpPerspective(yandex_img, H1, (self.image_size[1], self.image_size[0]))\n", + " google_warped = cv2.warpPerspective(google_img, H2 @ H1, (self.image_size[1], self.image_size[0]))\n", + " \n", + " return google_warped, yandex_warped, H2, params2\n", + "\n", + " def __len__(self):\n", + " return len(self.image_pairs)\n", + "\n", + " def __getitem__(self, idx):\n", + " self._access_counts[idx] += 1\n", + " \n", + " use_cache = self.augment and self.cache_level > 0 and self._access_counts[idx] > 1 and (self._access_counts[idx] - 1) % self.cache_level != 0\n", + " \n", + " if use_cache:\n", + " google_img = self._cached_google[idx]\n", + " yandex_img = self._cached_yandex[idx]\n", + " target_matrix = self._cached_homography[idx]\n", + " target_params = self._cached_params[idx]\n", + " elif self.augment:\n", + " google_img, yandex_img, target_matrix, target_params = self._generate_augmented(idx)\n", + " if self.cache_level > 0:\n", + " self._cached_google[idx] = google_img\n", + " self._cached_yandex[idx] = yandex_img\n", + " self._cached_homography[idx] = target_matrix\n", + " self._cached_params[idx] = target_params\n", + " else:\n", + " google_img = self._google_images[idx]\n", + " yandex_img = self._yandex_images[idx]\n", + " target_params = np.array([0, 0, 0, 0, 0, 1], dtype=np.float32)\n", + " target_matrix = np.eye(3, dtype=np.float32)\n", + "\n", + " google_img = Image.fromarray(google_img)\n", + " yandex_img = Image.fromarray(yandex_img)\n", + "\n", + " if self.transform:\n", + " google_img = self.transform(google_img)\n", + " yandex_img = self.transform(yandex_img)\n", + "\n", + " return {\n", + " \"google_img\": google_img,\n", + " \"yandex_img\": yandex_img,\n", + " \"homography_matrix\": torch.from_numpy(target_matrix).float(),\n", + " \"homography_params\": torch.from_numpy(target_params).float(),\n", + " }\n", + "\n", + "\n", + "def create_data_loaders(root_dir, batch_size=32, train_split=0.8, num_workers=0, \n", + " image_size=(256, 256), augment_train=True, cache_level=5):\n", + " transform = transforms.Compose([\n", + " transforms.ToTensor(),\n", + " # transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),\n", + " ])\n", + " \n", + " full_ds = YaGoDataset(root_dir, transform=transform, augment=False, image_size=image_size, cache_level=cache_level)\n", + " aug_ds = YaGoDataset(root_dir, transform=transform, augment=True, image_size=image_size, cache_level=cache_level)\n", + "\n", + " indices = list(range(len(full_ds)))\n", + " random.shuffle(indices)\n", + " split = int(train_split * len(indices))\n", + " \n", + " train_ds = Subset(aug_ds if augment_train else full_ds, indices[:split])\n", + " val_ds = Subset(aug_ds, indices[split:])\n", + "\n", + " return (DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=True),\n", + " DataLoader(val_ds, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=True))\n", + "\n", + "\n", + "def get_dataset_info():\n", + " ds = YaGoDataset(config[\"data_dir\"], augment=True, image_size=config[\"image_size\"])\n", + " return {\n", + " \"size\": len(ds),\n", + " \"sample_keys\": list(ds[0].keys()),\n", + " \"sample_params\": ds[0][\"homography_params\"].numpy()\n", + " }\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": 38, + "id": "9fee48b8", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "tensor([[ 1.0661e+00, -4.5570e-02, -2.1649e+01],\n", + " [ 3.7228e-02, 9.1285e-01, 1.3237e+01],\n", + " [ 5.3873e-04, -6.4843e-04, 9.6602e-01]])\n", + "tensor([-0.0379, 0.0183, -0.0856, -0.0661, -0.0375, 0.9617])\n", + "[[ 1.0660727e+00 -4.5569740e-02 -2.1648926e+01]\n", + " [ 3.7227791e-02 9.1284692e-01 1.3237175e+01]\n", + " [ 5.3873385e-04 -6.4843398e-04 9.6602309e-01]]\n", + "[ 0. 0. -0.08696619 -0.0720376 -0.03181122 0.9519815 ]\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "c:\\Users\\admin\\Projects\\autopilot\\.venv\\Lib\\site-packages\\torch\\utils\\data\\dataloader.py:775: UserWarning: 'pin_memory' argument is set as true but no accelerator is found, then device pinned memory won't be used.\n", + " super().__init__(loader)\n", + "C:\\Users\\admin\\AppData\\Local\\Temp\\ipykernel_19016\\897759767.py:32: DeprecationWarning: __array_wrap__ must accept context and return_scalar arguments (positionally) in the future. (Deprecated NumPy 2.0)\n", + " cy, sy = np.cos(rz), np.sin(rz)\n", + "C:\\Users\\admin\\AppData\\Local\\Temp\\ipykernel_19016\\897759767.py:33: DeprecationWarning: __array_wrap__ must accept context and return_scalar arguments (positionally) in the future. (Deprecated NumPy 2.0)\n", + " cp, sp = np.cos(ry), np.sin(ry)\n", + "C:\\Users\\admin\\AppData\\Local\\Temp\\ipykernel_19016\\897759767.py:34: DeprecationWarning: __array_wrap__ must accept context and return_scalar arguments (positionally) in the future. (Deprecated NumPy 2.0)\n", + " cr, sr = np.cos(rx), np.sin(rx)\n" + ] + } + ], + "source": [ + "\n", + "train_loader, val_loader = create_data_loaders(config['data_dir'])\n", + "batch = next(iter(train_loader))\n", + "google_img = batch['google_img'][0]\n", + "yandex_img = batch['yandex_img'][0]\n", + "\n", + "# google_img.permute((1, 2, 0)) * 255\n", + "batch['homography_params'].mean(axis=0)\n", + "\n", + "print(batch['homography_matrix'][0])\n", + "print(batch['homography_params'][0])\n", + "K = get_camera_matrix(config['image_size'][0], config['image_size'][1])\n", + "print(homography_params_to_matrix(batch['homography_params'][0], K))\n", + "print(matrix_to_homography_params(batch['homography_matrix'][0].numpy(), K))\n", + "\n" + ] + }, + { + "cell_type": "markdown", + "id": "e8072ee6", + "metadata": {}, + "source": [ + "## Model\n", + "\n", + "`HomographyCNN6` — CNN architecture for homography estimation.\n", + "\n", + "**Output:** 6 parameters\n", + "- `rx, ry, rz` — rotation angles (radians)\n", + "- `tx, ty` — translation offsets\n", + "- `scale` — isotropic scale factor\n", + "\n", + "**Architecture:**\n", + "- Dual-branch CNN (Google + Yandex images)\n", + "- Shared backbone (configurable: resnet18/34/50)\n" + ] + }, + { + "cell_type": "code", + "execution_count": 39, + "id": "c246531d", + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "\n", + "def angular_difference(pred_angles, target_angles):\n", + " diff = pred_angles - target_angles\n", + " diff = torch.atan2(torch.sin(diff), torch.cos(diff))\n", + " return torch.abs(diff)\n", + "\n", + "\n", + "class HomographyCNN6(nn.Module):\n", + " def __init__(self, input_channels=3, backbone_name=\"resnet18\", pretrained=True, dropout_rate=0.3):\n", + " super().__init__()\n", + " backbone = getattr(models, backbone_name)(weights=models.ResNet18_Weights.IMAGENET1K_V1 if pretrained else None)\n", + " self.feature_dim = backbone.fc.in_features\n", + " backbone.fc = nn.Identity()\n", + " self.backbone = backbone\n", + "\n", + " self.head = nn.Sequential(\n", + " nn.Linear(self.feature_dim * 4, 1024),\n", + " nn.ReLU(inplace=True),\n", + " nn.Dropout(dropout_rate),\n", + " nn.Linear(1024, 512),\n", + " nn.ReLU(inplace=True),\n", + " nn.Dropout(dropout_rate),\n", + " nn.Linear(512, 256),\n", + " nn.ReLU(inplace=True),\n", + " nn.Dropout(dropout_rate),\n", + " nn.Linear(256, 6),\n", + " )\n", + " self._init_weights()\n", + "\n", + " def _normalize_sin_cos(self, _sin, _cos):\n", + " _len = torch.sqrt(_sin ** 2 + _cos ** 2)\n", + " return _sin / _len, _cos / _len\n", + "\n", + " def _init_weights(self):\n", + " for module in self.head.modules():\n", + " if isinstance(module, nn.Linear):\n", + " nn.init.kaiming_normal_(module.weight, mode='fan_in', nonlinearity='relu')\n", + " if module.bias is not None:\n", + " nn.init.zeros_(module.bias)\n", + "\n", + " def forward(self, img1, img2):\n", + " f1 = self.backbone(img1)\n", + " f2 = self.backbone(img2)\n", + " combined = torch.cat([f1, f2, torch.abs(f1 - f2), f1 * f2], dim=1)\n", + "\n", + " output = self.head(combined)\n", + "\n", + " output = torch.tanh(output) # [-1; 1]\n", + " modified = output.clone()\n", + " modified[:, 2:6] = torch.mul(output[:, 2:6], torch.pi) # [-pi; pi]\n", + "\n", + " return modified\n", + "\n", + " def decode_output(self, output):\n", + " tx = output[:, 0]\n", + " ty = output[:, 1]\n", + " scale = output[:, 5]\n", + " angle1 = output[:, 2]\n", + " angle2 = output[:, 3]\n", + " angle3 = output[:, 4]\n", + "\n", + " return torch.stack([tx, ty, angle1, angle2, angle3, scale], dim=1)\n", + "\n", + " def get_components(self, output):\n", + " decoded = self.decode_output(output)\n", + " return {\n", + " \"tx\": decoded[:, 0],\n", + " \"ty\": decoded[:, 1],\n", + " \"rx\": decoded[:, 2],\n", + " \"ry\": decoded[:, 3],\n", + " \"rz\": decoded[:, 4],\n", + " \"scale\": decoded[:, 5],\n", + " }\n", + "\n", + "\n", + "class HomographyHybridCNN(nn.Module):\n", + " def __init__(self, input_channels=3, use_resnet_layers=2, dropout_rate=0.3):\n", + " super().__init__()\n", + " \n", + " if use_resnet_layers == 1:\n", + " resnet = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1)\n", + " self.conv1 = resnet.conv1\n", + " self.bn1 = resnet.bn1\n", + " self.relu = resnet.relu\n", + " self.maxpool = resnet.maxpool\n", + " conv_out_channels = 64\n", + " elif use_resnet_layers == 2:\n", + " resnet = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1)\n", + " self.conv1 = resnet.conv1\n", + " self.bn1 = resnet.bn1\n", + " self.relu = resnet.relu\n", + " self.maxpool = resnet.maxpool\n", + " self.conv2 = resnet.layer1[0].conv1\n", + " self.bn2 = resnet.layer1[0].bn1\n", + " self.conv2_2 = resnet.layer1[0].conv2\n", + " self.bn2_2 = resnet.layer1[0].bn2\n", + " self.relu2 = resnet.layer1[0].relu\n", + " self.maxpool2 = resnet.maxpool\n", + " conv_out_channels = 64\n", + " else:\n", + " raise ValueError(\"use_resnet_layers must be 1 or 2\")\n", + " \n", + " self.use_resnet_layers = use_resnet_layers\n", + " self.feature_map_size = 64\n", + " \n", + " self.conv_head = nn.Sequential(\n", + " nn.Conv2d(conv_out_channels, 128, kernel_size=3, padding=1),\n", + " nn.BatchNorm2d(128),\n", + " nn.ReLU(inplace=True),\n", + " nn.Conv2d(128, 256, kernel_size=3, padding=1),\n", + " nn.BatchNorm2d(256),\n", + " nn.ReLU(inplace=True),\n", + " nn.MaxPool2d(2),\n", + " )\n", + " \n", + " self.global_pool = nn.AdaptiveAvgPool2d((1, 1))\n", + " \n", + " feature_dim = 256 * 4\n", + " \n", + " self.head = nn.Sequential(\n", + " nn.Linear(feature_dim, 1024),\n", + " nn.ReLU(inplace=True),\n", + " nn.Dropout(dropout_rate),\n", + " nn.Linear(1024, 512),\n", + " nn.ReLU(inplace=True),\n", + " nn.Dropout(dropout_rate),\n", + " nn.Linear(512, 256),\n", + " nn.ReLU(inplace=True),\n", + " nn.Dropout(dropout_rate),\n", + " nn.Linear(256, 6),\n", + " )\n", + " self._init_weights()\n", + "\n", + " def _init_weights(self):\n", + " for module in self.head.modules():\n", + " if isinstance(module, nn.Linear):\n", + " nn.init.kaiming_normal_(module.weight, mode='fan_in', nonlinearity='relu')\n", + " if module.bias is not None:\n", + " nn.init.zeros_(module.bias)\n", + "\n", + " def forward(self, img1, img2):\n", + " x1 = self._extract_features(img1)\n", + " x2 = self._extract_features(img2)\n", + " \n", + " combined = torch.cat([x1, x2, torch.abs(x1 - x2), x1 * x2], dim=1)\n", + " output = self.head(combined)\n", + " \n", + " output = torch.tanh(output)\n", + " modified = output.clone()\n", + " modified[:, 2:6] = torch.mul(output[:, 2:6], torch.pi)\n", + " \n", + " return modified\n", + "\n", + " def _extract_features(self, x):\n", + " x = self.conv1(x)\n", + " x = self.bn1(x)\n", + " x = self.relu(x)\n", + " x = self.maxpool(x)\n", + " \n", + " if self.use_resnet_layers >= 2:\n", + " x = self.conv2(x)\n", + " x = self.bn2(x)\n", + " x = self.relu(x)\n", + " x = self.conv2_2(x)\n", + " x = self.bn2_2(x)\n", + " x = self.relu2(x)\n", + " x = self.maxpool2(x)\n", + " \n", + " x = self.conv_head(x)\n", + " x = self.global_pool(x)\n", + " x = x.view(x.size(0), -1)\n", + " \n", + " return x\n", + "\n", + " def decode_output(self, output):\n", + " tx = output[:, 0]\n", + " ty = output[:, 1]\n", + " scale = output[:, 5]\n", + " angle1 = output[:, 2]\n", + " angle2 = output[:, 3]\n", + " angle3 = output[:, 4]\n", + " return torch.stack([tx, ty, angle1, angle2, angle3, scale], dim=1)\n", + "\n", + " def get_components(self, output):\n", + " decoded = self.decode_output(output)\n", + " return {\n", + " \"tx\": decoded[:, 0],\n", + " \"ty\": decoded[:, 1],\n", + " \"rx\": decoded[:, 2],\n", + " \"ry\": decoded[:, 3],\n", + " \"rz\": decoded[:, 4],\n", + " \"scale\": decoded[:, 5],\n", + " }\n", + "\n", + "\n", + "class HomographyLoss6(nn.Module):\n", + " def __init__(self, angle_loss_weight=1.0, trans_loss_weight=1.0, scale_loss_weight=1.0):\n", + " super().__init__()\n", + " self.criterion = nn.MSELoss()\n", + " self.angle_loss_weight = angle_loss_weight\n", + " self.trans_loss_weight = trans_loss_weight\n", + " self.scale_loss_weight = scale_loss_weight\n", + "\n", + " @staticmethod\n", + " def dot_angles(src, dest):\n", + " sin_src = torch.sin(src)\n", + " cos_src = torch.cos(src)\n", + " sin_dest = torch.sin(dest)\n", + " cos_dest = torch.cos(dest)\n", + " return sin_src * sin_dest + cos_src * cos_dest\n", + "\n", + " def forward(self, pred, target):\n", + " tx_loss = self.criterion(pred[:, 0], target[:, 0])\n", + " ty_loss = self.criterion(pred[:, 1], target[:, 1])\n", + "\n", + " dot_rx = HomographyLoss6.dot_angles(pred[:, 2], target[:, 2])\n", + " dot_ry = HomographyLoss6.dot_angles(pred[:, 3], target[:, 3])\n", + " dot_rz = HomographyLoss6.dot_angles(pred[:, 4], target[:, 4])\n", + "\n", + " rx_loss = self.criterion(dot_rx, torch.ones_like(dot_rx))\n", + " ry_loss = self.criterion(dot_ry, torch.ones_like(dot_ry))\n", + " rz_loss = self.criterion(dot_rz, torch.ones_like(dot_rz))\n", + "\n", + " scale_loss = self.criterion(pred[:, 5], target[:, 5])\n", + "\n", + " total_loss = (\n", + " self.trans_loss_weight * (tx_loss + ty_loss) +\n", + " self.angle_loss_weight * (rx_loss + ry_loss + rz_loss) +\n", + " self.scale_loss_weight * scale_loss\n", + " )\n", + "\n", + " return total_loss\n", + "\n", + " def compute_mse_components(self, decoded, target):\n", + " tx_mse = self.criterion(decoded[:, 0], target[:, 0]).item()\n", + " ty_mse = self.criterion(decoded[:, 1], target[:, 1]).item()\n", + " \n", + " dot_rx = HomographyLoss6.dot_angles(decoded[:, 2], target[:, 2])\n", + " dot_ry = HomographyLoss6.dot_angles(decoded[:, 3], target[:, 3])\n", + " dot_rz = HomographyLoss6.dot_angles(decoded[:, 4], target[:, 4])\n", + "\n", + " rx_mse = self.criterion(dot_rx, torch.ones_like(dot_rx)).item()\n", + " ry_mse = self.criterion(dot_ry, torch.ones_like(dot_ry)).item()\n", + " rz_mse = self.criterion(dot_rz, torch.ones_like(dot_rz)).item()\n", + "\n", + " scale_mse = self.criterion(decoded[:, 5], target[:, 5]).item()\n", + "\n", + " avg_angle_loss = (rx_mse + ry_mse + rz_mse) / 3\n", + "\n", + " return {\n", + " 'trans': (tx_mse + ty_mse) / 2,\n", + " 'angle': avg_angle_loss,\n", + " 'scale': scale_mse\n", + " }\n", + "\n", + "\n", + "HomographyLoss = HomographyLoss6\n", + "\n", + "\n", + "def count_parameters(model):\n", + " return sum(p.numel() for p in model.parameters())\n", + "\n" + ] + }, + { + "cell_type": "markdown", + "id": "d0991e10", + "metadata": {}, + "source": [ + "## Training\n", + "\n", + "`HomographyTrainer` — training loop with validation and checkpointing.\n", + "\n", + "**Features:**\n", + "- Epoch-based training with tqdm progress bar\n", + "- Adam optimizer with configurable LR\n", + "- Validation after each epoch\n", + "- Best model auto-save\n", + "- Periodic checkpoints (every N epochs via `save_every_n_epochs`)\n", + "\n", + "**Checkpoint saving:**\n", + "- `best_model.pt` — lowest validation loss\n" + ] + }, + { + "cell_type": "code", + "execution_count": 40, + "id": "83ff7cc6", + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "\n", + "\n", + "\n", + "class HomographyTrainer:\n", + " def __init__(self, model, train_loader, val_loader, device, criterion):\n", + " self.model = model.to(device)\n", + " self.train_loader = train_loader\n", + " self.val_loader = val_loader\n", + " self.device = device\n", + " self.criterion = criterion\n", + " self.optimizer = optim.Adam(model.parameters(), lr=config[\"learning_rate\"], weight_decay=1e-4)\n", + " self.writer = None\n", + " self.best_val_loss = float(\"inf\")\n", + " self.train_losses = []\n", + " self.val_losses = []\n", + " self.train_mse_trans = []\n", + " self.train_mse_angle = []\n", + " self.train_mse_scale = []\n", + " self.val_mse_trans = []\n", + " self.val_mse_angle = []\n", + " self.val_mse_scale = []\n", + "\n", + " def train_epoch(self, epoch):\n", + " self.model.train()\n", + " total_loss, total_samples = 0, 0\n", + " mse_trans_sum, mse_angle_sum, mse_scale_sum = 0, 0, 0\n", + " pbar = tqdm(self.train_loader, desc=f\"Epoch {epoch}\")\n", + " for batch_idx, batch in enumerate(pbar):\n", + " google_img = batch[\"google_img\"].to(self.device)\n", + " yandex_img = batch[\"yandex_img\"].to(self.device)\n", + " target = batch[\"homography_params\"].to(self.device)\n", + "\n", + " self.optimizer.zero_grad()\n", + " output = self.model(google_img, yandex_img)\n", + " loss = self.criterion(output, target)\n", + " loss.backward()\n", + " self.optimizer.step()\n", + "\n", + " total_loss += loss.item() * google_img.size(0)\n", + " total_samples += google_img.size(0)\n", + " \n", + " decoded_output = self.model.decode_output(output)\n", + " mse_components = self.criterion.compute_mse_components(decoded_output, target)\n", + " mse_trans_sum += mse_components['trans'] * google_img.size(0)\n", + " mse_angle_sum += mse_components['angle'] * google_img.size(0)\n", + " mse_scale_sum += mse_components['scale'] * google_img.size(0)\n", + " \n", + " pbar.set_postfix({\"loss\": loss.item()})\n", + "\n", + " self.train_mse_trans.append(mse_trans_sum / total_samples)\n", + " self.train_mse_angle.append(mse_angle_sum / total_samples)\n", + " self.train_mse_scale.append(mse_scale_sum / total_samples)\n", + " \n", + " return {\"loss\": total_loss / total_samples}\n", + "\n", + " def validate(self):\n", + " self.model.eval()\n", + " total_loss, total_samples = 0, 0\n", + " mse_trans_sum, mse_angle_sum, mse_scale_sum = 0, 0, 0\n", + " with torch.no_grad():\n", + " for batch in tqdm(self.val_loader, desc=\"Validation\"):\n", + " google_img = batch[\"google_img\"].to(self.device)\n", + " yandex_img = batch[\"yandex_img\"].to(self.device)\n", + " target = batch[\"homography_params\"].to(self.device)\n", + " output = self.model(google_img, yandex_img)\n", + " decoded_output = self.model.decode_output(output)\n", + " loss = self.criterion(output, target)\n", + " total_loss += loss.item() * google_img.size(0)\n", + " total_samples += google_img.size(0)\n", + " \n", + " mse_components = self.criterion.compute_mse_components(decoded_output, target)\n", + " mse_trans_sum += mse_components['trans'] * google_img.size(0)\n", + " mse_angle_sum += mse_components['angle'] * google_img.size(0)\n", + " mse_scale_sum += mse_components['scale'] * google_img.size(0)\n", + " \n", + " self.val_mse_trans.append(mse_trans_sum / total_samples)\n", + " self.val_mse_angle.append(mse_angle_sum / total_samples)\n", + " self.val_mse_scale.append(mse_scale_sum / total_samples)\n", + " \n", + " return {\"loss\": total_loss / total_samples}\n", + "\n", + " def train(self, num_epochs):\n", + " log_dir = config[\"output_dir\"]\n", + " os.makedirs(log_dir, exist_ok=True)\n", + " self.writer = SummaryWriter(log_dir)\n", + "\n", + " for epoch in range(1, num_epochs + 1):\n", + " train_metrics = self.train_epoch(epoch)\n", + " val_metrics = self.validate()\n", + " self.train_losses.append(train_metrics[\"loss\"])\n", + " self.val_losses.append(val_metrics[\"loss\"])\n", + " print(f\"Train Loss: {train_metrics['loss']:.4f}, Val Loss: {val_metrics['loss']:.4f}\")\n", + " print(f\" MSE - Trans: {self.val_mse_trans[-1]:.4f}, Angle: {self.val_mse_angle[-1]:.4f}, Scale: {self.val_mse_scale[-1]:.4f}\")\n", + "\n", + " if val_metrics[\"loss\"] < self.best_val_loss:\n", + " self.best_val_loss = val_metrics[\"loss\"]\n", + " self.save_checkpoint(epoch, is_best=True)\n", + " print(f\"Best model saved (val loss: {val_metrics['loss']:.4f})\")\n", + "\n", + " if epoch % config[\"save_every_n_epochs\"] == 0:\n", + " self.save_checkpoint(epoch, is_best=False)\n", + " print(f\"Checkpoint saved at epoch {epoch}\")\n", + "\n", + " self.writer.close()\n", + "\n", + " def save_checkpoint(self, epoch, is_best=False):\n", + " ckpt_dir = os.path.join(config[\"output_dir\"], \"checkpoints\")\n", + " os.makedirs(ckpt_dir, exist_ok=True)\n", + " ckpt = {\"epoch\": epoch, \"model_state_dict\": self.model.state_dict(), \"val_loss\": self.best_val_loss}\n", + " torch.save(ckpt, os.path.join(ckpt_dir, f\"checkpoint_epoch_{epoch}.pt\"))\n", + " if is_best:\n", + " torch.save(ckpt, os.path.join(ckpt_dir, \"best_model.pt\"))\n", + "\n" + ] + }, + { + "cell_type": "markdown", + "id": "d6dbc5ea", + "metadata": {}, + "source": [ + "## Analysis\n", + "\n", + "Visualization and evaluation tools:\n", + "\n", + "- Training metrics plots (loss curves)\n", + "- Prediction visualization on sample images\n" + ] + }, + { + "cell_type": "code", + "execution_count": 41, + "id": "6fac17d5", + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "\n", + "sns.set_theme(style=\"whitegrid\", palette=\"muted\", font_scale=1.2)\n", + "\n", + "IMG_DIR = os.path.join(config[\"output_dir\"], \"images\")\n", + "os.makedirs(IMG_DIR, exist_ok=True)\n", + "\n", + "\n", + "def analyze_training(trainer):\n", + " print(\"=== Training Analysis ===\\n\")\n", + "\n", + " if trainer.writer:\n", + " print(\"TensorBoard logs available at:\", trainer.writer.log_dir)\n", + "\n", + " print(f\"\\nBest val loss: {trainer.best_val_loss:.4f}\")\n", + "\n", + " best_model_path = os.path.join(config[\"output_dir\"], \"checkpoints\", \"best_model.pt\")\n", + " if os.path.exists(best_model_path):\n", + " checkpoint = torch.load(best_model_path, map_location=trainer.device)\n", + " trainer.model.load_state_dict(checkpoint[\"model_state_dict\"])\n", + " print(f\"\\nLoaded best model from epoch {checkpoint['epoch']} (val loss: {checkpoint['val_loss']:.4f})\")\n", + " \n", + " trainer.model.eval()\n", + " \n", + " n_samples = 50\n", + " names = [\"tx\", \"ty\", \"rx\", \"ry\", \"rz\", \"scale\"]\n", + " \n", + " _, val_loader_for_analysis = create_data_loaders(\n", + " root_dir=config[\"data_dir\"],\n", + " batch_size=config[\"batch_size\"],\n", + " train_split=config[\"train_split\"],\n", + " num_workers=config[\"num_workers\"],\n", + " image_size=config[\"image_size\"],\n", + " augment_train=True,\n", + " cache_level=0,\n", + " )\n", + " \n", + " with torch.no_grad():\n", + " all_errors = [[] for _ in range(6)]\n", + " all_targets = [[] for _ in range(6)]\n", + " all_preds = [[] for _ in range(6)]\n", + " \n", + " sample_count = 0\n", + " for batch in val_loader_for_analysis:\n", + " if sample_count >= n_samples:\n", + " break\n", + " \n", + " google_img = batch[\"google_img\"].to(trainer.device)\n", + " yandex_img = batch[\"yandex_img\"].to(trainer.device)\n", + " target_params = batch[\"homography_params\"].to(trainer.device)\n", + " pred_params = trainer.model(google_img, yandex_img)\n", + " decoded_pred = trainer.model.decode_output(pred_params)\n", + " \n", + " batch_size = google_img.size(0)\n", + " for i in range(batch_size):\n", + " if sample_count >= n_samples:\n", + " break\n", + " \n", + " tx_error = torch.abs(decoded_pred[i, 0] - target_params[i, 0]).item()\n", + " ty_error = torch.abs(decoded_pred[i, 1] - target_params[i, 1]).item()\n", + " rx_error = angular_difference(decoded_pred[i, 2], target_params[i, 2]).item()\n", + " ry_error = angular_difference(decoded_pred[i, 3], target_params[i, 3]).item()\n", + " rz_error = angular_difference(decoded_pred[i, 4], target_params[i, 4]).item()\n", + " scale_error = torch.abs(decoded_pred[i, 5] - target_params[i, 5]).item()\n", + " \n", + " errors = [tx_error, ty_error, rx_error, ry_error, rz_error, scale_error]\n", + " target_reordered = target_params[i].cpu().numpy()\n", + " pred_reordered = decoded_pred[i].cpu().numpy()\n", + " \n", + " for j in range(6):\n", + " all_errors[j].append(errors[j])\n", + " all_targets[j].append(target_reordered[j])\n", + " all_preds[j].append(pred_reordered[j])\n", + " \n", + " sample_count += 1\n", + " \n", + " mean_errors = [np.mean(all_errors[i]) for i in range(6)]\n", + " std_errors = [np.std(all_errors[i]) for i in range(6)]\n", + " \n", + " angle_errors_deg = [np.degrees(mean_errors[i]) for i in range(2, 5)]\n", + " \n", + " all_targets_stacked = [np.array(all_targets[i]) for i in range(6)]\n", + " target_ranges = [np.ptp(all_targets_stacked[i]) for i in range(6)]\n", + " relative_errors = [mean_errors[i] / target_ranges[i] if target_ranges[i] > 1e-8 else 0 for i in range(6)]\n", + " \n", + " if len(trainer.train_losses) > 0:\n", + " epochs = range(1, len(trainer.train_losses) + 1)\n", + " fig, axes = plt.subplots(2, 2, figsize=(16, 12))\n", + " \n", + " axes[0, 0].plot(epochs, trainer.train_losses, color=\"#2ecc71\", linewidth=2, label=\"Train Loss\")\n", + " axes[0, 0].plot(epochs, trainer.val_losses, color=\"#e74c3c\", linewidth=2, label=\"Val Loss\")\n", + " axes[0, 0].set_xlabel(\"Epoch\")\n", + " axes[0, 0].set_ylabel(\"Loss\")\n", + " axes[0, 0].set_title(\"Training & Validation Loss\", fontweight=\"bold\")\n", + " axes[0, 0].legend(framealpha=0.9)\n", + " axes[0, 0].grid(True, alpha=0.3)\n", + " \n", + " axes[0, 1].plot(epochs, trainer.val_losses, color=\"#e74c3c\", linewidth=2, label=\"Val Loss\")\n", + " axes[0, 1].set_xlabel(\"Epoch\")\n", + " axes[0, 1].set_ylabel(\"Loss\")\n", + " axes[0, 1].set_title(\"Validation Loss\", fontweight=\"bold\")\n", + " axes[0, 1].legend(framealpha=0.9)\n", + " axes[0, 1].grid(True, alpha=0.3)\n", + " \n", + " axes[1, 0].plot(epochs, trainer.val_mse_trans, color=\"#3498db\", linewidth=2, label=\"Translation (tx, ty)\")\n", + " axes[1, 0].plot(epochs, trainer.val_mse_angle, color=\"#9b59b6\", linewidth=2, label=\"Angle (rx, ry, rz)\")\n", + " axes[1, 0].plot(epochs, trainer.val_mse_scale, color=\"#e67e22\", linewidth=2, label=\"Scale\")\n", + " axes[1, 0].set_xlabel(\"Epoch\")\n", + " axes[1, 0].set_ylabel(\"MSE\")\n", + " axes[1, 0].set_title(\"Validation MSE by Category\", fontweight=\"bold\")\n", + " axes[1, 0].legend(framealpha=0.9)\n", + " axes[1, 0].grid(True, alpha=0.3)\n", + " \n", + " x_pos = np.arange(6)\n", + " colors = [\"#3498db\", \"#e74c3c\", \"#9b59b6\", \"#2ecc71\", \"#f39c12\", \"#1abc9c\"]\n", + " bars = axes[1, 1].bar(x_pos, mean_errors, yerr=std_errors, capsize=6, color=colors, alpha=0.85, edgecolor=\"white\", linewidth=1.5)\n", + " axes[1, 1].set_xticks(x_pos)\n", + " axes[1, 1].set_xticklabels(names)\n", + " axes[1, 1].set_ylabel(\"Mean Absolute Error\")\n", + " axes[1, 1].set_title(f\"Mean Absolute Error per Parameter ({n_samples} samples)\", fontweight=\"bold\")\n", + " axes[1, 1].grid(True, alpha=0.3, axis=\"y\")\n", + " \n", + " plt.tight_layout()\n", + " plt.savefig(os.path.join(IMG_DIR, \"training_loss_plots.png\"), dpi=150, bbox_inches=\"tight\")\n", + " print(\"Saved training_loss_plots.png\")\n", + " plt.show()\n", + " \n", + " fig, axes = plt.subplots(2, 3, figsize=(18, 10))\n", + " colors = [\"#3498db\", \"#e74c3c\", \"#9b59b6\", \"#2ecc71\", \"#f39c12\", \"#1abc9c\"]\n", + " for j in range(6):\n", + " row = j // 3\n", + " col = j % 3\n", + " axes[row, col].bar(range(len(all_errors[j])), all_errors[j], color=colors[j], alpha=0.75)\n", + " axes[row, col].set_xlabel(\"Sample\", fontsize=10)\n", + " axes[row, col].set_ylabel(\"Absolute Error\", fontsize=10)\n", + " axes[row, col].set_title(f\"{names[j]}: Mean={np.mean(all_errors[j]):.4f}, Std={np.std(all_errors[j]):.4f}\", fontweight=\"bold\", fontsize=11)\n", + " axes[row, col].grid(True, alpha=0.3, axis=\"y\")\n", + " plt.suptitle(f\"Mean Absolute Error per Parameter ({n_samples} samples)\", fontsize=14, fontweight=\"bold\")\n", + " plt.tight_layout()\n", + " plt.savefig(os.path.join(IMG_DIR, \"mae_per_parameter.png\"), dpi=150, bbox_inches=\"tight\")\n", + " print(\"Saved mae_per_parameter.png\")\n", + " plt.show()\n", + " \n", + " fig, axes = plt.subplots(1, 3, figsize=(18, 6))\n", + " \n", + " x_pos = np.arange(6)\n", + " colors = [\"#3498db\", \"#e74c3c\", \"#9b59b6\", \"#2ecc71\", \"#f39c12\", \"#1abc9c\"]\n", + " bars = axes[0].bar(x_pos, mean_errors, yerr=std_errors, capsize=6, color=colors, alpha=0.85, edgecolor=\"white\", linewidth=1.5)\n", + " axes[0].set_xticks(x_pos)\n", + " axes[0].set_xticklabels(names)\n", + " axes[0].set_ylabel(\"Mean Absolute Error\")\n", + " axes[0].set_title(\"Mean Absolute Error per Parameter (with std)\", fontweight=\"bold\")\n", + " axes[0].grid(True, alpha=0.3, axis=\"y\")\n", + " \n", + " bp = axes[1].boxplot([all_errors[i] for i in range(6)], labels=names, patch_artist=True)\n", + " for patch, color in zip(bp[\"boxes\"], colors):\n", + " patch.set_facecolor(color)\n", + " patch.set_alpha(0.8)\n", + " axes[1].set_ylabel(\"Absolute Error\")\n", + " axes[1].set_title(f\"Error Distribution per Parameter ({n_samples} samples)\", fontweight=\"bold\")\n", + " axes[1].grid(True, alpha=0.3, axis=\"y\")\n", + " \n", + " rel_err_pos = np.arange(6)\n", + " bars = axes[2].bar(rel_err_pos, relative_errors, color=colors, alpha=0.85, edgecolor=\"white\", linewidth=1.5)\n", + " axes[2].set_xticks(rel_err_pos)\n", + " axes[2].set_xticklabels(names)\n", + " axes[2].set_ylabel(\"Relative Error (MAE / Range)\")\n", + " axes[2].set_title(\"Relative Error per Parameter\", fontweight=\"bold\")\n", + " axes[2].grid(True, alpha=0.3, axis=\"y\")\n", + " \n", + " plt.tight_layout()\n", + " plt.savefig(os.path.join(IMG_DIR, \"mae_boxplot.png\"), dpi=150, bbox_inches=\"tight\")\n", + " print(\"Saved mae_boxplot.png\")\n", + " plt.show()\n", + " \n", + " fig, axes = plt.subplots(1, 2, figsize=(14, 6))\n", + " \n", + " angle_names = [\"rx\", \"ry\", \"rz\"]\n", + " x_pos = np.arange(3)\n", + " colors_angle = [\"#9b59b6\", \"#2ecc71\", \"#f39c12\"]\n", + " bars = axes[0].bar(x_pos, angle_errors_deg, color=colors_angle, alpha=0.85, edgecolor=\"white\", linewidth=1.5)\n", + " axes[0].set_xticks(x_pos)\n", + " axes[0].set_xticklabels(angle_names)\n", + " axes[0].set_ylabel(\"Mean Absolute Error (degrees)\")\n", + " axes[0].set_title(\"Angle MAE in Degrees\", fontweight=\"bold\")\n", + " axes[0].grid(True, alpha=0.3, axis=\"y\")\n", + " for i, e in enumerate(angle_errors_deg):\n", + " axes[0].text(i, e + 0.5, f\"{e:.1f}°\", ha=\"center\", va=\"bottom\", fontsize=11, fontweight=\"bold\")\n", + " \n", + " trans_scale_errs = [mean_errors[0], mean_errors[1], mean_errors[5]]\n", + " trans_scale_names = [\"tx\", \"ty\", \"scale\"]\n", + " x_pos = np.arange(3)\n", + " colors_trans = [\"#3498db\", \"#e74c3c\", \"#1abc9c\"]\n", + " bars = axes[1].bar(x_pos, trans_scale_errs, color=colors_trans, alpha=0.85, edgecolor=\"white\", linewidth=1.5)\n", + " axes[1].set_xticks(x_pos)\n", + " axes[1].set_xticklabels(trans_scale_names)\n", + " axes[1].set_ylabel(\"Mean Absolute Error\")\n", + " axes[1].set_title(\"Translation & Scale MAE\", fontweight=\"bold\")\n", + " axes[1].grid(True, alpha=0.3, axis=\"y\")\n", + " for i, e in enumerate(trans_scale_errs):\n", + " axes[1].text(i, e + 0.01, f\"{e:.4f}\", ha=\"center\", va=\"bottom\", fontsize=11, fontweight=\"bold\")\n", + " \n", + " plt.tight_layout()\n", + " plt.savefig(os.path.join(IMG_DIR, \"mae_by_category.png\"), dpi=150, bbox_inches=\"tight\")\n", + " print(\"Saved mae_by_category.png\")\n", + " plt.show()\n", + " \n", + " print(\"\\n=== Sample Predictions (20 pairs) ===\")\n", + " n_vis_samples = 20\n", + " \n", + " with torch.no_grad():\n", + " vis_count = 0\n", + " for batch in val_loader_for_analysis:\n", + " if vis_count >= n_vis_samples:\n", + " break\n", + " batch_size = batch[\"google_img\"].size(0)\n", + " \n", + " for i in range(batch_size):\n", + " if vis_count >= n_vis_samples:\n", + " break\n", + " \n", + " google_img = batch[\"google_img\"][i:i+1].to(trainer.device)\n", + " yandex_img = batch[\"yandex_img\"][i:i+1].to(trainer.device)\n", + " target_params = batch[\"homography_params\"][i:i+1].to(trainer.device)\n", + " pred_params = trainer.model(google_img, yandex_img)\n", + " decoded_pred = trainer.model.decode_output(pred_params)\n", + " \n", + " tx_error = torch.abs(decoded_pred[0, 0] - target_params[0, 0]).item()\n", + " ty_error = torch.abs(decoded_pred[0, 1] - target_params[0, 1]).item()\n", + " rx_error = angular_difference(decoded_pred[0, 2], target_params[0, 2]).item()\n", + " ry_error = angular_difference(decoded_pred[0, 3], target_params[0, 3]).item()\n", + " rz_error = angular_difference(decoded_pred[0, 4], target_params[0, 4]).item()\n", + " scale_error = torch.abs(decoded_pred[0, 5] - target_params[0, 5]).item()\n", + " \n", + " errors = np.array([tx_error, ty_error, rx_error, ry_error, rz_error, scale_error])\n", + " targets = target_params[0].cpu().numpy()\n", + " preds = decoded_pred[0].cpu().numpy()\n", + " \n", + " fig, axes = plt.subplots(2, 2, figsize=(12, 10))\n", + " \n", + " axes[0, 0].imshow(google_img[0].cpu().permute(1, 2, 0))\n", + " axes[0, 0].set_title(\"Google Image\", fontweight=\"bold\", fontsize=12)\n", + " axes[0, 0].axis(\"off\")\n", + " \n", + " axes[0, 1].imshow(yandex_img[0].cpu().permute(1, 2, 0))\n", + " axes[0, 1].set_title(\"Yandex Image\", fontweight=\"bold\", fontsize=12)\n", + " axes[0, 1].axis(\"off\")\n", + " \n", + " x_pos = np.arange(6)\n", + " width = 0.35\n", + " axes[1, 0].bar(x_pos - width/2, targets, width, label=\"Target\", color=\"#3498db\", alpha=0.85)\n", + " axes[1, 0].bar(x_pos + width/2, preds, width, label=\"Predicted\", color=\"#e74c3c\", alpha=0.85)\n", + " axes[1, 0].set_xticks(x_pos)\n", + " axes[1, 0].set_xticklabels(names)\n", + " axes[1, 0].set_ylabel(\"Parameter Value\")\n", + " axes[1, 0].set_title(\"Target vs Predicted\", fontweight=\"bold\", fontsize=12)\n", + " axes[1, 0].legend(framealpha=0.9)\n", + " axes[1, 0].grid(True, alpha=0.3, axis=\"y\")\n", + " \n", + " colors = [\"#3498db\", \"#e74c3c\", \"#9b59b6\", \"#2ecc71\", \"#f39c12\", \"#1abc9c\"]\n", + " bars = axes[1, 1].bar(x_pos, errors, color=colors, alpha=0.85, edgecolor=\"white\", linewidth=1.2)\n", + " axes[1, 1].set_xticks(x_pos)\n", + " axes[1, 1].set_xticklabels(names)\n", + " axes[1, 1].set_ylabel(\"Absolute Error\")\n", + " axes[1, 1].set_title(f\"Prediction Error (Mean: {np.mean(errors):.4f})\", fontweight=\"bold\", fontsize=12)\n", + " axes[1, 1].grid(True, alpha=0.3, axis=\"y\")\n", + " for i_e, e in enumerate(errors):\n", + " axes[1, 1].text(i_e, e + 0.01, f\"{e:.3f}\", ha=\"center\", va=\"bottom\", fontsize=9)\n", + " \n", + " plt.suptitle(f\"Sample {vis_count + 1}\", fontsize=14, fontweight=\"bold\")\n", + " plt.tight_layout()\n", + " plt.savefig(os.path.join(IMG_DIR, f\"prediction_sample_{vis_count + 1:02d}.png\"), dpi=100, bbox_inches=\"tight\")\n", + " plt.show()\n", + " print(f\"Saved prediction_sample_{vis_count + 1:02d}.png\")\n", + " \n", + " vis_count += 1\n", + " \n", + " print(f\"\\nPrediction errors over {n_samples} samples:\")\n", + " print(f\"{'Param':<8} {'Mean Error':>12} {'Std Error':>12} {'Min':>8} {'Max':>8} {'Rel Err':>10}\")\n", + " print(\"-\" * 62)\n", + " for i in range(6):\n", + " mean_err = np.mean(all_errors[i])\n", + " std_err = np.std(all_errors[i])\n", + " min_err = np.min(all_errors[i])\n", + " max_err = np.max(all_errors[i])\n", + " rel_err = relative_errors[i]\n", + " print(f\"{names[i]:<8} {mean_err:>12.4f} {std_err:>12.4f} {min_err:>8.4f} {max_err:>8.4f} {rel_err:>10.4f}\")\n", + " \n", + " print(f\"\\nAngle errors in degrees:\")\n", + " print(f\"{'Param':<8} {'MAE (deg)':>12} {'MAE (rad)':>12}\")\n", + " print(\"-\" * 35)\n", + " for i, name in enumerate([\"rx\", \"ry\", \"rz\"]):\n", + " print(f\"{name:<8} {angle_errors_deg[i]:>12.2f} {mean_errors[i+2]:>12.4f}\")\n", + "\n", + " return {\n", + " \"best_val_loss\": trainer.best_val_loss,\n", + " \"train_losses\": trainer.train_losses,\n", + " \"val_losses\": trainer.val_losses,\n", + " \"val_mse_trans\": trainer.val_mse_trans,\n", + " \"val_mse_angle\": trainer.val_mse_angle,\n", + " \"val_mse_scale\": trainer.val_mse_scale,\n", + " \"mean_errors\": mean_errors,\n", + " \"std_errors\": std_errors,\n", + " \"angle_errors_deg\": angle_errors_deg,\n", + " \"relative_errors\": relative_errors,\n", + " }\n", + "\n" + ] + }, + { + "cell_type": "markdown", + "id": "e36c328c", + "metadata": {}, + "source": [ + "## Main Pipeline\n", + "\n", + "Executes the full training workflow:\n", + "1. Load dataset info\n", + "2. Create data loaders\n", + "3. Initialize model\n", + "4. Train with validation\n", + "5. Analyze and export results\n", + "\n", + "**Outputs:**\n", + "- Model checkpoints in `runs/checkpoints/`\n", + "- TensorBoard logs in `runs/`\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0a271cba", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2026-04-06 23:05:04,927 - ==================================================\n", + "2026-04-06 23:05:04,928 - SiaN Training Pipeline\n", + "2026-04-06 23:05:04,929 - ==================================================\n", + "2026-04-06 23:05:06,049 - Dataset: 33 samples, keys=['google_img', 'yandex_img', 'homography_matrix', 'homography_params']\n", + "2026-04-06 23:05:08,308 - Data loaders created: train=26, val=7\n" + ] + }, + { + "ename": "KeyError", + "evalue": "'droupout_rate'", + "output_type": "error", + "traceback": [ + "\u001b[31m---------------------------------------------------------------------------\u001b[39m", + "\u001b[31mKeyError\u001b[39m Traceback (most recent call last)", + "\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[42]\u001b[39m\u001b[32m, line 29\u001b[39m\n\u001b[32m 18\u001b[39m logger.info(\u001b[33mf\u001b[39m\u001b[33m\"\u001b[39m\u001b[33mData loaders created: train=\u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28mlen\u001b[39m(train_loader.dataset)\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m, val=\u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28mlen\u001b[39m(val_loader.dataset)\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m\"\u001b[39m)\n\u001b[32m 20\u001b[39m \u001b[38;5;66;03m# model = HomographyCNN6(\u001b[39;00m\n\u001b[32m 21\u001b[39m \u001b[38;5;66;03m# input_channels=3,\u001b[39;00m\n\u001b[32m 22\u001b[39m \u001b[38;5;66;03m# backbone_name=config[\"backbone\"],\u001b[39;00m\n\u001b[32m 23\u001b[39m \u001b[38;5;66;03m# pretrained=True,\u001b[39;00m\n\u001b[32m 24\u001b[39m \u001b[38;5;66;03m# dropout_rate=config[\"dropout_rate\"]\u001b[39;00m\n\u001b[32m 25\u001b[39m \u001b[38;5;66;03m# )\u001b[39;00m\n\u001b[32m 27\u001b[39m model = HomographyHybridCNN(\n\u001b[32m 28\u001b[39m input_channels=\u001b[32m3\u001b[39m,\n\u001b[32m---> \u001b[39m\u001b[32m29\u001b[39m dropout_rate=\u001b[43mconfig\u001b[49m\u001b[43m[\u001b[49m\u001b[33;43m\"\u001b[39;49m\u001b[33;43mdroupout_rate\u001b[39;49m\u001b[33;43m\"\u001b[39;49m\u001b[43m]\u001b[49m,\n\u001b[32m 30\u001b[39m )\n\u001b[32m 32\u001b[39m logger.info(\u001b[33mf\u001b[39m\u001b[33m\"\u001b[39m\u001b[33mModel created with \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mcount_parameters(model)\u001b[38;5;132;01m:\u001b[39;00m\u001b[33m,\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m parameters\u001b[39m\u001b[33m\"\u001b[39m)\n\u001b[32m 34\u001b[39m device = torch.device(\u001b[33m\"\u001b[39m\u001b[33mcuda\u001b[39m\u001b[33m\"\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m torch.cuda.is_available() \u001b[38;5;28;01melse\u001b[39;00m \u001b[33m\"\u001b[39m\u001b[33mcpu\u001b[39m\u001b[33m\"\u001b[39m)\n", + "\u001b[31mKeyError\u001b[39m: 'droupout_rate'" + ] + } + ], + "source": [ + "\n", + "\n", + "\n", + "\n", + "logging.basicConfig(level=logging.INFO, format=\"%(asctime)s - %(message)s\")\n", + "logger = logging.getLogger(__name__)\n", + "\n", + "logger.info(\"=\" * 50)\n", + "logger.info(\"SiaN Training Pipeline\")\n", + "logger.info(\"=\" * 50)\n", + "\n", + "dataset_info = get_dataset_info()\n", + "logger.info(f\"Dataset: {dataset_info['size']} samples, keys={dataset_info['sample_keys']}\")\n", + "\n", + "train_loader, val_loader = create_data_loaders(\n", + " root_dir=config[\"data_dir\"],\n", + " batch_size=config[\"batch_size\"],\n", + " train_split=config[\"train_split\"],\n", + " num_workers=config[\"num_workers\"],\n", + " image_size=config[\"image_size\"],\n", + ")\n", + "logger.info(f\"Data loaders created: train={len(train_loader.dataset)}, val={len(val_loader.dataset)}\")\n", + "\n", + "# model = HomographyCNN6(\n", + "# input_channels=3,\n", + "# backbone_name=config[\"backbone\"],\n", + "# pretrained=True,\n", + "# dropout_rate=config[\"dropout_rate\"]\n", + "# )\n", + "\n", + "model = HomographyHybridCNN(\n", + " input_channels=3,\n", + " dropout_rate=config[\"dropout_rate\"],\n", + ")\n", + "\n", + "logger.info(f\"Model created with {count_parameters(model):,} parameters\")\n", + "\n", + "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", + "logger.info(f\"Using device: {device}\")\n", + "\n", + "trainer = HomographyTrainer(model, train_loader, val_loader, device, HomographyLoss())\n", + "logger.info(\"Starting training...\")\n", + "trainer.train(config[\"epochs\"])\n", + "logger.info(\"Training completed\")\n", + "\n", + "logger.info(\"Analyzing model...\")\n", + "results = analyze_training(trainer)\n", + "logger.info(f\"Analysis complete: best_val_loss={results['best_val_loss']:.4f}\")\n", + "\n", + "logger.info(\"=\" * 50)\n", + "logger.info(\"Pipeline completed successfully\")\n", + "logger.info(\"=\" * 50)\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c1ea5b8f", + "metadata": {}, + "outputs": [], + "source": [ + "!zip artefacts.zip runs/checkpoints/best_model.pt runs/images/ runs/events.*\n", + "\n" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": ".venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "name": "python", + "version": "3.11.0" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/models/SiaN/src/main.py b/models/SiaN/src/main.py index 8e62df9..0bdd13c 100644 --- a/models/SiaN/src/main.py +++ b/models/SiaN/src/main.py @@ -4,7 +4,7 @@ import os import torch from .dataloader import create_data_loaders, get_dataset_info -from .model import HomographyCNN6, count_parameters +from .model import HomographyCNN6, HomographyHybridCNN, HomographyLoss, count_parameters from .train import HomographyTrainer from .analyze import analyze_training from .utils import config @@ -29,18 +29,24 @@ train_loader, val_loader = create_data_loaders( ) logger.info(f"Data loaders created: train={len(train_loader.dataset)}, val={len(val_loader.dataset)}") -model = HomographyCNN6( +# model = HomographyCNN6( +# input_channels=3, +# backbone_name=config["backbone"], +# pretrained=True, +# dropout_rate=config["dropout_rate"] +# ) + +model = HomographyHybridCNN( input_channels=3, - backbone_name=config["backbone"], - pretrained=True, - dropout_rate=config["dropout_rate"] + dropout_rate=config["droupout_rate"], ) + logger.info(f"Model created with {count_parameters(model):,} parameters") device = torch.device("cuda" if torch.cuda.is_available() else "cpu") logger.info(f"Using device: {device}") -trainer = HomographyTrainer(model, train_loader, val_loader, device) +trainer = HomographyTrainer(model, train_loader, val_loader, device, HomographyLoss()) logger.info("Starting training...") trainer.train(config["epochs"]) logger.info("Training completed") diff --git a/models/SiaN/src/model.py b/models/SiaN/src/model.py index 7ed8f52..bdf3043 100644 --- a/models/SiaN/src/model.py +++ b/models/SiaN/src/model.py @@ -77,6 +77,126 @@ class HomographyCNN6(nn.Module): } +class HomographyHybridCNN(nn.Module): + def __init__(self, input_channels=3, use_resnet_layers=2, dropout_rate=0.3): + super().__init__() + + if use_resnet_layers == 1: + resnet = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1) + self.conv1 = resnet.conv1 + self.bn1 = resnet.bn1 + self.relu = resnet.relu + self.maxpool = resnet.maxpool + conv_out_channels = 64 + elif use_resnet_layers == 2: + resnet = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1) + self.conv1 = resnet.conv1 + self.bn1 = resnet.bn1 + self.relu = resnet.relu + self.maxpool = resnet.maxpool + self.conv2 = resnet.layer1[0].conv1 + self.bn2 = resnet.layer1[0].bn1 + self.conv2_2 = resnet.layer1[0].conv2 + self.bn2_2 = resnet.layer1[0].bn2 + self.relu2 = resnet.layer1[0].relu + self.maxpool2 = resnet.maxpool + conv_out_channels = 64 + else: + raise ValueError("use_resnet_layers must be 1 or 2") + + self.use_resnet_layers = use_resnet_layers + self.feature_map_size = 64 + + self.conv_head = nn.Sequential( + nn.Conv2d(conv_out_channels, 128, kernel_size=3, padding=1), + nn.BatchNorm2d(128), + nn.ReLU(inplace=True), + nn.Conv2d(128, 256, kernel_size=3, padding=1), + nn.BatchNorm2d(256), + nn.ReLU(inplace=True), + nn.MaxPool2d(2), + ) + + self.global_pool = nn.AdaptiveAvgPool2d((1, 1)) + + feature_dim = 256 * 4 + + self.head = nn.Sequential( + nn.Linear(feature_dim, 1024), + nn.ReLU(inplace=True), + nn.Dropout(dropout_rate), + nn.Linear(1024, 512), + nn.ReLU(inplace=True), + nn.Dropout(dropout_rate), + nn.Linear(512, 256), + nn.ReLU(inplace=True), + nn.Dropout(dropout_rate), + nn.Linear(256, 6), + ) + self._init_weights() + + def _init_weights(self): + for module in self.head.modules(): + if isinstance(module, nn.Linear): + nn.init.kaiming_normal_(module.weight, mode='fan_in', nonlinearity='relu') + if module.bias is not None: + nn.init.zeros_(module.bias) + + def forward(self, img1, img2): + x1 = self._extract_features(img1) + x2 = self._extract_features(img2) + + combined = torch.cat([x1, x2, torch.abs(x1 - x2), x1 * x2], dim=1) + output = self.head(combined) + + output = torch.tanh(output) + modified = output.clone() + modified[:, 2:6] = torch.mul(output[:, 2:6], torch.pi) + + return modified + + def _extract_features(self, x): + x = self.conv1(x) + x = self.bn1(x) + x = self.relu(x) + x = self.maxpool(x) + + if self.use_resnet_layers >= 2: + x = self.conv2(x) + x = self.bn2(x) + x = self.relu(x) + x = self.conv2_2(x) + x = self.bn2_2(x) + x = self.relu2(x) + x = self.maxpool2(x) + + x = self.conv_head(x) + x = self.global_pool(x) + x = x.view(x.size(0), -1) + + return x + + def decode_output(self, output): + tx = output[:, 0] + ty = output[:, 1] + scale = output[:, 5] + angle1 = output[:, 2] + angle2 = output[:, 3] + angle3 = output[:, 4] + return torch.stack([tx, ty, angle1, angle2, angle3, scale], dim=1) + + def get_components(self, output): + decoded = self.decode_output(output) + return { + "tx": decoded[:, 0], + "ty": decoded[:, 1], + "rx": decoded[:, 2], + "ry": decoded[:, 3], + "rz": decoded[:, 4], + "scale": decoded[:, 5], + } + + class HomographyLoss6(nn.Module): def __init__(self, angle_loss_weight=1.0, trans_loss_weight=1.0, scale_loss_weight=1.0): super().__init__() @@ -138,5 +258,8 @@ class HomographyLoss6(nn.Module): } +HomographyLoss = HomographyLoss6 + + def count_parameters(model): return sum(p.numel() for p in model.parameters()) diff --git a/models/SiaN/src/train.py b/models/SiaN/src/train.py index a9582ee..2b31fb0 100644 --- a/models/SiaN/src/train.py +++ b/models/SiaN/src/train.py @@ -12,12 +12,12 @@ from .utils import config class HomographyTrainer: - def __init__(self, model, train_loader, val_loader, device): + def __init__(self, model, train_loader, val_loader, device, criterion): self.model = model.to(device) self.train_loader = train_loader self.val_loader = val_loader self.device = device - self.criterion = HomographyLoss6() + self.criterion = criterion self.optimizer = optim.Adam(model.parameters(), lr=config["learning_rate"], weight_decay=1e-4) self.writer = None self.best_val_loss = float("inf")