diff --git a/models/SiaN/notebook.gen.ipynb b/models/SiaN/notebook.gen.ipynb index cf1923a..efdfc68 100644 --- a/models/SiaN/notebook.gen.ipynb +++ b/models/SiaN/notebook.gen.ipynb @@ -1,853 +1,823 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": 1, - "metadata": {}, - "outputs": [], - "source": [ - "import os\n", - "import random\n", - "import logging\n", - "from typing import Tuple\n", - "import cv2\n", - "import numpy as np\n", - "import torch\n", - "import torch.nn as nn\n", - "import torch.optim as optim\n", - "import matplotlib.pyplot as plt\n", - "from PIL import Image\n", - "from torch.utils.data import DataLoader, Dataset, Subset\n", - "from torch.utils.tensorboard import SummaryWriter\n", - "from torchvision import transforms, models\n", - "from tqdm import tqdm\n" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Configuration\n", - "\n", - "Global settings for:\n", - "- Data paths and image parameters\n", - "- Training hyperparameters\n", - "- Model architecture options\n" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "metadata": {}, - "outputs": [], - "source": [ - "\n", - "\n", - "config = {\n", - " \"data_dir\": r\"C:\\Users\\admin\\Projects\\autopilot\\datasets\\ya_go_maps\\images\",\n", - " \"image_size\": (256, 256),\n", - " \"batch_size\": 32,\n", - " \"train_split\": 0.8,\n", - " \"num_workers\": 0,\n", - " \"epochs\": 100,\n", - " \"learning_rate\": 2e-4,\n", - " \"dropout_rate\": 0.3,\n", - " \"backbone\": \"resnet18\",\n", - " \"output_dir\": r\"C:\\Users\\admin\\Projects\\autopilot\\models\\SiaN\\runs\",\n", - " \"save_every_n_epochs\": 15,\n", - "}\n", - "\n", - "\n", - "def get_camera_matrix(w, h):\n", - " return np.array([[w / 2, 0, w / 2], [0, h / 2, h / 2], [0, 0, 1]], dtype=np.float32)\n", - "\n", - "\n", - "def generate_random_homography_params(angle_range=10, translation_range=0.1, scale_range=(0.9, 1.1)):\n", - " scale = np.random.uniform(*scale_range)\n", - " tx = np.random.uniform(-translation_range, translation_range)\n", - " ty = np.random.uniform(-translation_range, translation_range)\n", - " rx = np.radians(np.random.uniform(-angle_range, angle_range))\n", - " ry = np.radians(np.random.uniform(-angle_range, angle_range))\n", - " rz = np.radians(np.random.uniform(-angle_range, angle_range))\n", - " return np.array([rx, ry, rz, tx, ty, scale])\n", - "\n", - "\n", - "def homography_params_to_matrix(params, K):\n", - " rx, ry, rz, tx, ty, scale = params\n", - " cy, sy = np.cos(rz), np.sin(rz)\n", - " cp, sp = np.cos(ry), np.sin(ry)\n", - " cr, sr = np.cos(rx), np.sin(rx)\n", - " Rz = np.array([[cy, -sy, 0], [sy, cy, 0], [0, 0, 1]], dtype=np.float32)\n", - " Ry = np.array([[cp, 0, sp], [0, 1, 0], [-sp, 0, cp]], dtype=np.float32)\n", - " Rx = np.array([[1, 0, 0], [0, cr, -sr], [0, sr, cr]], dtype=np.float32)\n", - " T = np.array([[1, 0, tx], [0, 1, ty], [0, 0, scale]], dtype=np.float32)\n", - " return K @ Rx @ Ry @ Rz @ T @ np.linalg.inv(K)\n", - "\n", - "\n", - "def matrix_to_homography_params(H, K):\n", - " K_inv = np.linalg.inv(K)\n", - " E = K_inv @ H @ K\n", - " scale = np.sqrt(np.linalg.det(E[:2, :2]))\n", - " R = E[:2, :2] / scale\n", - " tx, ty = E[0, 2], E[1, 2]\n", - " rz = np.arctan2(R[1, 0], R[0, 0])\n", - " r20, r21 = E[2, 0], E[2, 1]\n", - " ry = np.arctan2(r20, r21)\n", - " rx = np.arctan2(-E[1, 2], E[1, 1])\n", - " return np.array([rx, ry, rz, tx, ty, scale], dtype=np.float32)\n", - "\n" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Dataset\n", - "\n", - "Google/Yandex image pair loader with homography augmentation.\n", - "\n", - "**Features:**\n", - "- Loads paired images from dual camera sources\n", - "- Applies random homography transformations\n", - "- Supports configurable train/val split\n", - "\n", - "**Returns:**\n" - ] - }, - { - "cell_type": "code", - "execution_count": 15, - "id": "8740e758", - "metadata": {}, - "outputs": [], - "source": [ - "\n", - "\n", - "\n", - "\n", - "class YaGoDataset(Dataset):\n", - " def __init__(self, root_dir: str, transform=None, augment: bool = True, \n", - " image_size: Tuple[int, int] = (256, 256), cache_level: int = 5):\n", - " self.root_dir = root_dir\n", - " self.transform = transform\n", - " self.augment = augment\n", - " self.image_size = image_size\n", - " self.cache_level = cache_level\n", - " self.K = get_camera_matrix(image_size[1], image_size[0])\n", - " self.image_pairs = self._discover_image_pairs()\n", - " self._load_images_to_memory()\n", - " self._init_cache()\n", - "\n", - " def _discover_image_pairs(self):\n", - " pairs = []\n", - " for f in os.listdir(self.root_dir):\n", - " if f.endswith(\"_google.png\"):\n", - " idx = f.split(\"_\")[0]\n", - " yandex_path = os.path.join(self.root_dir, f\"{idx}_yandex.png\")\n", - " if os.path.exists(yandex_path):\n", - " pairs.append({\"idx\": int(idx), \"google\": os.path.join(self.root_dir, f), \"yandex\": yandex_path})\n", - " return sorted(pairs, key=lambda x: x[\"idx\"])\n", - "\n", - " def _load_images_to_memory(self):\n", - " self._google_images = []\n", - " self._yandex_images = []\n", - " for pair in self.image_pairs:\n", - " google_img = cv2.imread(pair[\"google\"])\n", - " google_img = cv2.cvtColor(google_img, cv2.COLOR_BGR2RGB)\n", - " google_img = cv2.resize(google_img, (self.image_size[1], self.image_size[0]), interpolation=cv2.INTER_LINEAR)\n", - " \n", - " yandex_img = cv2.imread(pair[\"yandex\"])\n", - " yandex_img = cv2.cvtColor(yandex_img, cv2.COLOR_BGR2RGB)\n", - " yandex_img = cv2.resize(yandex_img, (self.image_size[1], self.image_size[0]), interpolation=cv2.INTER_LINEAR)\n", - " \n", - " self._google_images.append(google_img)\n", - " self._yandex_images.append(yandex_img)\n", - "\n", - " def _init_cache(self):\n", - " self._access_counts = [0] * len(self.image_pairs)\n", - " self._cached_google = [None] * len(self.image_pairs)\n", - " self._cached_yandex = [None] * len(self.image_pairs)\n", - " self._cached_homography = [None] * len(self.image_pairs)\n", - "\n", - " def _generate_augmented(self, idx):\n", - " google_img = self._google_images[idx].copy()\n", - " yandex_img = self._yandex_images[idx].copy()\n", - "\n", - " params1 = generate_random_homography_params()\n", - " params2 = generate_random_homography_params()\n", - " H1 = homography_params_to_matrix(params1, self.K)\n", - " H2 = homography_params_to_matrix(params2, self.K)\n", - " H_combined = np.linalg.inv(H1) @ H2\n", - " \n", - " google_warped = cv2.warpPerspective(google_img, H2, (self.image_size[1], self.image_size[0]))\n", - " yandex_warped = cv2.warpPerspective(yandex_img, H1, (self.image_size[1], self.image_size[0]))\n", - " \n", - " target_params = matrix_to_homography_params(H_combined, self.K)\n", - " \n", - " return google_warped, yandex_warped, H_combined, target_params\n", - "\n", - " def __len__(self):\n", - " return len(self.image_pairs)\n", - "\n", - " def __getitem__(self, idx):\n", - " self._access_counts[idx] += 1\n", - " \n", - " use_cache = self.augment and self.cache_level > 0 and self._access_counts[idx] > 1 and (self._access_counts[idx] - 1) % self.cache_level != 0\n", - " \n", - " if use_cache:\n", - " google_img = self._cached_google[idx]\n", - " yandex_img = self._cached_yandex[idx]\n", - " target_matrix = self._cached_homography[idx]\n", - " target_params = matrix_to_homography_params(target_matrix, self.K)\n", - " elif self.augment:\n", - " google_img, yandex_img, target_matrix, target_params = self._generate_augmented(idx)\n", - " if self.cache_level > 0:\n", - " self._cached_google[idx] = google_img\n", - " self._cached_yandex[idx] = yandex_img\n", - " self._cached_homography[idx] = target_matrix\n", - " else:\n", - " google_img = self._google_images[idx]\n", - " yandex_img = self._yandex_images[idx]\n", - " target_params = np.zeros(6, dtype=np.float32)\n", - " target_matrix = np.eye(3, dtype=np.float32)\n", - "\n", - " google_img = Image.fromarray(google_img)\n", - " yandex_img = Image.fromarray(yandex_img)\n", - "\n", - " if self.transform:\n", - " google_img = self.transform(google_img)\n", - " yandex_img = self.transform(yandex_img)\n", - "\n", - " return {\n", - " \"google_img\": google_img,\n", - " \"yandex_img\": yandex_img,\n", - " \"homography_matrix\": torch.from_numpy(target_matrix).float(),\n", - " \"homography_params\": torch.from_numpy(target_params).float(),\n", - " }\n", - "\n", - "\n", - "def create_data_loaders(root_dir, batch_size=32, train_split=0.8, num_workers=0, \n", - " image_size=(256, 256), augment_train=True, cache_level=5):\n", - " transform = transforms.Compose([\n", - " transforms.ToTensor(),\n", - " transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),\n", - " ])\n", - " \n", - " full_ds = YaGoDataset(root_dir, transform=transform, augment=False, image_size=image_size, cache_level=cache_level)\n", - " aug_ds = YaGoDataset(root_dir, transform=transform, augment=True, image_size=image_size, cache_level=cache_level)\n", - "\n", - " indices = list(range(len(full_ds)))\n", - " random.shuffle(indices)\n", - " split = int(train_split * len(indices))\n", - " \n", - " train_ds = Subset(aug_ds if augment_train else full_ds, indices[:split])\n", - " val_ds = Subset(aug_ds, indices[split:])\n", - "\n", - " return (DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=True),\n", - " DataLoader(val_ds, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=True))\n", - "\n", - "\n", - "def get_dataset_info():\n", - " ds = YaGoDataset(config[\"data_dir\"], augment=True, image_size=config[\"image_size\"])\n", - " return {\n", - " \"size\": len(ds),\n", - " \"sample_keys\": list(ds[0].keys()),\n", - " \"sample_params\": ds[0][\"homography_params\"].numpy()\n", - " }\n", - "\n" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Model\n", - "\n", - "`HomographyCNN6` — CNN architecture for homography estimation.\n", - "\n", - "**Output:** 6 parameters\n", - "- `rx, ry, rz` — rotation angles (radians)\n", - "- `tx, ty` — translation offsets\n", - "- `scale` — isotropic scale factor\n", - "\n", - "**Architecture:**\n", - "- Dual-branch CNN (Google + Yandex images)\n", - "- Shared backbone (configurable: resnet18/34/50)\n" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "metadata": {}, - "outputs": [], - "source": [ - "\n", - "\n", - "class HomographyCNN6(nn.Module):\n", - " def __init__(self, input_channels=3, backbone_name=\"resnet18\", pretrained=True, dropout_rate=0.3):\n", - " super().__init__()\n", - " backbone = getattr(models, backbone_name)(weights=models.ResNet18_Weights.IMAGENET1K_V1 if pretrained else None)\n", - " self.feature_dim = backbone.fc.in_features\n", - " backbone.fc = nn.Identity()\n", - " self.backbone = backbone\n", - "\n", - " self.head = nn.Sequential(\n", - " nn.Linear(self.feature_dim * 4, 512),\n", - " nn.ReLU(inplace=True),\n", - " nn.Dropout(dropout_rate),\n", - " nn.Linear(512, 256),\n", - " nn.ReLU(inplace=True),\n", - " nn.Dropout(dropout_rate),\n", - " nn.Linear(256, 6),\n", - " )\n", - "\n", - " def forward(self, img1, img2):\n", - " f1 = self.backbone(img1)\n", - " f2 = self.backbone(img2)\n", - " combined = torch.cat([f1, f2, torch.abs(f1 - f2), f1 * f2], dim=1)\n", - " return self.head(combined)\n", - "\n", - "\n", - "class HomographyLoss6(nn.Module):\n", - " def __init__(self):\n", - " super().__init__()\n", - " self.criterion = nn.MSELoss()\n", - "\n", - " def forward(self, pred, target):\n", - " return self.criterion(pred, target)\n", - "\n", - "\n", - "def count_parameters(model):\n", - " return sum(p.numel() for p in model.parameters())\n", - "\n" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Training\n", - "\n", - "`HomographyTrainer` — training loop with validation and checkpointing.\n", - "\n", - "**Features:**\n", - "- Epoch-based training with tqdm progress bar\n", - "- Adam optimizer with configurable LR\n", - "- Validation after each epoch\n", - "- Best model auto-save\n", - "- Periodic checkpoints (every N epochs via `save_every_n_epochs`)\n", - "\n", - "**Checkpoint saving:**\n", - "- `best_model.pt` — lowest validation loss\n" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "metadata": {}, - "outputs": [], - "source": [ - "\n", - "\n", - "\n", - "\n", - "class HomographyTrainer:\n", - " def __init__(self, model, train_loader, val_loader, device):\n", - " self.model = model.to(device)\n", - " self.train_loader = train_loader\n", - " self.val_loader = val_loader\n", - " self.device = device\n", - " self.criterion = HomographyLoss6()\n", - " self.optimizer = optim.Adam(model.parameters(), lr=config[\"learning_rate\"])\n", - " self.writer = None\n", - " self.best_val_loss = float(\"inf\")\n", - " self.train_losses = []\n", - " self.val_losses = []\n", - " self.train_mse_trans = []\n", - " self.train_mse_angle = []\n", - " self.train_mse_scale = []\n", - " self.val_mse_trans = []\n", - " self.val_mse_angle = []\n", - " self.val_mse_scale = []\n", - "\n", - " def train_epoch(self, epoch):\n", - " self.model.train()\n", - " total_loss, total_samples = 0, 0\n", - " mse_trans_sum, mse_angle_sum, mse_scale_sum = 0, 0, 0\n", - " pbar = tqdm(self.train_loader, desc=f\"Epoch {epoch}\")\n", - " for batch_idx, batch in enumerate(pbar):\n", - " google_img = batch[\"google_img\"].to(self.device)\n", - " yandex_img = batch[\"yandex_img\"].to(self.device)\n", - " target = batch[\"homography_params\"].to(self.device)\n", - "\n", - " self.optimizer.zero_grad()\n", - " output = self.model(google_img, yandex_img)\n", - " loss = self.criterion(output, target)\n", - " loss.backward()\n", - " self.optimizer.step()\n", - "\n", - " total_loss += loss.item() * google_img.size(0)\n", - " total_samples += google_img.size(0)\n", - " \n", - " mse_trans_sum += torch.mean((output[:, 3:5] - target[:, 3:5]) ** 2).item() * google_img.size(0)\n", - " mse_angle_sum += torch.mean((output[:, 0:3] - target[:, 0:3]) ** 2).item() * google_img.size(0)\n", - " mse_scale_sum += torch.mean((output[:, 5:6] - target[:, 5:6]) ** 2).item() * google_img.size(0)\n", - " \n", - " pbar.set_postfix({\"loss\": loss.item()})\n", - "\n", - " self.train_mse_trans.append(mse_trans_sum / total_samples)\n", - " self.train_mse_angle.append(mse_angle_sum / total_samples)\n", - " self.train_mse_scale.append(mse_scale_sum / total_samples)\n", - " \n", - " return {\"loss\": total_loss / total_samples}\n", - "\n", - " def validate(self):\n", - " self.model.eval()\n", - " total_loss, total_samples = 0, 0\n", - " mse_trans_sum, mse_angle_sum, mse_scale_sum = 0, 0, 0\n", - " with torch.no_grad():\n", - " for batch in tqdm(self.val_loader, desc=\"Validation\"):\n", - " google_img = batch[\"google_img\"].to(self.device)\n", - " yandex_img = batch[\"yandex_img\"].to(self.device)\n", - " target = batch[\"homography_params\"].to(self.device)\n", - " output = self.model(google_img, yandex_img)\n", - " loss = self.criterion(output, target)\n", - " total_loss += loss.item() * google_img.size(0)\n", - " total_samples += google_img.size(0)\n", - " \n", - " mse_trans_sum += torch.mean((output[:, 3:5] - target[:, 3:5]) ** 2).item() * google_img.size(0)\n", - " mse_angle_sum += torch.mean((output[:, 0:3] - target[:, 0:3]) ** 2).item() * google_img.size(0)\n", - " mse_scale_sum += torch.mean((output[:, 5:6] - target[:, 5:6]) ** 2).item() * google_img.size(0)\n", - " \n", - " self.val_mse_trans.append(mse_trans_sum / total_samples)\n", - " self.val_mse_angle.append(mse_angle_sum / total_samples)\n", - " self.val_mse_scale.append(mse_scale_sum / total_samples)\n", - " \n", - " return {\"loss\": total_loss / total_samples}\n", - "\n", - " def train(self, num_epochs):\n", - " log_dir = config[\"output_dir\"]\n", - " os.makedirs(log_dir, exist_ok=True)\n", - " self.writer = SummaryWriter(log_dir)\n", - "\n", - " for epoch in range(1, num_epochs + 1):\n", - " train_metrics = self.train_epoch(epoch)\n", - " val_metrics = self.validate()\n", - " self.train_losses.append(train_metrics[\"loss\"])\n", - " self.val_losses.append(val_metrics[\"loss\"])\n", - " print(f\"Train Loss: {train_metrics['loss']:.4f}, Val Loss: {val_metrics['loss']:.4f}\")\n", - " print(f\" MSE - Trans: {self.val_mse_trans[-1]:.4f}, Angle: {self.val_mse_angle[-1]:.4f}, Scale: {self.val_mse_scale[-1]:.4f}\")\n", - "\n", - " if val_metrics[\"loss\"] < self.best_val_loss:\n", - " self.best_val_loss = val_metrics[\"loss\"]\n", - " self.save_checkpoint(epoch, is_best=True)\n", - " print(f\"Best model saved (val loss: {val_metrics['loss']:.4f})\")\n", - "\n", - " if epoch % config[\"save_every_n_epochs\"] == 0:\n", - " self.save_checkpoint(epoch, is_best=False)\n", - " print(f\"Checkpoint saved at epoch {epoch}\")\n", - "\n", - " self.writer.close()\n", - "\n", - " def save_checkpoint(self, epoch, is_best=False):\n", - " ckpt_dir = os.path.join(config[\"output_dir\"], \"checkpoints\")\n", - " os.makedirs(ckpt_dir, exist_ok=True)\n", - " ckpt = {\"epoch\": epoch, \"model_state_dict\": self.model.state_dict(), \"val_loss\": self.best_val_loss}\n", - " torch.save(ckpt, os.path.join(ckpt_dir, f\"checkpoint_epoch_{epoch}.pt\"))\n", - " if is_best:\n", - " torch.save(ckpt, os.path.join(ckpt_dir, \"best_model.pt\"))\n", - "\n" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Analysis\n", - "\n", - "Visualization and evaluation tools:\n", - "\n", - "- Training metrics plots (loss curves)\n", - "- Prediction visualization on sample images\n" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "metadata": {}, - "outputs": [], - "source": [ - "\n", - "\n", - "IMG_DIR = os.path.join(config[\"output_dir\"], \"images\")\n", - "os.makedirs(IMG_DIR, exist_ok=True)\n", - "\n", - "\n", - "def analyze_training(trainer):\n", - " print(\"=== Training Analysis ===\\n\")\n", - "\n", - " if trainer.writer:\n", - " print(\"TensorBoard logs available at:\", trainer.writer.log_dir)\n", - "\n", - " print(f\"\\nBest val loss: {trainer.best_val_loss:.4f}\")\n", - "\n", - " best_model_path = os.path.join(config[\"output_dir\"], \"checkpoints\", \"best_model.pt\")\n", - " if os.path.exists(best_model_path):\n", - " checkpoint = torch.load(best_model_path, map_location=trainer.device)\n", - " trainer.model.load_state_dict(checkpoint[\"model_state_dict\"])\n", - " print(f\"\\nLoaded best model from epoch {checkpoint['epoch']} (val loss: {checkpoint['val_loss']:.4f})\")\n", - " \n", - " trainer.model.eval()\n", - " \n", - " n_samples = 50\n", - " names = [\"rx\", \"ry\", \"rz\", \"tx\", \"ty\", \"scale\"]\n", - " \n", - " with torch.no_grad():\n", - " all_errors = [[] for _ in range(6)]\n", - " all_targets = [[] for _ in range(6)]\n", - " all_preds = [[] for _ in range(6)]\n", - " \n", - " for i in range(n_samples):\n", - " try:\n", - " batch = next(iter(trainer.val_loader))\n", - " except StopIteration:\n", - " break\n", - " google_img = batch[\"google_img\"].to(trainer.device)\n", - " yandex_img = batch[\"yandex_img\"].to(trainer.device)\n", - " target_params = batch[\"homography_params\"].to(trainer.device)\n", - " pred_params = trainer.model(google_img, yandex_img)\n", - " \n", - " for j in range(6):\n", - " all_errors[j].append(torch.abs(pred_params[0, j] - target_params[0, j]).item())\n", - " all_targets[j].append(target_params[0, j].item())\n", - " all_preds[j].append(pred_params[0, j].item())\n", - " \n", - " mean_errors = [np.mean(all_errors[i]) for i in range(6)]\n", - " std_errors = [np.std(all_errors[i]) for i in range(6)]\n", - " \n", - " if len(trainer.train_losses) > 0:\n", - " epochs = range(1, len(trainer.train_losses) + 1)\n", - " fig, axes = plt.subplots(2, 2, figsize=(16, 12))\n", - " \n", - " axes[0, 0].plot(epochs, trainer.train_losses, \"b-\", label=\"Train Loss\")\n", - " axes[0, 0].plot(epochs, trainer.val_losses, \"r-\", label=\"Val Loss\")\n", - " axes[0, 0].set_xlabel(\"Epoch\")\n", - " axes[0, 0].set_ylabel(\"Loss\")\n", - " axes[0, 0].set_title(\"Training & Validation Loss\")\n", - " axes[0, 0].legend()\n", - " axes[0, 0].grid(True, alpha=0.3)\n", - " \n", - " axes[0, 1].plot(epochs, trainer.val_losses, \"r-\", label=\"Val Loss\")\n", - " axes[0, 1].set_xlabel(\"Epoch\")\n", - " axes[0, 1].set_ylabel(\"Loss\")\n", - " axes[0, 1].set_title(\"Validation Loss\")\n", - " axes[0, 1].legend()\n", - " axes[0, 1].grid(True, alpha=0.3)\n", - " \n", - " axes[1, 0].plot(epochs, trainer.val_mse_trans, \"g-\", label=\"Translation (tx, ty)\")\n", - " axes[1, 0].plot(epochs, trainer.val_mse_angle, \"m-\", label=\"Angle (rx, ry, rz)\")\n", - " axes[1, 0].plot(epochs, trainer.val_mse_scale, \"c-\", label=\"Scale\")\n", - " axes[1, 0].set_xlabel(\"Epoch\")\n", - " axes[1, 0].set_ylabel(\"MSE\")\n", - " axes[1, 0].set_title(\"Validation MSE by Category\")\n", - " axes[1, 0].legend()\n", - " axes[1, 0].grid(True, alpha=0.3)\n", - " \n", - " x_pos = np.arange(6)\n", - " axes[1, 1].bar(x_pos, mean_errors, yerr=std_errors, capsize=5, color=[\"c\", \"m\", \"y\", \"g\", \"b\", \"r\"], alpha=0.8)\n", - " axes[1, 1].set_xticks(x_pos)\n", - " axes[1, 1].set_xticklabels(names)\n", - " axes[1, 1].set_ylabel(\"Mean Absolute Error\")\n", - " axes[1, 1].set_title(f\"Mean Absolute Error per Parameter ({n_samples} samples)\")\n", - " axes[1, 1].grid(True, alpha=0.3, axis=\"y\")\n", - " \n", - " plt.tight_layout()\n", - " plt.savefig(os.path.join(IMG_DIR, \"training_loss_plots.png\"), dpi=150)\n", - " print(\"Saved training_loss_plots.png\")\n", - " plt.show()\n", - " \n", - " fig, axes = plt.subplots(2, 3, figsize=(18, 10))\n", - " for j in range(6):\n", - " row = j // 3\n", - " col = j % 3\n", - " axes[row, col].bar(range(n_samples), all_errors[j], color=\"steelblue\", alpha=0.7)\n", - " axes[row, col].set_xlabel(\"Sample\")\n", - " axes[row, col].set_ylabel(\"Absolute Error\")\n", - " axes[row, col].set_title(f\"{names[j]}: Mean={np.mean(all_errors[j]):.4f}, Std={np.std(all_errors[j]):.4f}\")\n", - " axes[row, col].grid(True, alpha=0.3, axis=\"y\")\n", - " plt.suptitle(f\"Mean Absolute Error per Parameter ({n_samples} samples)\", fontsize=14)\n", - " plt.tight_layout()\n", - " plt.savefig(os.path.join(IMG_DIR, \"mae_per_parameter.png\"), dpi=150)\n", - " print(\"Saved mae_per_parameter.png\")\n", - " plt.show()\n", - " \n", - " fig, axes = plt.subplots(1, 2, figsize=(14, 6))\n", - " \n", - " x_pos = np.arange(6)\n", - " axes[0].bar(x_pos, mean_errors, yerr=std_errors, capsize=5, color=[\"c\", \"m\", \"y\", \"g\", \"b\", \"r\"], alpha=0.8)\n", - " axes[0].set_xticks(x_pos)\n", - " axes[0].set_xticklabels(names)\n", - " axes[0].set_ylabel(\"Mean Absolute Error\")\n", - " axes[0].set_title(\"Mean Absolute Error per Parameter (with std)\")\n", - " axes[0].grid(True, alpha=0.3, axis=\"y\")\n", - " \n", - " bp = axes[1].boxplot([all_errors[i] for i in range(6)], labels=names, patch_artist=True)\n", - " colors = [\"c\", \"m\", \"y\", \"g\", \"b\", \"r\"]\n", - " for patch, color in zip(bp[\"boxes\"], colors):\n", - " patch.set_facecolor(color)\n", - " patch.set_alpha(0.7)\n", - " axes[1].set_ylabel(\"Absolute Error\")\n", - " axes[1].set_title(f\"Error Distribution per Parameter ({n_samples} samples)\")\n", - " axes[1].grid(True, alpha=0.3, axis=\"y\")\n", - " \n", - " plt.tight_layout()\n", - " plt.savefig(os.path.join(IMG_DIR, \"mae_boxplot.png\"), dpi=150)\n", - " print(\"Saved mae_boxplot.png\")\n", - " plt.show()\n", - " \n", - " print(\"\\n=== Sample Predictions (20 pairs) ===\")\n", - " n_vis_samples = 20\n", - " \n", - " with torch.no_grad():\n", - " for sample_idx in range(n_vis_samples):\n", - " try:\n", - " batch = next(iter(trainer.val_loader))\n", - " except StopIteration:\n", - " break\n", - " google_img = batch[\"google_img\"].to(trainer.device)\n", - " yandex_img = batch[\"yandex_img\"].to(trainer.device)\n", - " target_params = batch[\"homography_params\"].to(trainer.device)\n", - " pred_params = trainer.model(google_img, yandex_img)\n", - " \n", - " errors = torch.abs(pred_params[0] - target_params[0]).cpu().numpy()\n", - " targets = target_params[0].cpu().numpy()\n", - " preds = pred_params[0].cpu().numpy()\n", - " \n", - " fig, axes = plt.subplots(2, 2, figsize=(12, 10))\n", - " \n", - " axes[0, 0].imshow(google_img[0].cpu().permute(1, 2, 0))\n", - " axes[0, 0].set_title(f\"Google Image\")\n", - " axes[0, 0].axis(\"off\")\n", - " \n", - " axes[0, 1].imshow(yandex_img[0].cpu().permute(1, 2, 0))\n", - " axes[0, 1].set_title(f\"Yandex Image\")\n", - " axes[0, 1].axis(\"off\")\n", - " \n", - " x_pos = np.arange(6)\n", - " width = 0.35\n", - " axes[1, 0].bar(x_pos - width/2, targets, width, label=\"Target\", color=\"steelblue\", alpha=0.8)\n", - " axes[1, 0].bar(x_pos + width/2, preds, width, label=\"Predicted\", color=\"coral\", alpha=0.8)\n", - " axes[1, 0].set_xticks(x_pos)\n", - " axes[1, 0].set_xticklabels(names)\n", - " axes[1, 0].set_ylabel(\"Parameter Value\")\n", - " axes[1, 0].set_title(\"Target vs Predicted\")\n", - " axes[1, 0].legend()\n", - " axes[1, 0].grid(True, alpha=0.3, axis=\"y\")\n", - " \n", - " axes[1, 1].bar(x_pos, errors, color=[\"c\", \"m\", \"y\", \"g\", \"b\", \"r\"], alpha=0.8)\n", - " axes[1, 1].set_xticks(x_pos)\n", - " axes[1, 1].set_xticklabels(names)\n", - " axes[1, 1].set_ylabel(\"Absolute Error\")\n", - " axes[1, 1].set_title(f\"Prediction Error (Mean: {np.mean(errors):.4f})\")\n", - " axes[1, 1].grid(True, alpha=0.3, axis=\"y\")\n", - " for i, e in enumerate(errors):\n", - " axes[1, 1].text(i, e + 0.01, f\"{e:.3f}\", ha=\"center\", va=\"bottom\", fontsize=8)\n", - " \n", - " plt.suptitle(f\"Sample {sample_idx + 1}\", fontsize=14)\n", - " plt.tight_layout()\n", - " plt.savefig(os.path.join(IMG_DIR, f\"prediction_sample_{sample_idx + 1:02d}.png\"), dpi=100)\n", - " plt.show()\n", - " print(f\"Saved prediction_sample_{sample_idx + 1:02d}.png\")\n", - " \n", - " print(f\"\\nPrediction errors over {n_samples} samples:\")\n", - " print(f\"{'Param':<8} {'Mean Error':>12} {'Std Error':>12} {'Min':>8} {'Max':>8}\")\n", - " print(\"-\" * 52)\n", - " for i in range(6):\n", - " mean_err = np.mean(all_errors[i])\n", - " std_err = np.std(all_errors[i])\n", - " min_err = np.min(all_errors[i])\n", - " max_err = np.max(all_errors[i])\n", - " print(f\"{names[i]:<8} {mean_err:>12.4f} {std_err:>12.4f} {min_err:>8.4f} {max_err:>8.4f}\")\n", - "\n", - " return {\n", - " \"best_val_loss\": trainer.best_val_loss,\n", - " \"train_losses\": trainer.train_losses,\n", - " \"val_losses\": trainer.val_losses,\n", - " \"val_mse_trans\": trainer.val_mse_trans,\n", - " \"val_mse_angle\": trainer.val_mse_angle,\n", - " \"val_mse_scale\": trainer.val_mse_scale,\n", - " }\n", - "\n" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Main Pipeline\n", - "\n", - "Executes the full training workflow:\n", - "1. Load dataset info\n", - "2. Create data loaders\n", - "3. Initialize model\n", - "4. Train with validation\n", - "5. Analyze and export results\n", - "\n", - "**Outputs:**\n", - "- Model checkpoints in `runs/checkpoints/`\n", - "- TensorBoard logs in `runs/`\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "2026-04-05 12:35:20,728 - ==================================================\n", - "2026-04-05 12:35:20,730 - SiaN Training Pipeline\n", - "2026-04-05 12:35:20,731 - ==================================================\n", - "2026-04-05 12:35:32,423 - Dataset: 327 samples, keys=['google_img', 'yandex_img', 'homography_matrix', 'homography_params']\n", - "2026-04-05 12:35:54,074 - Data loaders created: train=261, val=66\n", - "2026-04-05 12:35:54,366 - Model created with 12,358,470 parameters\n", - "2026-04-05 12:35:54,368 - Using device: cpu\n", - "2026-04-05 12:35:54,374 - Starting training...\n", - "Epoch 1: 0%| | 0/9 [00:00 \u001b[39m\u001b[32m33\u001b[39m \u001b[43mtrainer\u001b[49m\u001b[43m.\u001b[49m\u001b[43mtrain\u001b[49m\u001b[43m(\u001b[49m\u001b[43mconfig\u001b[49m\u001b[43m[\u001b[49m\u001b[33;43m\"\u001b[39;49m\u001b[33;43mepochs\u001b[39;49m\u001b[33;43m\"\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 34\u001b[39m logger.info(\u001b[33m\"\u001b[39m\u001b[33mTraining completed\u001b[39m\u001b[33m\"\u001b[39m)\n\u001b[32m 36\u001b[39m logger.info(\u001b[33m\"\u001b[39m\u001b[33mAnalyzing model...\u001b[39m\u001b[33m\"\u001b[39m)\n", - "\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[5]\u001b[39m\u001b[32m, line 81\u001b[39m, in \u001b[36mHomographyTrainer.train\u001b[39m\u001b[34m(self, num_epochs)\u001b[39m\n\u001b[32m 78\u001b[39m \u001b[38;5;28mself\u001b[39m.writer = SummaryWriter(log_dir)\n\u001b[32m 80\u001b[39m \u001b[38;5;28;01mfor\u001b[39;00m epoch \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mrange\u001b[39m(\u001b[32m1\u001b[39m, num_epochs + \u001b[32m1\u001b[39m):\n\u001b[32m---> \u001b[39m\u001b[32m81\u001b[39m train_metrics = \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43mtrain_epoch\u001b[49m\u001b[43m(\u001b[49m\u001b[43mepoch\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 82\u001b[39m val_metrics = \u001b[38;5;28mself\u001b[39m.validate()\n\u001b[32m 83\u001b[39m \u001b[38;5;28mself\u001b[39m.train_losses.append(train_metrics[\u001b[33m\"\u001b[39m\u001b[33mloss\u001b[39m\u001b[33m\"\u001b[39m])\n", - "\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[5]\u001b[39m\u001b[32m, line 31\u001b[39m, in \u001b[36mHomographyTrainer.train_epoch\u001b[39m\u001b[34m(self, epoch)\u001b[39m\n\u001b[32m 28\u001b[39m target = batch[\u001b[33m\"\u001b[39m\u001b[33mhomography_params\u001b[39m\u001b[33m\"\u001b[39m].to(\u001b[38;5;28mself\u001b[39m.device)\n\u001b[32m 30\u001b[39m \u001b[38;5;28mself\u001b[39m.optimizer.zero_grad()\n\u001b[32m---> \u001b[39m\u001b[32m31\u001b[39m output = \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43mmodel\u001b[49m\u001b[43m(\u001b[49m\u001b[43mgoogle_img\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43myandex_img\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 32\u001b[39m loss = \u001b[38;5;28mself\u001b[39m.criterion(output, target)\n\u001b[32m 33\u001b[39m loss.backward()\n", - "\u001b[36mFile \u001b[39m\u001b[32mc:\\Users\\admin\\Projects\\autopilot\\.venv\\Lib\\site-packages\\torch\\nn\\modules\\module.py:1776\u001b[39m, in \u001b[36mModule._wrapped_call_impl\u001b[39m\u001b[34m(self, *args, **kwargs)\u001b[39m\n\u001b[32m 1774\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m._compiled_call_impl(*args, **kwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[32m 1775\u001b[39m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[32m-> \u001b[39m\u001b[32m1776\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[43m*\u001b[49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m*\u001b[49m\u001b[43m*\u001b[49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", - "\u001b[36mFile \u001b[39m\u001b[32mc:\\Users\\admin\\Projects\\autopilot\\.venv\\Lib\\site-packages\\torch\\nn\\modules\\module.py:1787\u001b[39m, in \u001b[36mModule._call_impl\u001b[39m\u001b[34m(self, *args, **kwargs)\u001b[39m\n\u001b[32m 1782\u001b[39m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[32m 1783\u001b[39m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[32m 1784\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m._backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m._backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m._forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m._forward_pre_hooks\n\u001b[32m 1785\u001b[39m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[32m 1786\u001b[39m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[32m-> \u001b[39m\u001b[32m1787\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[43m*\u001b[49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m*\u001b[49m\u001b[43m*\u001b[49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 1789\u001b[39m result = \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[32m 1790\u001b[39m called_always_called_hooks = \u001b[38;5;28mset\u001b[39m()\n", - "\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[4]\u001b[39m\u001b[32m, line 20\u001b[39m, in \u001b[36mHomographyCNN6.forward\u001b[39m\u001b[34m(self, img1, img2)\u001b[39m\n\u001b[32m 19\u001b[39m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34mforward\u001b[39m(\u001b[38;5;28mself\u001b[39m, img1, img2):\n\u001b[32m---> \u001b[39m\u001b[32m20\u001b[39m f1 = \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43mbackbone\u001b[49m\u001b[43m(\u001b[49m\u001b[43mimg1\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 21\u001b[39m f2 = \u001b[38;5;28mself\u001b[39m.backbone(img2)\n\u001b[32m 22\u001b[39m combined = torch.cat([f1, f2, torch.abs(f1 - f2), f1 * f2], dim=\u001b[32m1\u001b[39m)\n", - "\u001b[36mFile \u001b[39m\u001b[32mc:\\Users\\admin\\Projects\\autopilot\\.venv\\Lib\\site-packages\\torch\\nn\\modules\\module.py:1776\u001b[39m, in \u001b[36mModule._wrapped_call_impl\u001b[39m\u001b[34m(self, *args, **kwargs)\u001b[39m\n\u001b[32m 1774\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m._compiled_call_impl(*args, **kwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[32m 1775\u001b[39m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[32m-> \u001b[39m\u001b[32m1776\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[43m*\u001b[49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m*\u001b[49m\u001b[43m*\u001b[49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", - "\u001b[36mFile \u001b[39m\u001b[32mc:\\Users\\admin\\Projects\\autopilot\\.venv\\Lib\\site-packages\\torch\\nn\\modules\\module.py:1787\u001b[39m, in \u001b[36mModule._call_impl\u001b[39m\u001b[34m(self, *args, **kwargs)\u001b[39m\n\u001b[32m 1782\u001b[39m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[32m 1783\u001b[39m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[32m 1784\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m._backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m._backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m._forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m._forward_pre_hooks\n\u001b[32m 1785\u001b[39m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[32m 1786\u001b[39m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[32m-> \u001b[39m\u001b[32m1787\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[43m*\u001b[49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m*\u001b[49m\u001b[43m*\u001b[49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 1789\u001b[39m result = \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[32m 1790\u001b[39m called_always_called_hooks = \u001b[38;5;28mset\u001b[39m()\n", - "\u001b[36mFile \u001b[39m\u001b[32mc:\\Users\\admin\\Projects\\autopilot\\.venv\\Lib\\site-packages\\torchvision\\models\\resnet.py:285\u001b[39m, in \u001b[36mResNet.forward\u001b[39m\u001b[34m(self, x)\u001b[39m\n\u001b[32m 284\u001b[39m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34mforward\u001b[39m(\u001b[38;5;28mself\u001b[39m, x: Tensor) -> Tensor:\n\u001b[32m--> \u001b[39m\u001b[32m285\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43m_forward_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[43mx\u001b[49m\u001b[43m)\u001b[49m\n", - "\u001b[36mFile \u001b[39m\u001b[32mc:\\Users\\admin\\Projects\\autopilot\\.venv\\Lib\\site-packages\\torchvision\\models\\resnet.py:274\u001b[39m, in \u001b[36mResNet._forward_impl\u001b[39m\u001b[34m(self, x)\u001b[39m\n\u001b[32m 271\u001b[39m x = \u001b[38;5;28mself\u001b[39m.maxpool(x)\n\u001b[32m 273\u001b[39m x = \u001b[38;5;28mself\u001b[39m.layer1(x)\n\u001b[32m--> \u001b[39m\u001b[32m274\u001b[39m x = \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43mlayer2\u001b[49m\u001b[43m(\u001b[49m\u001b[43mx\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 275\u001b[39m x = \u001b[38;5;28mself\u001b[39m.layer3(x)\n\u001b[32m 276\u001b[39m x = \u001b[38;5;28mself\u001b[39m.layer4(x)\n", - "\u001b[36mFile \u001b[39m\u001b[32mc:\\Users\\admin\\Projects\\autopilot\\.venv\\Lib\\site-packages\\torch\\nn\\modules\\module.py:1776\u001b[39m, in \u001b[36mModule._wrapped_call_impl\u001b[39m\u001b[34m(self, *args, **kwargs)\u001b[39m\n\u001b[32m 1774\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m._compiled_call_impl(*args, **kwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[32m 1775\u001b[39m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[32m-> \u001b[39m\u001b[32m1776\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[43m*\u001b[49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m*\u001b[49m\u001b[43m*\u001b[49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", - "\u001b[36mFile \u001b[39m\u001b[32mc:\\Users\\admin\\Projects\\autopilot\\.venv\\Lib\\site-packages\\torch\\nn\\modules\\module.py:1787\u001b[39m, in \u001b[36mModule._call_impl\u001b[39m\u001b[34m(self, *args, **kwargs)\u001b[39m\n\u001b[32m 1782\u001b[39m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[32m 1783\u001b[39m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[32m 1784\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m._backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m._backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m._forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m._forward_pre_hooks\n\u001b[32m 1785\u001b[39m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[32m 1786\u001b[39m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[32m-> \u001b[39m\u001b[32m1787\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[43m*\u001b[49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m*\u001b[49m\u001b[43m*\u001b[49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 1789\u001b[39m result = \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[32m 1790\u001b[39m called_always_called_hooks = \u001b[38;5;28mset\u001b[39m()\n", - "\u001b[36mFile \u001b[39m\u001b[32mc:\\Users\\admin\\Projects\\autopilot\\.venv\\Lib\\site-packages\\torch\\nn\\modules\\container.py:253\u001b[39m, in \u001b[36mSequential.forward\u001b[39m\u001b[34m(self, input)\u001b[39m\n\u001b[32m 249\u001b[39m \u001b[38;5;250m\u001b[39m\u001b[33;03m\"\"\"\u001b[39;00m\n\u001b[32m 250\u001b[39m \u001b[33;03mRuns the forward pass.\u001b[39;00m\n\u001b[32m 251\u001b[39m \u001b[33;03m\"\"\"\u001b[39;00m\n\u001b[32m 252\u001b[39m \u001b[38;5;28;01mfor\u001b[39;00m module \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mself\u001b[39m:\n\u001b[32m--> \u001b[39m\u001b[32m253\u001b[39m \u001b[38;5;28minput\u001b[39m = \u001b[43mmodule\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m)\u001b[49m\n\u001b[32m 254\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28minput\u001b[39m\n", - "\u001b[36mFile \u001b[39m\u001b[32mc:\\Users\\admin\\Projects\\autopilot\\.venv\\Lib\\site-packages\\torch\\nn\\modules\\module.py:1776\u001b[39m, in \u001b[36mModule._wrapped_call_impl\u001b[39m\u001b[34m(self, *args, **kwargs)\u001b[39m\n\u001b[32m 1774\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m._compiled_call_impl(*args, **kwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[32m 1775\u001b[39m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[32m-> \u001b[39m\u001b[32m1776\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[43m*\u001b[49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m*\u001b[49m\u001b[43m*\u001b[49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", - "\u001b[36mFile \u001b[39m\u001b[32mc:\\Users\\admin\\Projects\\autopilot\\.venv\\Lib\\site-packages\\torch\\nn\\modules\\module.py:1787\u001b[39m, in \u001b[36mModule._call_impl\u001b[39m\u001b[34m(self, *args, **kwargs)\u001b[39m\n\u001b[32m 1782\u001b[39m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[32m 1783\u001b[39m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[32m 1784\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m._backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m._backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m._forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m._forward_pre_hooks\n\u001b[32m 1785\u001b[39m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[32m 1786\u001b[39m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[32m-> \u001b[39m\u001b[32m1787\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[43m*\u001b[49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m*\u001b[49m\u001b[43m*\u001b[49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 1789\u001b[39m result = \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[32m 1790\u001b[39m called_always_called_hooks = \u001b[38;5;28mset\u001b[39m()\n", - "\u001b[36mFile \u001b[39m\u001b[32mc:\\Users\\admin\\Projects\\autopilot\\.venv\\Lib\\site-packages\\torchvision\\models\\resnet.py:96\u001b[39m, in \u001b[36mBasicBlock.forward\u001b[39m\u001b[34m(self, x)\u001b[39m\n\u001b[32m 93\u001b[39m out = \u001b[38;5;28mself\u001b[39m.bn1(out)\n\u001b[32m 94\u001b[39m out = \u001b[38;5;28mself\u001b[39m.relu(out)\n\u001b[32m---> \u001b[39m\u001b[32m96\u001b[39m out = \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43mconv2\u001b[49m\u001b[43m(\u001b[49m\u001b[43mout\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 97\u001b[39m out = \u001b[38;5;28mself\u001b[39m.bn2(out)\n\u001b[32m 99\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m.downsample \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n", - "\u001b[36mFile \u001b[39m\u001b[32mc:\\Users\\admin\\Projects\\autopilot\\.venv\\Lib\\site-packages\\torch\\nn\\modules\\module.py:1776\u001b[39m, in \u001b[36mModule._wrapped_call_impl\u001b[39m\u001b[34m(self, *args, **kwargs)\u001b[39m\n\u001b[32m 1774\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m._compiled_call_impl(*args, **kwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[32m 1775\u001b[39m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[32m-> \u001b[39m\u001b[32m1776\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[43m*\u001b[49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m*\u001b[49m\u001b[43m*\u001b[49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", - "\u001b[36mFile \u001b[39m\u001b[32mc:\\Users\\admin\\Projects\\autopilot\\.venv\\Lib\\site-packages\\torch\\nn\\modules\\module.py:1787\u001b[39m, in \u001b[36mModule._call_impl\u001b[39m\u001b[34m(self, *args, **kwargs)\u001b[39m\n\u001b[32m 1782\u001b[39m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[32m 1783\u001b[39m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[32m 1784\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m._backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m._backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m._forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m._forward_pre_hooks\n\u001b[32m 1785\u001b[39m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[32m 1786\u001b[39m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[32m-> \u001b[39m\u001b[32m1787\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[43m*\u001b[49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m*\u001b[49m\u001b[43m*\u001b[49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 1789\u001b[39m result = \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[32m 1790\u001b[39m called_always_called_hooks = \u001b[38;5;28mset\u001b[39m()\n", - "\u001b[36mFile \u001b[39m\u001b[32mc:\\Users\\admin\\Projects\\autopilot\\.venv\\Lib\\site-packages\\torch\\nn\\modules\\conv.py:553\u001b[39m, in \u001b[36mConv2d.forward\u001b[39m\u001b[34m(self, input)\u001b[39m\n\u001b[32m 552\u001b[39m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34mforward\u001b[39m(\u001b[38;5;28mself\u001b[39m, \u001b[38;5;28minput\u001b[39m: Tensor) -> Tensor:\n\u001b[32m--> \u001b[39m\u001b[32m553\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43m_conv_forward\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43mweight\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43mbias\u001b[49m\u001b[43m)\u001b[49m\n", - "\u001b[36mFile \u001b[39m\u001b[32mc:\\Users\\admin\\Projects\\autopilot\\.venv\\Lib\\site-packages\\torch\\nn\\modules\\conv.py:548\u001b[39m, in \u001b[36mConv2d._conv_forward\u001b[39m\u001b[34m(self, input, weight, bias)\u001b[39m\n\u001b[32m 535\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m.padding_mode != \u001b[33m\"\u001b[39m\u001b[33mzeros\u001b[39m\u001b[33m\"\u001b[39m:\n\u001b[32m 536\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m F.conv2d(\n\u001b[32m 537\u001b[39m F.pad(\n\u001b[32m 538\u001b[39m \u001b[38;5;28minput\u001b[39m, \u001b[38;5;28mself\u001b[39m._reversed_padding_repeated_twice, mode=\u001b[38;5;28mself\u001b[39m.padding_mode\n\u001b[32m (...)\u001b[39m\u001b[32m 545\u001b[39m \u001b[38;5;28mself\u001b[39m.groups,\n\u001b[32m 546\u001b[39m )\n\u001b[32m--> \u001b[39m\u001b[32m548\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mF\u001b[49m\u001b[43m.\u001b[49m\u001b[43mconv2d\u001b[49m\u001b[43m(\u001b[49m\n\u001b[32m 549\u001b[39m \u001b[43m \u001b[49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mweight\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mbias\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43mstride\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43mpadding\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43mdilation\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43mgroups\u001b[49m\n\u001b[32m 550\u001b[39m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n", - "\u001b[31mKeyboardInterrupt\u001b[39m: " - ] - } - ], - "source": [ - "\n", - "\n", - "\n", - "\n", - "logging.basicConfig(level=logging.INFO, format=\"%(asctime)s - %(message)s\")\n", - "logger = logging.getLogger(__name__)\n", - "\n", - "logger.info(\"=\" * 50)\n", - "logger.info(\"SiaN Training Pipeline\")\n", - "logger.info(\"=\" * 50)\n", - "\n", - "dataset_info = get_dataset_info()\n", - "logger.info(f\"Dataset: {dataset_info['size']} samples, keys={dataset_info['sample_keys']}\")\n", - "\n", - "train_loader, val_loader = create_data_loaders(\n", - " root_dir=config[\"data_dir\"],\n", - " batch_size=config[\"batch_size\"],\n", - " train_split=config[\"train_split\"],\n", - " num_workers=config[\"num_workers\"],\n", - " image_size=config[\"image_size\"],\n", - ")\n", - "logger.info(f\"Data loaders created: train={len(train_loader.dataset)}, val={len(val_loader.dataset)}\")\n", - "\n", - "model = HomographyCNN6(\n", - " input_channels=3,\n", - " backbone_name=config[\"backbone\"],\n", - " pretrained=True,\n", - " dropout_rate=config[\"dropout_rate\"]\n", - ")\n", - "logger.info(f\"Model created with {count_parameters(model):,} parameters\")\n", - "\n", - "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", - "logger.info(f\"Using device: {device}\")\n", - "\n", - "trainer = HomographyTrainer(model, train_loader, val_loader, device)\n", - "logger.info(\"Starting training...\")\n", - "trainer.train(config[\"epochs\"])\n", - "logger.info(\"Training completed\")\n", - "\n", - "logger.info(\"Analyzing model...\")\n", - "results = analyze_training(trainer)\n", - "logger.info(f\"Analysis complete: best_val_loss={results['best_val_loss']:.4f}\")\n", - "\n", - "logger.info(\"=\" * 50)\n", - "logger.info(\"Pipeline completed successfully\")\n", - "logger.info(\"=\" * 50)\n", - "\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "!zip artefacts.zip runs/checkpoints/best_model.pt runs/images/ runs/events.*\n", - "\n" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": ".venv", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.11.0" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "import random\n", + "import logging\n", + "from typing import Tuple\n", + "import cv2\n", + "import numpy as np\n", + "import torch\n", + "import torch.nn as nn\n", + "import torch.optim as optim\n", + "import matplotlib.pyplot as plt\n", + "from PIL import Image\n", + "from torch.utils.data import DataLoader, Dataset, Subset\n", + "from torch.utils.tensorboard import SummaryWriter\n", + "from torchvision import transforms, models\n", + "from tqdm import tqdm\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": "# Configuration\n\nGlobal settings for:\n- Data paths and image parameters\n- Training hyperparameters\n- Model architecture options\n" + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "\n", + "config = {\n", + " \"data_dir\": r\"C:\\Users\\admin\\Projects\\autopilot\\datasets\\ya_go_maps\\images\",\n", + " \"image_size\": (256, 256),\n", + " \"batch_size\": 32,\n", + " \"train_split\": 0.8,\n", + " \"num_workers\": 0,\n", + " \"epochs\": 100,\n", + " \"learning_rate\": 2e-4,\n", + " \"dropout_rate\": 0.3,\n", + " \"backbone\": \"resnet18\",\n", + " \"output_dir\": r\"C:\\Users\\admin\\Projects\\autopilot\\models\\SiaN\\runs\",\n", + " \"save_every_n_epochs\": 15,\n", + "}\n", + "\n", + "\n", + "def get_camera_matrix(w, h):\n", + " return np.array([[w / 2, 0, w / 2], [0, h / 2, h / 2], [0, 0, 1]], dtype=np.float32)\n", + "\n", + "\n", + "def generate_random_homography_params(angle_range=10, translation_range=0.1, scale_range=(0.9, 1.1)):\n", + " scale = np.random.uniform(*scale_range)\n", + " tx = np.random.uniform(-translation_range, translation_range)\n", + " ty = np.random.uniform(-translation_range, translation_range)\n", + " rx = np.radians(np.random.uniform(-angle_range, angle_range))\n", + " ry = np.radians(np.random.uniform(-angle_range, angle_range))\n", + " rz = np.radians(np.random.uniform(-angle_range, angle_range))\n", + " return np.array([rx, ry, rz, tx, ty, scale])\n", + "\n", + "\n", + "def homography_params_to_matrix(params, K):\n", + " rx, ry, rz, tx, ty, scale = params\n", + " cy, sy = np.cos(rz), np.sin(rz)\n", + " cp, sp = np.cos(ry), np.sin(ry)\n", + " cr, sr = np.cos(rx), np.sin(rx)\n", + " Rz = np.array([[cy, -sy, 0], [sy, cy, 0], [0, 0, 1]], dtype=np.float32)\n", + " Ry = np.array([[cp, 0, sp], [0, 1, 0], [-sp, 0, cp]], dtype=np.float32)\n", + " Rx = np.array([[1, 0, 0], [0, cr, -sr], [0, sr, cr]], dtype=np.float32)\n", + " T = np.array([[1, 0, tx], [0, 1, ty], [0, 0, scale]], dtype=np.float32)\n", + " return K @ Rx @ Ry @ Rz @ T @ np.linalg.inv(K)\n", + "\n", + "\n", + "def matrix_to_homography_params(H, K):\n", + " K_inv = np.linalg.inv(K)\n", + " E = K_inv @ H @ K\n", + " scale = np.sqrt(np.linalg.det(E[:2, :2]))\n", + " R = E[:2, :2] / scale\n", + " tx, ty = E[0, 2], E[1, 2]\n", + " rz = np.arctan2(R[1, 0], R[0, 0])\n", + " r20, r21 = E[2, 0], E[2, 1]\n", + " ry = np.arctan2(r20, r21)\n", + " rx = np.arctan2(-E[1, 2], E[1, 1])\n", + " return np.array([rx, ry, rz, tx, ty, scale], dtype=np.float32)\n", + "\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": "## Dataset\n\nGoogle/Yandex image pair loader with homography augmentation.\n\n**Features:**\n- Loads paired images from dual camera sources\n- Applies random homography transformations\n- Supports configurable train/val split\n\n**Returns:**\n" + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "\n", + "\n", + "\n", + "class YaGoDataset(Dataset):\n", + " def __init__(self, root_dir: str, transform=None, augment: bool = True, \n", + " image_size: Tuple[int, int] = (256, 256), cache_level: int = 5):\n", + " self.root_dir = root_dir\n", + " self.transform = transform\n", + " self.augment = augment\n", + " self.image_size = image_size\n", + " self.cache_level = cache_level\n", + " self.K = get_camera_matrix(image_size[1], image_size[0])\n", + " self.image_pairs = self._discover_image_pairs()\n", + " self._load_images_to_memory()\n", + " self._init_cache()\n", + "\n", + " def _discover_image_pairs(self):\n", + " pairs = []\n", + " for f in os.listdir(self.root_dir):\n", + " if f.endswith(\"_google.png\"):\n", + " idx = f.split(\"_\")[0]\n", + " yandex_path = os.path.join(self.root_dir, f\"{idx}_yandex.png\")\n", + " if os.path.exists(yandex_path):\n", + " pairs.append({\"idx\": int(idx), \"google\": os.path.join(self.root_dir, f), \"yandex\": yandex_path})\n", + " return sorted(pairs, key=lambda x: x[\"idx\"])\n", + "\n", + " def _load_images_to_memory(self):\n", + " self._google_images = []\n", + " self._yandex_images = []\n", + " for pair in self.image_pairs:\n", + " google_img = cv2.imread(pair[\"google\"])\n", + " google_img = cv2.cvtColor(google_img, cv2.COLOR_BGR2RGB)\n", + " google_img = cv2.resize(google_img, (self.image_size[1], self.image_size[0]), interpolation=cv2.INTER_LINEAR)\n", + " \n", + " yandex_img = cv2.imread(pair[\"yandex\"])\n", + " yandex_img = cv2.cvtColor(yandex_img, cv2.COLOR_BGR2RGB)\n", + " yandex_img = cv2.resize(yandex_img, (self.image_size[1], self.image_size[0]), interpolation=cv2.INTER_LINEAR)\n", + " \n", + " self._google_images.append(google_img)\n", + " self._yandex_images.append(yandex_img)\n", + "\n", + " def _init_cache(self):\n", + " self._access_counts = [0] * len(self.image_pairs)\n", + " self._cached_google = [None] * len(self.image_pairs)\n", + " self._cached_yandex = [None] * len(self.image_pairs)\n", + " self._cached_homography = [None] * len(self.image_pairs)\n", + "\n", + " def _generate_augmented(self, idx):\n", + " google_img = self._google_images[idx].copy()\n", + " yandex_img = self._yandex_images[idx].copy()\n", + "\n", + " params1 = generate_random_homography_params()\n", + " params2 = generate_random_homography_params()\n", + " H1 = homography_params_to_matrix(params1, self.K)\n", + " H2 = homography_params_to_matrix(params2, self.K)\n", + " H_combined = np.linalg.inv(H1) @ H2\n", + " \n", + " google_warped = cv2.warpPerspective(google_img, H2, (self.image_size[1], self.image_size[0]))\n", + " yandex_warped = cv2.warpPerspective(yandex_img, H1, (self.image_size[1], self.image_size[0]))\n", + " \n", + " target_params = matrix_to_homography_params(H_combined, self.K)\n", + " \n", + " return google_warped, yandex_warped, H_combined, target_params\n", + "\n", + " def __len__(self):\n", + " return len(self.image_pairs)\n", + "\n", + " def __getitem__(self, idx):\n", + " self._access_counts[idx] += 1\n", + " \n", + " use_cache = self.augment and self.cache_level > 0 and self._access_counts[idx] > 1 and (self._access_counts[idx] - 1) % self.cache_level != 0\n", + " \n", + " if use_cache:\n", + " google_img = self._cached_google[idx]\n", + " yandex_img = self._cached_yandex[idx]\n", + " target_matrix = self._cached_homography[idx]\n", + " target_params = matrix_to_homography_params(target_matrix, self.K)\n", + " elif self.augment:\n", + " google_img, yandex_img, target_matrix, target_params = self._generate_augmented(idx)\n", + " if self.cache_level > 0:\n", + " self._cached_google[idx] = google_img\n", + " self._cached_yandex[idx] = yandex_img\n", + " self._cached_homography[idx] = target_matrix\n", + " else:\n", + " google_img = self._google_images[idx]\n", + " yandex_img = self._yandex_images[idx]\n", + " target_params = np.zeros(6, dtype=np.float32)\n", + " target_matrix = np.eye(3, dtype=np.float32)\n", + "\n", + " google_img = Image.fromarray(google_img)\n", + " yandex_img = Image.fromarray(yandex_img)\n", + "\n", + " if self.transform:\n", + " google_img = self.transform(google_img)\n", + " yandex_img = self.transform(yandex_img)\n", + "\n", + " return {\n", + " \"google_img\": google_img,\n", + " \"yandex_img\": yandex_img,\n", + " \"homography_matrix\": torch.from_numpy(target_matrix).float(),\n", + " \"homography_params\": torch.from_numpy(target_params).float(),\n", + " }\n", + "\n", + "\n", + "def create_data_loaders(root_dir, batch_size=32, train_split=0.8, num_workers=0, \n", + " image_size=(256, 256), augment_train=True, cache_level=5):\n", + " transform = transforms.Compose([\n", + " transforms.ToTensor(),\n", + " # transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),\n", + " ])\n", + " \n", + " full_ds = YaGoDataset(root_dir, transform=transform, augment=False, image_size=image_size, cache_level=cache_level)\n", + " aug_ds = YaGoDataset(root_dir, transform=transform, augment=True, image_size=image_size, cache_level=cache_level)\n", + "\n", + " indices = list(range(len(full_ds)))\n", + " random.shuffle(indices)\n", + " split = int(train_split * len(indices))\n", + " \n", + " train_ds = Subset(aug_ds if augment_train else full_ds, indices[:split])\n", + " val_ds = Subset(aug_ds, indices[split:])\n", + "\n", + " return (DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=True),\n", + " DataLoader(val_ds, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=True))\n", + "\n", + "\n", + "def get_dataset_info():\n", + " ds = YaGoDataset(config[\"data_dir\"], augment=True, image_size=config[\"image_size\"])\n", + " return {\n", + " \"size\": len(ds),\n", + " \"sample_keys\": list(ds[0].keys()),\n", + " \"sample_params\": ds[0][\"homography_params\"].numpy()\n", + " }\n", + "\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": "## Model\n\n`HomographyCNN6` — CNN architecture for homography estimation.\n\n**Output:** 6 parameters\n- `rx, ry, rz` — rotation angles (radians)\n- `tx, ty` — translation offsets\n- `scale` — isotropic scale factor\n\n**Architecture:**\n- Dual-branch CNN (Google + Yandex images)\n- Shared backbone (configurable: resnet18/34/50)\n" + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "\n", + "class HomographyCNN6(nn.Module):\n", + " def __init__(self, input_channels=3, backbone_name=\"resnet18\", pretrained=True, dropout_rate=0.3):\n", + " super().__init__()\n", + " backbone = getattr(models, backbone_name)(weights=models.ResNet18_Weights.IMAGENET1K_V1 if pretrained else None)\n", + " self.feature_dim = backbone.fc.in_features\n", + " backbone.fc = nn.Identity()\n", + " self.backbone = backbone\n", + "\n", + " self.head = nn.Sequential(\n", + " nn.Linear(self.feature_dim * 4, 512),\n", + " nn.ReLU(inplace=True),\n", + " nn.Dropout(dropout_rate),\n", + " nn.Linear(512, 256),\n", + " nn.ReLU(inplace=True),\n", + " nn.Dropout(dropout_rate),\n", + " nn.Linear(256, 9),\n", + " )\n", + "\n", + " def forward(self, img1, img2):\n", + " f1 = self.backbone(img1)\n", + " f2 = self.backbone(img2)\n", + " combined = torch.cat([f1, f2, torch.abs(f1 - f2), f1 * f2], dim=1)\n", + " return self.head(combined)\n", + "\n", + " def decode_output(self, output):\n", + " tx, ty = output[:, 0], output[:, 1]\n", + " sin1, cos1 = torch.tanh(output[:, 2]), torch.tanh(output[:, 3])\n", + " sin2, cos2 = torch.tanh(output[:, 4]), torch.tanh(output[:, 5])\n", + " sin3, cos3 = torch.tanh(output[:, 6]), torch.tanh(output[:, 7])\n", + " scale = output[:, 8]\n", + "\n", + " angle1 = torch.atan2(sin1, cos1)\n", + " angle2 = torch.atan2(sin2, cos2)\n", + " angle3 = torch.atan2(sin3, cos3)\n", + "\n", + " return torch.stack([tx, ty, angle1, angle2, angle3, scale], dim=1)\n", + "\n", + "\n", + "class HomographyLoss6(nn.Module):\n", + " def __init__(self, angle_loss_weight=1.0, trans_loss_weight=1.0, scale_loss_weight=1.0):\n", + " super().__init__()\n", + " self.criterion = nn.MSELoss()\n", + " self.angle_loss_weight = angle_loss_weight\n", + " self.trans_loss_weight = trans_loss_weight\n", + " self.scale_loss_weight = scale_loss_weight\n", + "\n", + " def forward(self, pred, target):\n", + " tx_loss = self.criterion(pred[:, 0], target[:, 0])\n", + " ty_loss = self.criterion(pred[:, 1], target[:, 1])\n", + "\n", + " sin1_pred, cos1_pred = pred[:, 2], pred[:, 3]\n", + " sin2_pred, cos2_pred = pred[:, 4], pred[:, 5]\n", + " sin3_pred, cos3_pred = pred[:, 6], pred[:, 7]\n", + "\n", + " sin1_target = torch.sin(target[:, 2])\n", + " cos1_target = torch.cos(target[:, 2])\n", + " sin2_target = torch.sin(target[:, 3])\n", + " cos2_target = torch.cos(target[:, 3])\n", + " sin3_target = torch.sin(target[:, 4])\n", + " cos3_target = torch.cos(target[:, 4])\n", + "\n", + " sin1_pred_t = torch.tanh(sin1_pred)\n", + " cos1_pred_t = torch.tanh(cos1_pred)\n", + " sin2_pred_t = torch.tanh(sin2_pred)\n", + " cos2_pred_t = torch.tanh(cos2_pred)\n", + " sin3_pred_t = torch.tanh(sin3_pred)\n", + " cos3_pred_t = torch.tanh(cos3_pred)\n", + " \n", + " angle1_loss = (1 - (sin1_pred_t * sin1_target + cos1_pred_t * cos1_target)).mean()\n", + " angle2_loss = (1 - (sin2_pred_t * sin2_target + cos2_pred_t * cos2_target)).mean()\n", + " angle3_loss = (1 - (sin3_pred_t * sin3_target + cos3_pred_t * cos3_target)).mean()\n", + "\n", + " scale_loss = self.criterion(pred[:, 8], target[:, 5])\n", + "\n", + " total_loss = (\n", + " self.trans_loss_weight * (tx_loss + ty_loss) +\n", + " self.angle_loss_weight * (angle1_loss + angle2_loss + angle3_loss) +\n", + " self.scale_loss_weight * scale_loss\n", + " )\n", + "\n", + " return total_loss\n", + "\n", + " def compute_mse_components(self, pred, target):\n", + " tx_mse = self.criterion(pred[:, 0], target[:, 0]).item()\n", + " ty_mse = self.criterion(pred[:, 1], target[:, 1]).item()\n", + "\n", + " sin1_target = torch.sin(target[:, 2])\n", + " cos1_target = torch.cos(target[:, 2])\n", + " sin2_target = torch.sin(target[:, 3])\n", + " cos2_target = torch.cos(target[:, 3])\n", + " sin3_target = torch.sin(target[:, 4])\n", + " cos3_target = torch.cos(target[:, 4])\n", + "\n", + " sin1_pred_t = torch.tanh(pred[:, 2])\n", + " cos1_pred_t = torch.tanh(pred[:, 3])\n", + " sin2_pred_t = torch.tanh(pred[:, 4])\n", + " cos2_pred_t = torch.tanh(pred[:, 5])\n", + " sin3_pred_t = torch.tanh(pred[:, 6])\n", + " cos3_pred_t = torch.tanh(pred[:, 7])\n", + " \n", + " angle1_loss = (1 - (sin1_pred_t * sin1_target + cos1_pred_t * cos1_target)).mean().item()\n", + " angle2_loss = (1 - (sin2_pred_t * sin2_target + cos2_pred_t * cos2_target)).mean().item()\n", + " angle3_loss = (1 - (sin3_pred_t * sin3_target + cos3_pred_t * cos3_target)).mean().item()\n", + "\n", + " scale_mse = self.criterion(pred[:, 8], target[:, 5]).item()\n", + "\n", + " avg_angle_loss = (angle1_loss + angle2_loss + angle3_loss) / 3\n", + "\n", + " return {\n", + " 'trans': (tx_mse + ty_mse) / 2,\n", + " 'angle': avg_angle_loss,\n", + " 'scale': scale_mse\n", + " }\n", + "\n", + "\n", + "def count_parameters(model):\n", + " return sum(p.numel() for p in model.parameters())\n", + "\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": "## Training\n\n`HomographyTrainer` — training loop with validation and checkpointing.\n\n**Features:**\n- Epoch-based training with tqdm progress bar\n- Adam optimizer with configurable LR\n- Validation after each epoch\n- Best model auto-save\n- Periodic checkpoints (every N epochs via `save_every_n_epochs`)\n\n**Checkpoint saving:**\n- `best_model.pt` — lowest validation loss\n" + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "\n", + "\n", + "\n", + "class HomographyTrainer:\n", + " def __init__(self, model, train_loader, val_loader, device):\n", + " self.model = model.to(device)\n", + " self.train_loader = train_loader\n", + " self.val_loader = val_loader\n", + " self.device = device\n", + " self.criterion = HomographyLoss6()\n", + " self.optimizer = optim.Adam(model.parameters(), lr=config[\"learning_rate\"])\n", + " self.writer = None\n", + " self.best_val_loss = float(\"inf\")\n", + " self.train_losses = []\n", + " self.val_losses = []\n", + " self.train_mse_trans = []\n", + " self.train_mse_angle = []\n", + " self.train_mse_scale = []\n", + " self.val_mse_trans = []\n", + " self.val_mse_angle = []\n", + " self.val_mse_scale = []\n", + "\n", + " def train_epoch(self, epoch):\n", + " self.model.train()\n", + " total_loss, total_samples = 0, 0\n", + " mse_trans_sum, mse_angle_sum, mse_scale_sum = 0, 0, 0\n", + " pbar = tqdm(self.train_loader, desc=f\"Epoch {epoch}\")\n", + " for batch_idx, batch in enumerate(pbar):\n", + " google_img = batch[\"google_img\"].to(self.device)\n", + " yandex_img = batch[\"yandex_img\"].to(self.device)\n", + " target = batch[\"homography_params\"].to(self.device)\n", + "\n", + " self.optimizer.zero_grad()\n", + " output = self.model(google_img, yandex_img)\n", + " loss = self.criterion(output, target)\n", + " loss.backward()\n", + " self.optimizer.step()\n", + "\n", + " total_loss += loss.item() * google_img.size(0)\n", + " total_samples += google_img.size(0)\n", + " \n", + " mse_components = self.criterion.compute_mse_components(output, target)\n", + " mse_trans_sum += mse_components['trans'] * google_img.size(0)\n", + " mse_angle_sum += mse_components['angle'] * google_img.size(0)\n", + " mse_scale_sum += mse_components['scale'] * google_img.size(0)\n", + " \n", + " pbar.set_postfix({\"loss\": loss.item()})\n", + "\n", + " self.train_mse_trans.append(mse_trans_sum / total_samples)\n", + " self.train_mse_angle.append(mse_angle_sum / total_samples)\n", + " self.train_mse_scale.append(mse_scale_sum / total_samples)\n", + " \n", + " return {\"loss\": total_loss / total_samples}\n", + "\n", + " def validate(self):\n", + " self.model.eval()\n", + " total_loss, total_samples = 0, 0\n", + " mse_trans_sum, mse_angle_sum, mse_scale_sum = 0, 0, 0\n", + " with torch.no_grad():\n", + " for batch in tqdm(self.val_loader, desc=\"Validation\"):\n", + " google_img = batch[\"google_img\"].to(self.device)\n", + " yandex_img = batch[\"yandex_img\"].to(self.device)\n", + " target = batch[\"homography_params\"].to(self.device)\n", + " output = self.model(google_img, yandex_img)\n", + " loss = self.criterion(output, target)\n", + " total_loss += loss.item() * google_img.size(0)\n", + " total_samples += google_img.size(0)\n", + " \n", + " mse_components = self.criterion.compute_mse_components(output, target)\n", + " mse_trans_sum += mse_components['trans'] * google_img.size(0)\n", + " mse_angle_sum += mse_components['angle'] * google_img.size(0)\n", + " mse_scale_sum += mse_components['scale'] * google_img.size(0)\n", + " \n", + " self.val_mse_trans.append(mse_trans_sum / total_samples)\n", + " self.val_mse_angle.append(mse_angle_sum / total_samples)\n", + " self.val_mse_scale.append(mse_scale_sum / total_samples)\n", + " \n", + " return {\"loss\": total_loss / total_samples}\n", + "\n", + " def train(self, num_epochs):\n", + " log_dir = config[\"output_dir\"]\n", + " os.makedirs(log_dir, exist_ok=True)\n", + " self.writer = SummaryWriter(log_dir)\n", + "\n", + " for epoch in range(1, num_epochs + 1):\n", + " train_metrics = self.train_epoch(epoch)\n", + " val_metrics = self.validate()\n", + " self.train_losses.append(train_metrics[\"loss\"])\n", + " self.val_losses.append(val_metrics[\"loss\"])\n", + " print(f\"Train Loss: {train_metrics['loss']:.4f}, Val Loss: {val_metrics['loss']:.4f}\")\n", + " print(f\" MSE - Trans: {self.val_mse_trans[-1]:.4f}, Angle: {self.val_mse_angle[-1]:.4f}, Scale: {self.val_mse_scale[-1]:.4f}\")\n", + "\n", + " if val_metrics[\"loss\"] < self.best_val_loss:\n", + " self.best_val_loss = val_metrics[\"loss\"]\n", + " self.save_checkpoint(epoch, is_best=True)\n", + " print(f\"Best model saved (val loss: {val_metrics['loss']:.4f})\")\n", + "\n", + " if epoch % config[\"save_every_n_epochs\"] == 0:\n", + " self.save_checkpoint(epoch, is_best=False)\n", + " print(f\"Checkpoint saved at epoch {epoch}\")\n", + "\n", + " self.writer.close()\n", + "\n", + " def save_checkpoint(self, epoch, is_best=False):\n", + " ckpt_dir = os.path.join(config[\"output_dir\"], \"checkpoints\")\n", + " os.makedirs(ckpt_dir, exist_ok=True)\n", + " ckpt = {\"epoch\": epoch, \"model_state_dict\": self.model.state_dict(), \"val_loss\": self.best_val_loss}\n", + " torch.save(ckpt, os.path.join(ckpt_dir, f\"checkpoint_epoch_{epoch}.pt\"))\n", + " if is_best:\n", + " torch.save(ckpt, os.path.join(ckpt_dir, \"best_model.pt\"))\n", + "\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": "## Analysis\n\nVisualization and evaluation tools:\n\n- Training metrics plots (loss curves)\n- Prediction visualization on sample images\n" + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "\n", + "IMG_DIR = os.path.join(config[\"output_dir\"], \"images\")\n", + "os.makedirs(IMG_DIR, exist_ok=True)\n", + "\n", + "\n", + "def angular_difference(pred_angles, target_angles):\n", + " diff = pred_angles - target_angles\n", + " diff = torch.atan2(torch.sin(diff), torch.cos(diff))\n", + " return torch.abs(diff)\n", + "\n", + "\n", + "def analyze_training(trainer):\n", + " print(\"=== Training Analysis ===\\n\")\n", + "\n", + " if trainer.writer:\n", + " print(\"TensorBoard logs available at:\", trainer.writer.log_dir)\n", + "\n", + " print(f\"\\nBest val loss: {trainer.best_val_loss:.4f}\")\n", + "\n", + " best_model_path = os.path.join(config[\"output_dir\"], \"checkpoints\", \"best_model.pt\")\n", + " if os.path.exists(best_model_path):\n", + " checkpoint = torch.load(best_model_path, map_location=trainer.device)\n", + " trainer.model.load_state_dict(checkpoint[\"model_state_dict\"])\n", + " print(f\"\\nLoaded best model from epoch {checkpoint['epoch']} (val loss: {checkpoint['val_loss']:.4f})\")\n", + " \n", + " trainer.model.eval()\n", + " \n", + " n_samples = 50\n", + " names = [\"tx\", \"ty\", \"rx\", \"ry\", \"rz\", \"scale\"]\n", + " \n", + " with torch.no_grad():\n", + " all_errors = [[] for _ in range(6)]\n", + " all_targets = [[] for _ in range(6)]\n", + " all_preds = [[] for _ in range(6)]\n", + " \n", + " for i in range(n_samples):\n", + " try:\n", + " batch = next(iter(trainer.val_loader))\n", + " except StopIteration:\n", + " break\n", + " google_img = batch[\"google_img\"].to(trainer.device)\n", + " yandex_img = batch[\"yandex_img\"].to(trainer.device)\n", + " target_params = batch[\"homography_params\"].to(trainer.device)\n", + " pred_params = trainer.model(google_img, yandex_img)\n", + " decoded_pred = trainer.model.decode_output(pred_params)\n", + " \n", + " tx_error = torch.abs(decoded_pred[:, 0] - target_params[:, 0]).item()\n", + " ty_error = torch.abs(decoded_pred[:, 1] - target_params[:, 1]).item()\n", + " rx_error = angular_difference(decoded_pred[:, 2], target_params[:, 2]).item()\n", + " ry_error = angular_difference(decoded_pred[:, 3], target_params[:, 3]).item()\n", + " rz_error = angular_difference(decoded_pred[:, 4], target_params[:, 4]).item()\n", + " scale_error = torch.abs(decoded_pred[:, 5] - target_params[:, 5]).item()\n", + " \n", + " errors = [tx_error, ty_error, rx_error, ry_error, rz_error, scale_error]\n", + " targets = target_params[0].cpu().numpy()\n", + " preds = decoded_pred[0].cpu().numpy()\n", + " \n", + " for j in range(6):\n", + " all_errors[j].append(errors[j])\n", + " all_targets[j].append(targets[j])\n", + " all_preds[j].append(preds[j])\n", + " \n", + " mean_errors = [np.mean(all_errors[i]) for i in range(6)]\n", + " std_errors = [np.std(all_errors[i]) for i in range(6)]\n", + " \n", + " if len(trainer.train_losses) > 0:\n", + " epochs = range(1, len(trainer.train_losses) + 1)\n", + " fig, axes = plt.subplots(2, 2, figsize=(16, 12))\n", + " \n", + " axes[0, 0].plot(epochs, trainer.train_losses, \"b-\", label=\"Train Loss\")\n", + " axes[0, 0].plot(epochs, trainer.val_losses, \"r-\", label=\"Val Loss\")\n", + " axes[0, 0].set_xlabel(\"Epoch\")\n", + " axes[0, 0].set_ylabel(\"Loss\")\n", + " axes[0, 0].set_title(\"Training & Validation Loss\")\n", + " axes[0, 0].legend()\n", + " axes[0, 0].grid(True, alpha=0.3)\n", + " \n", + " axes[0, 1].plot(epochs, trainer.val_losses, \"r-\", label=\"Val Loss\")\n", + " axes[0, 1].set_xlabel(\"Epoch\")\n", + " axes[0, 1].set_ylabel(\"Loss\")\n", + " axes[0, 1].set_title(\"Validation Loss\")\n", + " axes[0, 1].legend()\n", + " axes[0, 1].grid(True, alpha=0.3)\n", + " \n", + " axes[1, 0].plot(epochs, trainer.val_mse_trans, \"g-\", label=\"Translation (tx, ty)\")\n", + " axes[1, 0].plot(epochs, trainer.val_mse_angle, \"m-\", label=\"Angle (rx, ry, rz)\")\n", + " axes[1, 0].plot(epochs, trainer.val_mse_scale, \"c-\", label=\"Scale\")\n", + " axes[1, 0].set_xlabel(\"Epoch\")\n", + " axes[1, 0].set_ylabel(\"MSE\")\n", + " axes[1, 0].set_title(\"Validation MSE by Category\")\n", + " axes[1, 0].legend()\n", + " axes[1, 0].grid(True, alpha=0.3)\n", + " \n", + " x_pos = np.arange(6)\n", + " axes[1, 1].bar(x_pos, mean_errors, yerr=std_errors, capsize=5, color=[\"c\", \"m\", \"y\", \"g\", \"b\", \"r\"], alpha=0.8)\n", + " axes[1, 1].set_xticks(x_pos)\n", + " axes[1, 1].set_xticklabels(names)\n", + " axes[1, 1].set_ylabel(\"Mean Absolute Error\")\n", + " axes[1, 1].set_title(f\"Mean Absolute Error per Parameter ({n_samples} samples)\")\n", + " axes[1, 1].grid(True, alpha=0.3, axis=\"y\")\n", + " \n", + " plt.tight_layout()\n", + " plt.savefig(os.path.join(IMG_DIR, \"training_loss_plots.png\"), dpi=150)\n", + " print(\"Saved training_loss_plots.png\")\n", + " plt.show()\n", + " \n", + " fig, axes = plt.subplots(2, 3, figsize=(18, 10))\n", + " for j in range(6):\n", + " row = j // 3\n", + " col = j % 3\n", + " axes[row, col].bar(range(n_samples), all_errors[j], color=\"steelblue\", alpha=0.7)\n", + " axes[row, col].set_xlabel(\"Sample\")\n", + " axes[row, col].set_ylabel(\"Absolute Error\")\n", + " axes[row, col].set_title(f\"{names[j]}: Mean={np.mean(all_errors[j]):.4f}, Std={np.std(all_errors[j]):.4f}\")\n", + " axes[row, col].grid(True, alpha=0.3, axis=\"y\")\n", + " plt.suptitle(f\"Mean Absolute Error per Parameter ({n_samples} samples)\", fontsize=14)\n", + " plt.tight_layout()\n", + " plt.savefig(os.path.join(IMG_DIR, \"mae_per_parameter.png\"), dpi=150)\n", + " print(\"Saved mae_per_parameter.png\")\n", + " plt.show()\n", + " \n", + " fig, axes = plt.subplots(1, 2, figsize=(14, 6))\n", + " \n", + " x_pos = np.arange(6)\n", + " axes[0].bar(x_pos, mean_errors, yerr=std_errors, capsize=5, color=[\"c\", \"m\", \"y\", \"g\", \"b\", \"r\"], alpha=0.8)\n", + " axes[0].set_xticks(x_pos)\n", + " axes[0].set_xticklabels(names)\n", + " axes[0].set_ylabel(\"Mean Absolute Error\")\n", + " axes[0].set_title(\"Mean Absolute Error per Parameter (with std)\")\n", + " axes[0].grid(True, alpha=0.3, axis=\"y\")\n", + " \n", + " bp = axes[1].boxplot([all_errors[i] for i in range(6)], labels=names, patch_artist=True)\n", + " colors = [\"c\", \"m\", \"y\", \"g\", \"b\", \"r\"]\n", + " for patch, color in zip(bp[\"boxes\"], colors):\n", + " patch.set_facecolor(color)\n", + " patch.set_alpha(0.7)\n", + " axes[1].set_ylabel(\"Absolute Error\")\n", + " axes[1].set_title(f\"Error Distribution per Parameter ({n_samples} samples)\")\n", + " axes[1].grid(True, alpha=0.3, axis=\"y\")\n", + " \n", + " plt.tight_layout()\n", + " plt.savefig(os.path.join(IMG_DIR, \"mae_boxplot.png\"), dpi=150)\n", + " print(\"Saved mae_boxplot.png\")\n", + " plt.show()\n", + " \n", + " print(\"\\n=== Sample Predictions (20 pairs) ===\")\n", + " n_vis_samples = 20\n", + " \n", + " with torch.no_grad():\n", + " for sample_idx in range(n_vis_samples):\n", + " try:\n", + " batch = next(iter(trainer.val_loader))\n", + " except StopIteration:\n", + " break\n", + " google_img = batch[\"google_img\"].to(trainer.device)\n", + " yandex_img = batch[\"yandex_img\"].to(trainer.device)\n", + " target_params = batch[\"homography_params\"].to(trainer.device)\n", + " pred_params = trainer.model(google_img, yandex_img)\n", + " decoded_pred = trainer.model.decode_output(pred_params)\n", + " \n", + " tx_error = torch.abs(decoded_pred[:, 0] - target_params[:, 0]).cpu().numpy()\n", + " ty_error = torch.abs(decoded_pred[:, 1] - target_params[:, 1]).cpu().numpy()\n", + " rx_error = angular_difference(decoded_pred[:, 2], target_params[:, 2]).cpu().numpy()\n", + " ry_error = angular_difference(decoded_pred[:, 3], target_params[:, 3]).cpu().numpy()\n", + " rz_error = angular_difference(decoded_pred[:, 4], target_params[:, 4]).cpu().numpy()\n", + " scale_error = torch.abs(decoded_pred[:, 5] - target_params[:, 5]).cpu().numpy()\n", + " \n", + " errors = np.array([tx_error[0], ty_error[0], rx_error[0], ry_error[0], rz_error[0], scale_error[0]])\n", + " targets = target_params[0].cpu().numpy()\n", + " preds = decoded_pred[0].cpu().numpy()\n", + " \n", + " fig, axes = plt.subplots(2, 2, figsize=(12, 10))\n", + " \n", + " axes[0, 0].imshow(google_img[0].cpu().permute(1, 2, 0))\n", + " axes[0, 0].set_title(f\"Google Image\")\n", + " axes[0, 0].axis(\"off\")\n", + " \n", + " axes[0, 1].imshow(yandex_img[0].cpu().permute(1, 2, 0))\n", + " axes[0, 1].set_title(f\"Yandex Image\")\n", + " axes[0, 1].axis(\"off\")\n", + " \n", + " x_pos = np.arange(6)\n", + " width = 0.35\n", + " axes[1, 0].bar(x_pos - width/2, targets, width, label=\"Target\", color=\"steelblue\", alpha=0.8)\n", + " axes[1, 0].bar(x_pos + width/2, preds, width, label=\"Predicted\", color=\"coral\", alpha=0.8)\n", + " axes[1, 0].set_xticks(x_pos)\n", + " axes[1, 0].set_xticklabels(names)\n", + " axes[1, 0].set_ylabel(\"Parameter Value\")\n", + " axes[1, 0].set_title(\"Target vs Predicted\")\n", + " axes[1, 0].legend()\n", + " axes[1, 0].grid(True, alpha=0.3, axis=\"y\")\n", + " \n", + " axes[1, 1].bar(x_pos, errors, color=[\"c\", \"m\", \"y\", \"g\", \"b\", \"r\"], alpha=0.8)\n", + " axes[1, 1].set_xticks(x_pos)\n", + " axes[1, 1].set_xticklabels(names)\n", + " axes[1, 1].set_ylabel(\"Absolute Error\")\n", + " axes[1, 1].set_title(f\"Prediction Error (Mean: {np.mean(errors):.4f})\")\n", + " axes[1, 1].grid(True, alpha=0.3, axis=\"y\")\n", + " for i, e in enumerate(errors):\n", + " axes[1, 1].text(i, e + 0.01, f\"{e:.3f}\", ha=\"center\", va=\"bottom\", fontsize=8)\n", + " \n", + " plt.suptitle(f\"Sample {sample_idx + 1}\", fontsize=14)\n", + " plt.tight_layout()\n", + " plt.savefig(os.path.join(IMG_DIR, f\"prediction_sample_{sample_idx + 1:02d}.png\"), dpi=100)\n", + " plt.show()\n", + " print(f\"Saved prediction_sample_{sample_idx + 1:02d}.png\")\n", + " \n", + " print(f\"\\nPrediction errors over {n_samples} samples:\")\n", + " print(f\"{'Param':<8} {'Mean Error':>12} {'Std Error':>12} {'Min':>8} {'Max':>8}\")\n", + " print(\"-\" * 52)\n", + " for i in range(6):\n", + " mean_err = np.mean(all_errors[i])\n", + " std_err = np.std(all_errors[i])\n", + " min_err = np.min(all_errors[i])\n", + " max_err = np.max(all_errors[i])\n", + " print(f\"{names[i]:<8} {mean_err:>12.4f} {std_err:>12.4f} {min_err:>8.4f} {max_err:>8.4f}\")\n", + "\n", + " return {\n", + " \"best_val_loss\": trainer.best_val_loss,\n", + " \"train_losses\": trainer.train_losses,\n", + " \"val_losses\": trainer.val_losses,\n", + " \"val_mse_trans\": trainer.val_mse_trans,\n", + " \"val_mse_angle\": trainer.val_mse_angle,\n", + " \"val_mse_scale\": trainer.val_mse_scale,\n", + " }\n", + "\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": "## Main Pipeline\n\nExecutes the full training workflow:\n1. Load dataset info\n2. Create data loaders\n3. Initialize model\n4. Train with validation\n5. Analyze and export results\n\n**Outputs:**\n- Model checkpoints in `runs/checkpoints/`\n- TensorBoard logs in `runs/`\n" + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "\n", + "\n", + "\n", + "logging.basicConfig(level=logging.INFO, format=\"%(asctime)s - %(message)s\")\n", + "logger = logging.getLogger(__name__)\n", + "\n", + "logger.info(\"=\" * 50)\n", + "logger.info(\"SiaN Training Pipeline\")\n", + "logger.info(\"=\" * 50)\n", + "\n", + "dataset_info = get_dataset_info()\n", + "logger.info(f\"Dataset: {dataset_info['size']} samples, keys={dataset_info['sample_keys']}\")\n", + "\n", + "train_loader, val_loader = create_data_loaders(\n", + " root_dir=config[\"data_dir\"],\n", + " batch_size=config[\"batch_size\"],\n", + " train_split=config[\"train_split\"],\n", + " num_workers=config[\"num_workers\"],\n", + " image_size=config[\"image_size\"],\n", + ")\n", + "logger.info(f\"Data loaders created: train={len(train_loader.dataset)}, val={len(val_loader.dataset)}\")\n", + "\n", + "model = HomographyCNN6(\n", + " input_channels=3,\n", + " backbone_name=config[\"backbone\"],\n", + " pretrained=True,\n", + " dropout_rate=config[\"dropout_rate\"]\n", + ")\n", + "logger.info(f\"Model created with {count_parameters(model):,} parameters\")\n", + "\n", + "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", + "logger.info(f\"Using device: {device}\")\n", + "\n", + "trainer = HomographyTrainer(model, train_loader, val_loader, device)\n", + "logger.info(\"Starting training...\")\n", + "trainer.train(config[\"epochs\"])\n", + "logger.info(\"Training completed\")\n", + "\n", + "logger.info(\"Analyzing model...\")\n", + "results = analyze_training(trainer)\n", + "logger.info(f\"Analysis complete: best_val_loss={results['best_val_loss']:.4f}\")\n", + "\n", + "logger.info(\"=\" * 50)\n", + "logger.info(\"Pipeline completed successfully\")\n", + "logger.info(\"=\" * 50)\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!zip artefacts.zip runs/checkpoints/best_model.pt runs/images/ runs/events.*\n", + "\n" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": ".venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "name": "python", + "version": "3.11.0" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} \ No newline at end of file diff --git a/models/SiaN/src/analyze.py b/models/SiaN/src/analyze.py index 75816d3..3dedd62 100644 --- a/models/SiaN/src/analyze.py +++ b/models/SiaN/src/analyze.py @@ -9,6 +9,12 @@ IMG_DIR = os.path.join(config["output_dir"], "images") os.makedirs(IMG_DIR, exist_ok=True) +def angular_difference(pred_angles, target_angles): + diff = pred_angles - target_angles + diff = torch.atan2(torch.sin(diff), torch.cos(diff)) + return torch.abs(diff) + + def analyze_training(trainer): print("=== Training Analysis ===\n") @@ -26,7 +32,7 @@ def analyze_training(trainer): trainer.model.eval() n_samples = 50 - names = ["rx", "ry", "rz", "tx", "ty", "scale"] + names = ["tx", "ty", "rx", "ry", "rz", "scale"] with torch.no_grad(): all_errors = [[] for _ in range(6)] @@ -42,11 +48,23 @@ def analyze_training(trainer): yandex_img = batch["yandex_img"].to(trainer.device) target_params = batch["homography_params"].to(trainer.device) pred_params = trainer.model(google_img, yandex_img) + decoded_pred = trainer.model.decode_output(pred_params) + + tx_error = torch.abs(decoded_pred[:, 0] - target_params[:, 0]).item() + ty_error = torch.abs(decoded_pred[:, 1] - target_params[:, 1]).item() + rx_error = angular_difference(decoded_pred[:, 2], target_params[:, 2]).item() + ry_error = angular_difference(decoded_pred[:, 3], target_params[:, 3]).item() + rz_error = angular_difference(decoded_pred[:, 4], target_params[:, 4]).item() + scale_error = torch.abs(decoded_pred[:, 5] - target_params[:, 5]).item() + + errors = [tx_error, ty_error, rx_error, ry_error, rz_error, scale_error] + targets = target_params[0].cpu().numpy() + preds = decoded_pred[0].cpu().numpy() for j in range(6): - all_errors[j].append(torch.abs(pred_params[0, j] - target_params[0, j]).item()) - all_targets[j].append(target_params[0, j].item()) - all_preds[j].append(pred_params[0, j].item()) + all_errors[j].append(errors[j]) + all_targets[j].append(targets[j]) + all_preds[j].append(preds[j]) mean_errors = [np.mean(all_errors[i]) for i in range(6)] std_errors = [np.std(all_errors[i]) for i in range(6)] @@ -144,10 +162,18 @@ def analyze_training(trainer): yandex_img = batch["yandex_img"].to(trainer.device) target_params = batch["homography_params"].to(trainer.device) pred_params = trainer.model(google_img, yandex_img) + decoded_pred = trainer.model.decode_output(pred_params) - errors = torch.abs(pred_params[0] - target_params[0]).cpu().numpy() + tx_error = torch.abs(decoded_pred[:, 0] - target_params[:, 0]).cpu().numpy() + ty_error = torch.abs(decoded_pred[:, 1] - target_params[:, 1]).cpu().numpy() + rx_error = angular_difference(decoded_pred[:, 2], target_params[:, 2]).cpu().numpy() + ry_error = angular_difference(decoded_pred[:, 3], target_params[:, 3]).cpu().numpy() + rz_error = angular_difference(decoded_pred[:, 4], target_params[:, 4]).cpu().numpy() + scale_error = torch.abs(decoded_pred[:, 5] - target_params[:, 5]).cpu().numpy() + + errors = np.array([tx_error[0], ty_error[0], rx_error[0], ry_error[0], rz_error[0], scale_error[0]]) targets = target_params[0].cpu().numpy() - preds = pred_params[0].cpu().numpy() + preds = decoded_pred[0].cpu().numpy() fig, axes = plt.subplots(2, 2, figsize=(12, 10)) diff --git a/models/SiaN/src/model.py b/models/SiaN/src/model.py index b2ee78a..3711a92 100644 --- a/models/SiaN/src/model.py +++ b/models/SiaN/src/model.py @@ -18,7 +18,7 @@ class HomographyCNN6(nn.Module): nn.Linear(512, 256), nn.ReLU(inplace=True), nn.Dropout(dropout_rate), - nn.Linear(256, 6), + nn.Linear(256, 9), ) def forward(self, img1, img2): @@ -27,14 +27,95 @@ class HomographyCNN6(nn.Module): combined = torch.cat([f1, f2, torch.abs(f1 - f2), f1 * f2], dim=1) return self.head(combined) + def decode_output(self, output): + tx, ty = output[:, 0], output[:, 1] + sin1, cos1 = torch.tanh(output[:, 2]), torch.tanh(output[:, 3]) + sin2, cos2 = torch.tanh(output[:, 4]), torch.tanh(output[:, 5]) + sin3, cos3 = torch.tanh(output[:, 6]), torch.tanh(output[:, 7]) + scale = output[:, 8] + + angle1 = torch.atan2(sin1, cos1) + angle2 = torch.atan2(sin2, cos2) + angle3 = torch.atan2(sin3, cos3) + + return torch.stack([tx, ty, angle1, angle2, angle3, scale], dim=1) + class HomographyLoss6(nn.Module): - def __init__(self): + def __init__(self, angle_loss_weight=1.0, trans_loss_weight=1.0, scale_loss_weight=1.0): super().__init__() self.criterion = nn.MSELoss() + self.angle_loss_weight = angle_loss_weight + self.trans_loss_weight = trans_loss_weight + self.scale_loss_weight = scale_loss_weight def forward(self, pred, target): - return self.criterion(pred, target) + tx_loss = self.criterion(pred[:, 0], target[:, 0]) + ty_loss = self.criterion(pred[:, 1], target[:, 1]) + + sin1_pred, cos1_pred = pred[:, 2], pred[:, 3] + sin2_pred, cos2_pred = pred[:, 4], pred[:, 5] + sin3_pred, cos3_pred = pred[:, 6], pred[:, 7] + + sin1_target = torch.sin(target[:, 2]) + cos1_target = torch.cos(target[:, 2]) + sin2_target = torch.sin(target[:, 3]) + cos2_target = torch.cos(target[:, 3]) + sin3_target = torch.sin(target[:, 4]) + cos3_target = torch.cos(target[:, 4]) + + sin1_pred_t = torch.tanh(sin1_pred) + cos1_pred_t = torch.tanh(cos1_pred) + sin2_pred_t = torch.tanh(sin2_pred) + cos2_pred_t = torch.tanh(cos2_pred) + sin3_pred_t = torch.tanh(sin3_pred) + cos3_pred_t = torch.tanh(cos3_pred) + + angle1_loss = (1 - (sin1_pred_t * sin1_target + cos1_pred_t * cos1_target)).mean() + angle2_loss = (1 - (sin2_pred_t * sin2_target + cos2_pred_t * cos2_target)).mean() + angle3_loss = (1 - (sin3_pred_t * sin3_target + cos3_pred_t * cos3_target)).mean() + + scale_loss = self.criterion(pred[:, 8], target[:, 5]) + + total_loss = ( + self.trans_loss_weight * (tx_loss + ty_loss) + + self.angle_loss_weight * (angle1_loss + angle2_loss + angle3_loss) + + self.scale_loss_weight * scale_loss + ) + + return total_loss + + def compute_mse_components(self, pred, target): + tx_mse = self.criterion(pred[:, 0], target[:, 0]).item() + ty_mse = self.criterion(pred[:, 1], target[:, 1]).item() + + sin1_target = torch.sin(target[:, 2]) + cos1_target = torch.cos(target[:, 2]) + sin2_target = torch.sin(target[:, 3]) + cos2_target = torch.cos(target[:, 3]) + sin3_target = torch.sin(target[:, 4]) + cos3_target = torch.cos(target[:, 4]) + + sin1_pred_t = torch.tanh(pred[:, 2]) + cos1_pred_t = torch.tanh(pred[:, 3]) + sin2_pred_t = torch.tanh(pred[:, 4]) + cos2_pred_t = torch.tanh(pred[:, 5]) + sin3_pred_t = torch.tanh(pred[:, 6]) + cos3_pred_t = torch.tanh(pred[:, 7]) + + angle1_loss = (1 - (sin1_pred_t * sin1_target + cos1_pred_t * cos1_target)).mean().item() + angle2_loss = (1 - (sin2_pred_t * sin2_target + cos2_pred_t * cos2_target)).mean().item() + angle3_loss = (1 - (sin3_pred_t * sin3_target + cos3_pred_t * cos3_target)).mean().item() + + scale_mse = self.criterion(pred[:, 8], target[:, 5]).item() + + avg_angle_loss = (angle1_loss + angle2_loss + angle3_loss) / 3 + + return { + 'trans': (tx_mse + ty_mse) / 2, + 'angle': avg_angle_loss, + 'scale': scale_mse + } def count_parameters(model): diff --git a/models/SiaN/src/train.py b/models/SiaN/src/train.py index 94929f7..290b3b0 100644 --- a/models/SiaN/src/train.py +++ b/models/SiaN/src/train.py @@ -49,9 +49,10 @@ class HomographyTrainer: total_loss += loss.item() * google_img.size(0) total_samples += google_img.size(0) - mse_trans_sum += torch.mean((output[:, 3:5] - target[:, 3:5]) ** 2).item() * google_img.size(0) - mse_angle_sum += torch.mean((output[:, 0:3] - target[:, 0:3]) ** 2).item() * google_img.size(0) - mse_scale_sum += torch.mean((output[:, 5:6] - target[:, 5:6]) ** 2).item() * google_img.size(0) + mse_components = self.criterion.compute_mse_components(output, target) + mse_trans_sum += mse_components['trans'] * google_img.size(0) + mse_angle_sum += mse_components['angle'] * google_img.size(0) + mse_scale_sum += mse_components['scale'] * google_img.size(0) pbar.set_postfix({"loss": loss.item()}) @@ -75,9 +76,10 @@ class HomographyTrainer: total_loss += loss.item() * google_img.size(0) total_samples += google_img.size(0) - mse_trans_sum += torch.mean((output[:, 3:5] - target[:, 3:5]) ** 2).item() * google_img.size(0) - mse_angle_sum += torch.mean((output[:, 0:3] - target[:, 0:3]) ** 2).item() * google_img.size(0) - mse_scale_sum += torch.mean((output[:, 5:6] - target[:, 5:6]) ** 2).item() * google_img.size(0) + mse_components = self.criterion.compute_mse_components(output, target) + mse_trans_sum += mse_components['trans'] * google_img.size(0) + mse_angle_sum += mse_components['angle'] * google_img.size(0) + mse_scale_sum += mse_components['scale'] * google_img.size(0) self.val_mse_trans.append(mse_trans_sum / total_samples) self.val_mse_angle.append(mse_angle_sum / total_samples)