{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import os\n", "import random\n", "import logging\n", "from typing import Tuple\n", "import cv2\n", "import numpy as np\n", "import torch\n", "import torch.nn as nn\n", "import torch.optim as optim\n", "import matplotlib.pyplot as plt\n", "from PIL import Image\n", "from torch.utils.data import DataLoader, Dataset, Subset\n", "from torch.utils.tensorboard import SummaryWriter\n", "from torchvision import transforms, models\n", "from tqdm import tqdm\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": "# Configuration\n\nGlobal settings for:\n- Data paths and image parameters\n- Training hyperparameters\n- Model architecture options\n" }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "\n", "\n", "config = {\n", " \"data_dir\": r\"C:\\Users\\admin\\Projects\\autopilot\\datasets\\ya_go_maps\\images\",\n", " \"image_size\": (256, 256),\n", " \"batch_size\": 32,\n", " \"train_split\": 0.8,\n", " \"num_workers\": 0,\n", " \"epochs\": 100,\n", " \"learning_rate\": 2e-4,\n", " \"dropout_rate\": 0.3,\n", " \"backbone\": \"resnet18\",\n", " \"output_dir\": r\"C:\\Users\\admin\\Projects\\autopilot\\models\\SiaN\\runs\",\n", " \"save_every_n_epochs\": 15,\n", "}\n", "\n", "\n", "def get_camera_matrix(w, h):\n", " return np.array([[w / 2, 0, w / 2], [0, h / 2, h / 2], [0, 0, 1]], dtype=np.float32)\n", "\n", "\n", "def generate_random_homography_params(angle_range=10, translation_range=0.1, scale_range=(0.9, 1.1)):\n", " scale = np.random.uniform(*scale_range)\n", " tx = np.random.uniform(-translation_range, translation_range)\n", " ty = np.random.uniform(-translation_range, translation_range)\n", " rx = np.radians(np.random.uniform(-angle_range, angle_range))\n", " ry = np.radians(np.random.uniform(-angle_range, angle_range))\n", " rz = np.radians(np.random.uniform(-angle_range, angle_range))\n", " return np.array([rx, ry, rz, tx, ty, scale])\n", "\n", "\n", "def homography_params_to_matrix(params, K):\n", " rx, ry, rz, tx, ty, scale = params\n", " cy, sy = np.cos(rz), np.sin(rz)\n", " cp, sp = np.cos(ry), np.sin(ry)\n", " cr, sr = np.cos(rx), np.sin(rx)\n", " Rz = np.array([[cy, -sy, 0], [sy, cy, 0], [0, 0, 1]], dtype=np.float32)\n", " Ry = np.array([[cp, 0, sp], [0, 1, 0], [-sp, 0, cp]], dtype=np.float32)\n", " Rx = np.array([[1, 0, 0], [0, cr, -sr], [0, sr, cr]], dtype=np.float32)\n", " T = np.array([[1, 0, tx], [0, 1, ty], [0, 0, scale]], dtype=np.float32)\n", " return K @ Rx @ Ry @ Rz @ T @ np.linalg.inv(K)\n", "\n", "\n", "def matrix_to_homography_params(H, K):\n", " K_inv = np.linalg.inv(K)\n", " E = K_inv @ H @ K\n", " scale = np.sqrt(np.linalg.det(E[:2, :2]))\n", " R = E[:2, :2] / scale\n", " tx, ty = E[0, 2], E[1, 2]\n", " rz = np.arctan2(R[1, 0], R[0, 0])\n", " r20, r21 = E[2, 0], E[2, 1]\n", " ry = np.arctan2(r20, r21)\n", " rx = np.arctan2(-E[1, 2], E[1, 1])\n", " return np.array([rx, ry, rz, tx, ty, scale], dtype=np.float32)\n", "\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": "## Dataset\n\nGoogle/Yandex image pair loader with homography augmentation.\n\n**Features:**\n- Loads paired images from dual camera sources\n- Applies random homography transformations\n- Supports configurable train/val split\n\n**Returns:**\n" }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "\n", "\n", "\n", "\n", "class YaGoDataset(Dataset):\n", " def __init__(self, root_dir: str, transform=None, augment: bool = True, \n", " image_size: Tuple[int, int] = (256, 256), cache_level: int = 5):\n", " self.root_dir = root_dir\n", " self.transform = transform\n", " self.augment = augment\n", " self.image_size = image_size\n", " self.cache_level = cache_level\n", " self.K = get_camera_matrix(image_size[1], image_size[0])\n", " self.image_pairs = self._discover_image_pairs()\n", " self._load_images_to_memory()\n", " self._init_cache()\n", "\n", " def _discover_image_pairs(self):\n", " pairs = []\n", " for f in os.listdir(self.root_dir):\n", " if f.endswith(\"_google.png\"):\n", " idx = f.split(\"_\")[0]\n", " yandex_path = os.path.join(self.root_dir, f\"{idx}_yandex.png\")\n", " if os.path.exists(yandex_path):\n", " pairs.append({\"idx\": int(idx), \"google\": os.path.join(self.root_dir, f), \"yandex\": yandex_path})\n", " return sorted(pairs, key=lambda x: x[\"idx\"])\n", "\n", " def _load_images_to_memory(self):\n", " self._google_images = []\n", " self._yandex_images = []\n", " for pair in self.image_pairs:\n", " google_img = cv2.imread(pair[\"google\"])\n", " google_img = cv2.cvtColor(google_img, cv2.COLOR_BGR2RGB)\n", " google_img = cv2.resize(google_img, (self.image_size[1], self.image_size[0]), interpolation=cv2.INTER_LINEAR)\n", " \n", " yandex_img = cv2.imread(pair[\"yandex\"])\n", " yandex_img = cv2.cvtColor(yandex_img, cv2.COLOR_BGR2RGB)\n", " yandex_img = cv2.resize(yandex_img, (self.image_size[1], self.image_size[0]), interpolation=cv2.INTER_LINEAR)\n", " \n", " self._google_images.append(google_img)\n", " self._yandex_images.append(yandex_img)\n", "\n", " def _init_cache(self):\n", " self._access_counts = [0] * len(self.image_pairs)\n", " self._cached_google = [None] * len(self.image_pairs)\n", " self._cached_yandex = [None] * len(self.image_pairs)\n", " self._cached_homography = [None] * len(self.image_pairs)\n", "\n", " def _generate_augmented(self, idx):\n", " google_img = self._google_images[idx].copy()\n", " yandex_img = self._yandex_images[idx].copy()\n", "\n", " params1 = generate_random_homography_params()\n", " params2 = generate_random_homography_params()\n", " H1 = homography_params_to_matrix(params1, self.K)\n", " H2 = homography_params_to_matrix(params2, self.K)\n", " H_combined = np.linalg.inv(H1) @ H2\n", " \n", " google_warped = cv2.warpPerspective(google_img, H2, (self.image_size[1], self.image_size[0]))\n", " yandex_warped = cv2.warpPerspective(yandex_img, H1, (self.image_size[1], self.image_size[0]))\n", " \n", " target_params = matrix_to_homography_params(H_combined, self.K)\n", " \n", " return google_warped, yandex_warped, H_combined, target_params\n", "\n", " def __len__(self):\n", " return len(self.image_pairs)\n", "\n", " def __getitem__(self, idx):\n", " self._access_counts[idx] += 1\n", " \n", " use_cache = self.augment and self.cache_level > 0 and self._access_counts[idx] > 1 and (self._access_counts[idx] - 1) % self.cache_level != 0\n", " \n", " if use_cache:\n", " google_img = self._cached_google[idx]\n", " yandex_img = self._cached_yandex[idx]\n", " target_matrix = self._cached_homography[idx]\n", " target_params = matrix_to_homography_params(target_matrix, self.K)\n", " elif self.augment:\n", " google_img, yandex_img, target_matrix, target_params = self._generate_augmented(idx)\n", " if self.cache_level > 0:\n", " self._cached_google[idx] = google_img\n", " self._cached_yandex[idx] = yandex_img\n", " self._cached_homography[idx] = target_matrix\n", " else:\n", " google_img = self._google_images[idx]\n", " yandex_img = self._yandex_images[idx]\n", " target_params = np.zeros(6, dtype=np.float32)\n", " target_matrix = np.eye(3, dtype=np.float32)\n", "\n", " google_img = Image.fromarray(google_img)\n", " yandex_img = Image.fromarray(yandex_img)\n", "\n", " if self.transform:\n", " google_img = self.transform(google_img)\n", " yandex_img = self.transform(yandex_img)\n", "\n", " return {\n", " \"google_img\": google_img,\n", " \"yandex_img\": yandex_img,\n", " \"homography_matrix\": torch.from_numpy(target_matrix).float(),\n", " \"homography_params\": torch.from_numpy(target_params).float(),\n", " }\n", "\n", "\n", "def create_data_loaders(root_dir, batch_size=32, train_split=0.8, num_workers=0, \n", " image_size=(256, 256), augment_train=True, cache_level=5):\n", " transform = transforms.Compose([\n", " transforms.ToTensor(),\n", " # transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),\n", " ])\n", " \n", " full_ds = YaGoDataset(root_dir, transform=transform, augment=False, image_size=image_size, cache_level=cache_level)\n", " aug_ds = YaGoDataset(root_dir, transform=transform, augment=True, image_size=image_size, cache_level=cache_level)\n", "\n", " indices = list(range(len(full_ds)))\n", " random.shuffle(indices)\n", " split = int(train_split * len(indices))\n", " \n", " train_ds = Subset(aug_ds if augment_train else full_ds, indices[:split])\n", " val_ds = Subset(aug_ds, indices[split:])\n", "\n", " return (DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=True),\n", " DataLoader(val_ds, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=True))\n", "\n", "\n", "def get_dataset_info():\n", " ds = YaGoDataset(config[\"data_dir\"], augment=True, image_size=config[\"image_size\"])\n", " return {\n", " \"size\": len(ds),\n", " \"sample_keys\": list(ds[0].keys()),\n", " \"sample_params\": ds[0][\"homography_params\"].numpy()\n", " }\n", "\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": "## Model\n\n`HomographyCNN6` — CNN architecture for homography estimation.\n\n**Output:** 6 parameters\n- `rx, ry, rz` — rotation angles (radians)\n- `tx, ty` — translation offsets\n- `scale` — isotropic scale factor\n\n**Architecture:**\n- Dual-branch CNN (Google + Yandex images)\n- Shared backbone (configurable: resnet18/34/50)\n" }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "\n", "\n", "class HomographyCNN6(nn.Module):\n", " def __init__(self, input_channels=3, backbone_name=\"resnet18\", pretrained=True, dropout_rate=0.3):\n", " super().__init__()\n", " backbone = getattr(models, backbone_name)(weights=models.ResNet18_Weights.IMAGENET1K_V1 if pretrained else None)\n", " self.feature_dim = backbone.fc.in_features\n", " backbone.fc = nn.Identity()\n", " self.backbone = backbone\n", "\n", " self.head = nn.Sequential(\n", " nn.Linear(self.feature_dim * 4, 512),\n", " nn.ReLU(inplace=True),\n", " nn.Dropout(dropout_rate),\n", " nn.Linear(512, 256),\n", " nn.ReLU(inplace=True),\n", " nn.Dropout(dropout_rate),\n", " nn.Linear(256, 9),\n", " )\n", "\n", " def forward(self, img1, img2):\n", " f1 = self.backbone(img1)\n", " f2 = self.backbone(img2)\n", " combined = torch.cat([f1, f2, torch.abs(f1 - f2), f1 * f2], dim=1)\n", " return self.head(combined)\n", "\n", " def decode_output(self, output):\n", " tx, ty = output[:, 0], output[:, 1]\n", " sin1, cos1 = torch.tanh(output[:, 2]), torch.tanh(output[:, 3])\n", " sin2, cos2 = torch.tanh(output[:, 4]), torch.tanh(output[:, 5])\n", " sin3, cos3 = torch.tanh(output[:, 6]), torch.tanh(output[:, 7])\n", " scale = output[:, 8]\n", "\n", " angle1 = torch.atan2(sin1, cos1)\n", " angle2 = torch.atan2(sin2, cos2)\n", " angle3 = torch.atan2(sin3, cos3)\n", "\n", " return torch.stack([tx, ty, angle1, angle2, angle3, scale], dim=1)\n", "\n", "\n", "class HomographyLoss6(nn.Module):\n", " def __init__(self, angle_loss_weight=1.0, trans_loss_weight=1.0, scale_loss_weight=1.0):\n", " super().__init__()\n", " self.criterion = nn.MSELoss()\n", " self.angle_loss_weight = angle_loss_weight\n", " self.trans_loss_weight = trans_loss_weight\n", " self.scale_loss_weight = scale_loss_weight\n", "\n", " def forward(self, pred, target):\n", " tx_loss = self.criterion(pred[:, 0], target[:, 0])\n", " ty_loss = self.criterion(pred[:, 1], target[:, 1])\n", "\n", " sin1_pred, cos1_pred = pred[:, 2], pred[:, 3]\n", " sin2_pred, cos2_pred = pred[:, 4], pred[:, 5]\n", " sin3_pred, cos3_pred = pred[:, 6], pred[:, 7]\n", "\n", " sin1_target = torch.sin(target[:, 2])\n", " cos1_target = torch.cos(target[:, 2])\n", " sin2_target = torch.sin(target[:, 3])\n", " cos2_target = torch.cos(target[:, 3])\n", " sin3_target = torch.sin(target[:, 4])\n", " cos3_target = torch.cos(target[:, 4])\n", "\n", " sin1_pred_t = torch.tanh(sin1_pred)\n", " cos1_pred_t = torch.tanh(cos1_pred)\n", " sin2_pred_t = torch.tanh(sin2_pred)\n", " cos2_pred_t = torch.tanh(cos2_pred)\n", " sin3_pred_t = torch.tanh(sin3_pred)\n", " cos3_pred_t = torch.tanh(cos3_pred)\n", " \n", " angle1_loss = (1 - (sin1_pred_t * sin1_target + cos1_pred_t * cos1_target)).mean()\n", " angle2_loss = (1 - (sin2_pred_t * sin2_target + cos2_pred_t * cos2_target)).mean()\n", " angle3_loss = (1 - (sin3_pred_t * sin3_target + cos3_pred_t * cos3_target)).mean()\n", "\n", " scale_loss = self.criterion(pred[:, 8], target[:, 5])\n", "\n", " total_loss = (\n", " self.trans_loss_weight * (tx_loss + ty_loss) +\n", " self.angle_loss_weight * (angle1_loss + angle2_loss + angle3_loss) +\n", " self.scale_loss_weight * scale_loss\n", " )\n", "\n", " return total_loss\n", "\n", " def compute_mse_components(self, pred, target):\n", " tx_mse = self.criterion(pred[:, 0], target[:, 0]).item()\n", " ty_mse = self.criterion(pred[:, 1], target[:, 1]).item()\n", "\n", " sin1_target = torch.sin(target[:, 2])\n", " cos1_target = torch.cos(target[:, 2])\n", " sin2_target = torch.sin(target[:, 3])\n", " cos2_target = torch.cos(target[:, 3])\n", " sin3_target = torch.sin(target[:, 4])\n", " cos3_target = torch.cos(target[:, 4])\n", "\n", " sin1_pred_t = torch.tanh(pred[:, 2])\n", " cos1_pred_t = torch.tanh(pred[:, 3])\n", " sin2_pred_t = torch.tanh(pred[:, 4])\n", " cos2_pred_t = torch.tanh(pred[:, 5])\n", " sin3_pred_t = torch.tanh(pred[:, 6])\n", " cos3_pred_t = torch.tanh(pred[:, 7])\n", " \n", " angle1_loss = (1 - (sin1_pred_t * sin1_target + cos1_pred_t * cos1_target)).mean().item()\n", " angle2_loss = (1 - (sin2_pred_t * sin2_target + cos2_pred_t * cos2_target)).mean().item()\n", " angle3_loss = (1 - (sin3_pred_t * sin3_target + cos3_pred_t * cos3_target)).mean().item()\n", "\n", " scale_mse = self.criterion(pred[:, 8], target[:, 5]).item()\n", "\n", " avg_angle_loss = (angle1_loss + angle2_loss + angle3_loss) / 3\n", "\n", " return {\n", " 'trans': (tx_mse + ty_mse) / 2,\n", " 'angle': avg_angle_loss,\n", " 'scale': scale_mse\n", " }\n", "\n", "\n", "def count_parameters(model):\n", " return sum(p.numel() for p in model.parameters())\n", "\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": "## Training\n\n`HomographyTrainer` — training loop with validation and checkpointing.\n\n**Features:**\n- Epoch-based training with tqdm progress bar\n- Adam optimizer with configurable LR\n- Validation after each epoch\n- Best model auto-save\n- Periodic checkpoints (every N epochs via `save_every_n_epochs`)\n\n**Checkpoint saving:**\n- `best_model.pt` — lowest validation loss\n" }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "\n", "\n", "\n", "\n", "class HomographyTrainer:\n", " def __init__(self, model, train_loader, val_loader, device):\n", " self.model = model.to(device)\n", " self.train_loader = train_loader\n", " self.val_loader = val_loader\n", " self.device = device\n", " self.criterion = HomographyLoss6()\n", " self.optimizer = optim.Adam(model.parameters(), lr=config[\"learning_rate\"])\n", " self.writer = None\n", " self.best_val_loss = float(\"inf\")\n", " self.train_losses = []\n", " self.val_losses = []\n", " self.train_mse_trans = []\n", " self.train_mse_angle = []\n", " self.train_mse_scale = []\n", " self.val_mse_trans = []\n", " self.val_mse_angle = []\n", " self.val_mse_scale = []\n", "\n", " def train_epoch(self, epoch):\n", " self.model.train()\n", " total_loss, total_samples = 0, 0\n", " mse_trans_sum, mse_angle_sum, mse_scale_sum = 0, 0, 0\n", " pbar = tqdm(self.train_loader, desc=f\"Epoch {epoch}\")\n", " for batch_idx, batch in enumerate(pbar):\n", " google_img = batch[\"google_img\"].to(self.device)\n", " yandex_img = batch[\"yandex_img\"].to(self.device)\n", " target = batch[\"homography_params\"].to(self.device)\n", "\n", " self.optimizer.zero_grad()\n", " output = self.model(google_img, yandex_img)\n", " loss = self.criterion(output, target)\n", " loss.backward()\n", " self.optimizer.step()\n", "\n", " total_loss += loss.item() * google_img.size(0)\n", " total_samples += google_img.size(0)\n", " \n", " mse_components = self.criterion.compute_mse_components(output, target)\n", " mse_trans_sum += mse_components['trans'] * google_img.size(0)\n", " mse_angle_sum += mse_components['angle'] * google_img.size(0)\n", " mse_scale_sum += mse_components['scale'] * google_img.size(0)\n", " \n", " pbar.set_postfix({\"loss\": loss.item()})\n", "\n", " self.train_mse_trans.append(mse_trans_sum / total_samples)\n", " self.train_mse_angle.append(mse_angle_sum / total_samples)\n", " self.train_mse_scale.append(mse_scale_sum / total_samples)\n", " \n", " return {\"loss\": total_loss / total_samples}\n", "\n", " def validate(self):\n", " self.model.eval()\n", " total_loss, total_samples = 0, 0\n", " mse_trans_sum, mse_angle_sum, mse_scale_sum = 0, 0, 0\n", " with torch.no_grad():\n", " for batch in tqdm(self.val_loader, desc=\"Validation\"):\n", " google_img = batch[\"google_img\"].to(self.device)\n", " yandex_img = batch[\"yandex_img\"].to(self.device)\n", " target = batch[\"homography_params\"].to(self.device)\n", " output = self.model(google_img, yandex_img)\n", " loss = self.criterion(output, target)\n", " total_loss += loss.item() * google_img.size(0)\n", " total_samples += google_img.size(0)\n", " \n", " mse_components = self.criterion.compute_mse_components(output, target)\n", " mse_trans_sum += mse_components['trans'] * google_img.size(0)\n", " mse_angle_sum += mse_components['angle'] * google_img.size(0)\n", " mse_scale_sum += mse_components['scale'] * google_img.size(0)\n", " \n", " self.val_mse_trans.append(mse_trans_sum / total_samples)\n", " self.val_mse_angle.append(mse_angle_sum / total_samples)\n", " self.val_mse_scale.append(mse_scale_sum / total_samples)\n", " \n", " return {\"loss\": total_loss / total_samples}\n", "\n", " def train(self, num_epochs):\n", " log_dir = config[\"output_dir\"]\n", " os.makedirs(log_dir, exist_ok=True)\n", " self.writer = SummaryWriter(log_dir)\n", "\n", " for epoch in range(1, num_epochs + 1):\n", " train_metrics = self.train_epoch(epoch)\n", " val_metrics = self.validate()\n", " self.train_losses.append(train_metrics[\"loss\"])\n", " self.val_losses.append(val_metrics[\"loss\"])\n", " print(f\"Train Loss: {train_metrics['loss']:.4f}, Val Loss: {val_metrics['loss']:.4f}\")\n", " print(f\" MSE - Trans: {self.val_mse_trans[-1]:.4f}, Angle: {self.val_mse_angle[-1]:.4f}, Scale: {self.val_mse_scale[-1]:.4f}\")\n", "\n", " if val_metrics[\"loss\"] < self.best_val_loss:\n", " self.best_val_loss = val_metrics[\"loss\"]\n", " self.save_checkpoint(epoch, is_best=True)\n", " print(f\"Best model saved (val loss: {val_metrics['loss']:.4f})\")\n", "\n", " if epoch % config[\"save_every_n_epochs\"] == 0:\n", " self.save_checkpoint(epoch, is_best=False)\n", " print(f\"Checkpoint saved at epoch {epoch}\")\n", "\n", " self.writer.close()\n", "\n", " def save_checkpoint(self, epoch, is_best=False):\n", " ckpt_dir = os.path.join(config[\"output_dir\"], \"checkpoints\")\n", " os.makedirs(ckpt_dir, exist_ok=True)\n", " ckpt = {\"epoch\": epoch, \"model_state_dict\": self.model.state_dict(), \"val_loss\": self.best_val_loss}\n", " torch.save(ckpt, os.path.join(ckpt_dir, f\"checkpoint_epoch_{epoch}.pt\"))\n", " if is_best:\n", " torch.save(ckpt, os.path.join(ckpt_dir, \"best_model.pt\"))\n", "\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": "## Analysis\n\nVisualization and evaluation tools:\n\n- Training metrics plots (loss curves)\n- Prediction visualization on sample images\n" }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "\n", "\n", "IMG_DIR = os.path.join(config[\"output_dir\"], \"images\")\n", "os.makedirs(IMG_DIR, exist_ok=True)\n", "\n", "\n", "def angular_difference(pred_angles, target_angles):\n", " diff = pred_angles - target_angles\n", " diff = torch.atan2(torch.sin(diff), torch.cos(diff))\n", " return torch.abs(diff)\n", "\n", "\n", "def analyze_training(trainer):\n", " print(\"=== Training Analysis ===\\n\")\n", "\n", " if trainer.writer:\n", " print(\"TensorBoard logs available at:\", trainer.writer.log_dir)\n", "\n", " print(f\"\\nBest val loss: {trainer.best_val_loss:.4f}\")\n", "\n", " best_model_path = os.path.join(config[\"output_dir\"], \"checkpoints\", \"best_model.pt\")\n", " if os.path.exists(best_model_path):\n", " checkpoint = torch.load(best_model_path, map_location=trainer.device)\n", " trainer.model.load_state_dict(checkpoint[\"model_state_dict\"])\n", " print(f\"\\nLoaded best model from epoch {checkpoint['epoch']} (val loss: {checkpoint['val_loss']:.4f})\")\n", " \n", " trainer.model.eval()\n", " \n", " n_samples = 50\n", " names = [\"tx\", \"ty\", \"rx\", \"ry\", \"rz\", \"scale\"]\n", " \n", " with torch.no_grad():\n", " all_errors = [[] for _ in range(6)]\n", " all_targets = [[] for _ in range(6)]\n", " all_preds = [[] for _ in range(6)]\n", " \n", " for i in range(n_samples):\n", " try:\n", " batch = next(iter(trainer.val_loader))\n", " except StopIteration:\n", " break\n", " google_img = batch[\"google_img\"].to(trainer.device)\n", " yandex_img = batch[\"yandex_img\"].to(trainer.device)\n", " target_params = batch[\"homography_params\"].to(trainer.device)\n", " pred_params = trainer.model(google_img, yandex_img)\n", " decoded_pred = trainer.model.decode_output(pred_params)\n", " \n", " tx_error = torch.abs(decoded_pred[:, 0] - target_params[:, 0]).item()\n", " ty_error = torch.abs(decoded_pred[:, 1] - target_params[:, 1]).item()\n", " rx_error = angular_difference(decoded_pred[:, 2], target_params[:, 2]).item()\n", " ry_error = angular_difference(decoded_pred[:, 3], target_params[:, 3]).item()\n", " rz_error = angular_difference(decoded_pred[:, 4], target_params[:, 4]).item()\n", " scale_error = torch.abs(decoded_pred[:, 5] - target_params[:, 5]).item()\n", " \n", " errors = [tx_error, ty_error, rx_error, ry_error, rz_error, scale_error]\n", " targets = target_params[0].cpu().numpy()\n", " preds = decoded_pred[0].cpu().numpy()\n", " \n", " for j in range(6):\n", " all_errors[j].append(errors[j])\n", " all_targets[j].append(targets[j])\n", " all_preds[j].append(preds[j])\n", " \n", " mean_errors = [np.mean(all_errors[i]) for i in range(6)]\n", " std_errors = [np.std(all_errors[i]) for i in range(6)]\n", " \n", " if len(trainer.train_losses) > 0:\n", " epochs = range(1, len(trainer.train_losses) + 1)\n", " fig, axes = plt.subplots(2, 2, figsize=(16, 12))\n", " \n", " axes[0, 0].plot(epochs, trainer.train_losses, \"b-\", label=\"Train Loss\")\n", " axes[0, 0].plot(epochs, trainer.val_losses, \"r-\", label=\"Val Loss\")\n", " axes[0, 0].set_xlabel(\"Epoch\")\n", " axes[0, 0].set_ylabel(\"Loss\")\n", " axes[0, 0].set_title(\"Training & Validation Loss\")\n", " axes[0, 0].legend()\n", " axes[0, 0].grid(True, alpha=0.3)\n", " \n", " axes[0, 1].plot(epochs, trainer.val_losses, \"r-\", label=\"Val Loss\")\n", " axes[0, 1].set_xlabel(\"Epoch\")\n", " axes[0, 1].set_ylabel(\"Loss\")\n", " axes[0, 1].set_title(\"Validation Loss\")\n", " axes[0, 1].legend()\n", " axes[0, 1].grid(True, alpha=0.3)\n", " \n", " axes[1, 0].plot(epochs, trainer.val_mse_trans, \"g-\", label=\"Translation (tx, ty)\")\n", " axes[1, 0].plot(epochs, trainer.val_mse_angle, \"m-\", label=\"Angle (rx, ry, rz)\")\n", " axes[1, 0].plot(epochs, trainer.val_mse_scale, \"c-\", label=\"Scale\")\n", " axes[1, 0].set_xlabel(\"Epoch\")\n", " axes[1, 0].set_ylabel(\"MSE\")\n", " axes[1, 0].set_title(\"Validation MSE by Category\")\n", " axes[1, 0].legend()\n", " axes[1, 0].grid(True, alpha=0.3)\n", " \n", " x_pos = np.arange(6)\n", " axes[1, 1].bar(x_pos, mean_errors, yerr=std_errors, capsize=5, color=[\"c\", \"m\", \"y\", \"g\", \"b\", \"r\"], alpha=0.8)\n", " axes[1, 1].set_xticks(x_pos)\n", " axes[1, 1].set_xticklabels(names)\n", " axes[1, 1].set_ylabel(\"Mean Absolute Error\")\n", " axes[1, 1].set_title(f\"Mean Absolute Error per Parameter ({n_samples} samples)\")\n", " axes[1, 1].grid(True, alpha=0.3, axis=\"y\")\n", " \n", " plt.tight_layout()\n", " plt.savefig(os.path.join(IMG_DIR, \"training_loss_plots.png\"), dpi=150)\n", " print(\"Saved training_loss_plots.png\")\n", " plt.show()\n", " \n", " fig, axes = plt.subplots(2, 3, figsize=(18, 10))\n", " for j in range(6):\n", " row = j // 3\n", " col = j % 3\n", " axes[row, col].bar(range(n_samples), all_errors[j], color=\"steelblue\", alpha=0.7)\n", " axes[row, col].set_xlabel(\"Sample\")\n", " axes[row, col].set_ylabel(\"Absolute Error\")\n", " axes[row, col].set_title(f\"{names[j]}: Mean={np.mean(all_errors[j]):.4f}, Std={np.std(all_errors[j]):.4f}\")\n", " axes[row, col].grid(True, alpha=0.3, axis=\"y\")\n", " plt.suptitle(f\"Mean Absolute Error per Parameter ({n_samples} samples)\", fontsize=14)\n", " plt.tight_layout()\n", " plt.savefig(os.path.join(IMG_DIR, \"mae_per_parameter.png\"), dpi=150)\n", " print(\"Saved mae_per_parameter.png\")\n", " plt.show()\n", " \n", " fig, axes = plt.subplots(1, 2, figsize=(14, 6))\n", " \n", " x_pos = np.arange(6)\n", " axes[0].bar(x_pos, mean_errors, yerr=std_errors, capsize=5, color=[\"c\", \"m\", \"y\", \"g\", \"b\", \"r\"], alpha=0.8)\n", " axes[0].set_xticks(x_pos)\n", " axes[0].set_xticklabels(names)\n", " axes[0].set_ylabel(\"Mean Absolute Error\")\n", " axes[0].set_title(\"Mean Absolute Error per Parameter (with std)\")\n", " axes[0].grid(True, alpha=0.3, axis=\"y\")\n", " \n", " bp = axes[1].boxplot([all_errors[i] for i in range(6)], labels=names, patch_artist=True)\n", " colors = [\"c\", \"m\", \"y\", \"g\", \"b\", \"r\"]\n", " for patch, color in zip(bp[\"boxes\"], colors):\n", " patch.set_facecolor(color)\n", " patch.set_alpha(0.7)\n", " axes[1].set_ylabel(\"Absolute Error\")\n", " axes[1].set_title(f\"Error Distribution per Parameter ({n_samples} samples)\")\n", " axes[1].grid(True, alpha=0.3, axis=\"y\")\n", " \n", " plt.tight_layout()\n", " plt.savefig(os.path.join(IMG_DIR, \"mae_boxplot.png\"), dpi=150)\n", " print(\"Saved mae_boxplot.png\")\n", " plt.show()\n", " \n", " print(\"\\n=== Sample Predictions (20 pairs) ===\")\n", " n_vis_samples = 20\n", " \n", " with torch.no_grad():\n", " for sample_idx in range(n_vis_samples):\n", " try:\n", " batch = next(iter(trainer.val_loader))\n", " except StopIteration:\n", " break\n", " google_img = batch[\"google_img\"].to(trainer.device)\n", " yandex_img = batch[\"yandex_img\"].to(trainer.device)\n", " target_params = batch[\"homography_params\"].to(trainer.device)\n", " pred_params = trainer.model(google_img, yandex_img)\n", " decoded_pred = trainer.model.decode_output(pred_params)\n", " \n", " tx_error = torch.abs(decoded_pred[:, 0] - target_params[:, 0]).cpu().numpy()\n", " ty_error = torch.abs(decoded_pred[:, 1] - target_params[:, 1]).cpu().numpy()\n", " rx_error = angular_difference(decoded_pred[:, 2], target_params[:, 2]).cpu().numpy()\n", " ry_error = angular_difference(decoded_pred[:, 3], target_params[:, 3]).cpu().numpy()\n", " rz_error = angular_difference(decoded_pred[:, 4], target_params[:, 4]).cpu().numpy()\n", " scale_error = torch.abs(decoded_pred[:, 5] - target_params[:, 5]).cpu().numpy()\n", " \n", " errors = np.array([tx_error[0], ty_error[0], rx_error[0], ry_error[0], rz_error[0], scale_error[0]])\n", " targets = target_params[0].cpu().numpy()\n", " preds = decoded_pred[0].cpu().numpy()\n", " \n", " fig, axes = plt.subplots(2, 2, figsize=(12, 10))\n", " \n", " axes[0, 0].imshow(google_img[0].cpu().permute(1, 2, 0))\n", " axes[0, 0].set_title(f\"Google Image\")\n", " axes[0, 0].axis(\"off\")\n", " \n", " axes[0, 1].imshow(yandex_img[0].cpu().permute(1, 2, 0))\n", " axes[0, 1].set_title(f\"Yandex Image\")\n", " axes[0, 1].axis(\"off\")\n", " \n", " x_pos = np.arange(6)\n", " width = 0.35\n", " axes[1, 0].bar(x_pos - width/2, targets, width, label=\"Target\", color=\"steelblue\", alpha=0.8)\n", " axes[1, 0].bar(x_pos + width/2, preds, width, label=\"Predicted\", color=\"coral\", alpha=0.8)\n", " axes[1, 0].set_xticks(x_pos)\n", " axes[1, 0].set_xticklabels(names)\n", " axes[1, 0].set_ylabel(\"Parameter Value\")\n", " axes[1, 0].set_title(\"Target vs Predicted\")\n", " axes[1, 0].legend()\n", " axes[1, 0].grid(True, alpha=0.3, axis=\"y\")\n", " \n", " axes[1, 1].bar(x_pos, errors, color=[\"c\", \"m\", \"y\", \"g\", \"b\", \"r\"], alpha=0.8)\n", " axes[1, 1].set_xticks(x_pos)\n", " axes[1, 1].set_xticklabels(names)\n", " axes[1, 1].set_ylabel(\"Absolute Error\")\n", " axes[1, 1].set_title(f\"Prediction Error (Mean: {np.mean(errors):.4f})\")\n", " axes[1, 1].grid(True, alpha=0.3, axis=\"y\")\n", " for i, e in enumerate(errors):\n", " axes[1, 1].text(i, e + 0.01, f\"{e:.3f}\", ha=\"center\", va=\"bottom\", fontsize=8)\n", " \n", " plt.suptitle(f\"Sample {sample_idx + 1}\", fontsize=14)\n", " plt.tight_layout()\n", " plt.savefig(os.path.join(IMG_DIR, f\"prediction_sample_{sample_idx + 1:02d}.png\"), dpi=100)\n", " plt.show()\n", " print(f\"Saved prediction_sample_{sample_idx + 1:02d}.png\")\n", " \n", " print(f\"\\nPrediction errors over {n_samples} samples:\")\n", " print(f\"{'Param':<8} {'Mean Error':>12} {'Std Error':>12} {'Min':>8} {'Max':>8}\")\n", " print(\"-\" * 52)\n", " for i in range(6):\n", " mean_err = np.mean(all_errors[i])\n", " std_err = np.std(all_errors[i])\n", " min_err = np.min(all_errors[i])\n", " max_err = np.max(all_errors[i])\n", " print(f\"{names[i]:<8} {mean_err:>12.4f} {std_err:>12.4f} {min_err:>8.4f} {max_err:>8.4f}\")\n", "\n", " return {\n", " \"best_val_loss\": trainer.best_val_loss,\n", " \"train_losses\": trainer.train_losses,\n", " \"val_losses\": trainer.val_losses,\n", " \"val_mse_trans\": trainer.val_mse_trans,\n", " \"val_mse_angle\": trainer.val_mse_angle,\n", " \"val_mse_scale\": trainer.val_mse_scale,\n", " }\n", "\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": "## Main Pipeline\n\nExecutes the full training workflow:\n1. Load dataset info\n2. Create data loaders\n3. Initialize model\n4. Train with validation\n5. Analyze and export results\n\n**Outputs:**\n- Model checkpoints in `runs/checkpoints/`\n- TensorBoard logs in `runs/`\n" }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "\n", "\n", "\n", "\n", "logging.basicConfig(level=logging.INFO, format=\"%(asctime)s - %(message)s\")\n", "logger = logging.getLogger(__name__)\n", "\n", "logger.info(\"=\" * 50)\n", "logger.info(\"SiaN Training Pipeline\")\n", "logger.info(\"=\" * 50)\n", "\n", "dataset_info = get_dataset_info()\n", "logger.info(f\"Dataset: {dataset_info['size']} samples, keys={dataset_info['sample_keys']}\")\n", "\n", "train_loader, val_loader = create_data_loaders(\n", " root_dir=config[\"data_dir\"],\n", " batch_size=config[\"batch_size\"],\n", " train_split=config[\"train_split\"],\n", " num_workers=config[\"num_workers\"],\n", " image_size=config[\"image_size\"],\n", ")\n", "logger.info(f\"Data loaders created: train={len(train_loader.dataset)}, val={len(val_loader.dataset)}\")\n", "\n", "model = HomographyCNN6(\n", " input_channels=3,\n", " backbone_name=config[\"backbone\"],\n", " pretrained=True,\n", " dropout_rate=config[\"dropout_rate\"]\n", ")\n", "logger.info(f\"Model created with {count_parameters(model):,} parameters\")\n", "\n", "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", "logger.info(f\"Using device: {device}\")\n", "\n", "trainer = HomographyTrainer(model, train_loader, val_loader, device)\n", "logger.info(\"Starting training...\")\n", "trainer.train(config[\"epochs\"])\n", "logger.info(\"Training completed\")\n", "\n", "logger.info(\"Analyzing model...\")\n", "results = analyze_training(trainer)\n", "logger.info(f\"Analysis complete: best_val_loss={results['best_val_loss']:.4f}\")\n", "\n", "logger.info(\"=\" * 50)\n", "logger.info(\"Pipeline completed successfully\")\n", "logger.info(\"=\" * 50)\n", "\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "!zip artefacts.zip runs/checkpoints/best_model.pt runs/images/ runs/events.*\n", "\n" ] } ], "metadata": { "kernelspec": { "display_name": ".venv", "language": "python", "name": "python3" }, "language_info": { "name": "python", "version": "3.11.0" } }, "nbformat": 4, "nbformat_minor": 5 }