|
|
|
|
@@ -2,7 +2,8 @@
|
|
|
|
|
"cells": [
|
|
|
|
|
{
|
|
|
|
|
"cell_type": "code",
|
|
|
|
|
"execution_count": null,
|
|
|
|
|
"execution_count": 35,
|
|
|
|
|
"id": "4230e9cd",
|
|
|
|
|
"metadata": {},
|
|
|
|
|
"outputs": [],
|
|
|
|
|
"source": [
|
|
|
|
|
@@ -17,6 +18,7 @@
|
|
|
|
|
"import torch.optim as optim\n",
|
|
|
|
|
"import matplotlib.pyplot as plt\n",
|
|
|
|
|
"from PIL import Image\n",
|
|
|
|
|
"import seaborn as sns\n",
|
|
|
|
|
"from torch.utils.data import DataLoader, Dataset, Subset\n",
|
|
|
|
|
"from torch.utils.tensorboard import SummaryWriter\n",
|
|
|
|
|
"from torchvision import transforms, models\n",
|
|
|
|
|
@@ -25,12 +27,21 @@
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
"cell_type": "markdown",
|
|
|
|
|
"id": "d41ef314",
|
|
|
|
|
"metadata": {},
|
|
|
|
|
"source": "# Configuration\n\nGlobal settings for:\n- Data paths and image parameters\n- Training hyperparameters\n- Model architecture options\n"
|
|
|
|
|
"source": [
|
|
|
|
|
"# Configuration\n",
|
|
|
|
|
"\n",
|
|
|
|
|
"Global settings for:\n",
|
|
|
|
|
"- Data paths and image parameters\n",
|
|
|
|
|
"- Training hyperparameters\n",
|
|
|
|
|
"- Model architecture options\n"
|
|
|
|
|
]
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
"cell_type": "code",
|
|
|
|
|
"execution_count": null,
|
|
|
|
|
"execution_count": 36,
|
|
|
|
|
"id": "4463d9d3",
|
|
|
|
|
"metadata": {},
|
|
|
|
|
"outputs": [],
|
|
|
|
|
"source": [
|
|
|
|
|
@@ -96,12 +107,25 @@
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
"cell_type": "markdown",
|
|
|
|
|
"id": "e5a40be4",
|
|
|
|
|
"metadata": {},
|
|
|
|
|
"source": "## Dataset\n\nGoogle/Yandex image pair loader with homography augmentation.\n\n**Features:**\n- Loads paired images from dual camera sources\n- Applies random homography transformations\n- Supports configurable train/val split\n\n**Returns:**\n"
|
|
|
|
|
"source": [
|
|
|
|
|
"## Dataset\n",
|
|
|
|
|
"\n",
|
|
|
|
|
"Google/Yandex image pair loader with homography augmentation.\n",
|
|
|
|
|
"\n",
|
|
|
|
|
"**Features:**\n",
|
|
|
|
|
"- Loads paired images from dual camera sources\n",
|
|
|
|
|
"- Applies random homography transformations\n",
|
|
|
|
|
"- Supports configurable train/val split\n",
|
|
|
|
|
"\n",
|
|
|
|
|
"**Returns:**\n"
|
|
|
|
|
]
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
"cell_type": "code",
|
|
|
|
|
"execution_count": null,
|
|
|
|
|
"execution_count": 37,
|
|
|
|
|
"id": "37358bfe",
|
|
|
|
|
"metadata": {},
|
|
|
|
|
"outputs": [],
|
|
|
|
|
"source": [
|
|
|
|
|
@@ -242,9 +266,39 @@
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
"cell_type": "code",
|
|
|
|
|
"execution_count": null,
|
|
|
|
|
"execution_count": 38,
|
|
|
|
|
"id": "9fee48b8",
|
|
|
|
|
"metadata": {},
|
|
|
|
|
"outputs": [],
|
|
|
|
|
"outputs": [
|
|
|
|
|
{
|
|
|
|
|
"name": "stdout",
|
|
|
|
|
"output_type": "stream",
|
|
|
|
|
"text": [
|
|
|
|
|
"tensor([[ 1.0661e+00, -4.5570e-02, -2.1649e+01],\n",
|
|
|
|
|
" [ 3.7228e-02, 9.1285e-01, 1.3237e+01],\n",
|
|
|
|
|
" [ 5.3873e-04, -6.4843e-04, 9.6602e-01]])\n",
|
|
|
|
|
"tensor([-0.0379, 0.0183, -0.0856, -0.0661, -0.0375, 0.9617])\n",
|
|
|
|
|
"[[ 1.0660727e+00 -4.5569740e-02 -2.1648926e+01]\n",
|
|
|
|
|
" [ 3.7227791e-02 9.1284692e-01 1.3237175e+01]\n",
|
|
|
|
|
" [ 5.3873385e-04 -6.4843398e-04 9.6602309e-01]]\n",
|
|
|
|
|
"[ 0. 0. -0.08696619 -0.0720376 -0.03181122 0.9519815 ]\n"
|
|
|
|
|
]
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
"name": "stderr",
|
|
|
|
|
"output_type": "stream",
|
|
|
|
|
"text": [
|
|
|
|
|
"c:\\Users\\admin\\Projects\\autopilot\\.venv\\Lib\\site-packages\\torch\\utils\\data\\dataloader.py:775: UserWarning: 'pin_memory' argument is set as true but no accelerator is found, then device pinned memory won't be used.\n",
|
|
|
|
|
" super().__init__(loader)\n",
|
|
|
|
|
"C:\\Users\\admin\\AppData\\Local\\Temp\\ipykernel_19016\\897759767.py:32: DeprecationWarning: __array_wrap__ must accept context and return_scalar arguments (positionally) in the future. (Deprecated NumPy 2.0)\n",
|
|
|
|
|
" cy, sy = np.cos(rz), np.sin(rz)\n",
|
|
|
|
|
"C:\\Users\\admin\\AppData\\Local\\Temp\\ipykernel_19016\\897759767.py:33: DeprecationWarning: __array_wrap__ must accept context and return_scalar arguments (positionally) in the future. (Deprecated NumPy 2.0)\n",
|
|
|
|
|
" cp, sp = np.cos(ry), np.sin(ry)\n",
|
|
|
|
|
"C:\\Users\\admin\\AppData\\Local\\Temp\\ipykernel_19016\\897759767.py:34: DeprecationWarning: __array_wrap__ must accept context and return_scalar arguments (positionally) in the future. (Deprecated NumPy 2.0)\n",
|
|
|
|
|
" cr, sr = np.cos(rx), np.sin(rx)\n"
|
|
|
|
|
]
|
|
|
|
|
}
|
|
|
|
|
],
|
|
|
|
|
"source": [
|
|
|
|
|
"\n",
|
|
|
|
|
"train_loader, val_loader = create_data_loaders(config['data_dir'])\n",
|
|
|
|
|
@@ -265,12 +319,27 @@
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
"cell_type": "markdown",
|
|
|
|
|
"id": "e8072ee6",
|
|
|
|
|
"metadata": {},
|
|
|
|
|
"source": "## Model\n\n`HomographyCNN6` — CNN architecture for homography estimation.\n\n**Output:** 6 parameters\n- `rx, ry, rz` — rotation angles (radians)\n- `tx, ty` — translation offsets\n- `scale` — isotropic scale factor\n\n**Architecture:**\n- Dual-branch CNN (Google + Yandex images)\n- Shared backbone (configurable: resnet18/34/50)\n"
|
|
|
|
|
"source": [
|
|
|
|
|
"## Model\n",
|
|
|
|
|
"\n",
|
|
|
|
|
"`HomographyCNN6` — CNN architecture for homography estimation.\n",
|
|
|
|
|
"\n",
|
|
|
|
|
"**Output:** 6 parameters\n",
|
|
|
|
|
"- `rx, ry, rz` — rotation angles (radians)\n",
|
|
|
|
|
"- `tx, ty` — translation offsets\n",
|
|
|
|
|
"- `scale` — isotropic scale factor\n",
|
|
|
|
|
"\n",
|
|
|
|
|
"**Architecture:**\n",
|
|
|
|
|
"- Dual-branch CNN (Google + Yandex images)\n",
|
|
|
|
|
"- Shared backbone (configurable: resnet18/34/50)\n"
|
|
|
|
|
]
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
"cell_type": "code",
|
|
|
|
|
"execution_count": null,
|
|
|
|
|
"execution_count": 39,
|
|
|
|
|
"id": "c246531d",
|
|
|
|
|
"metadata": {},
|
|
|
|
|
"outputs": [],
|
|
|
|
|
"source": [
|
|
|
|
|
@@ -302,11 +371,19 @@
|
|
|
|
|
" nn.Dropout(dropout_rate),\n",
|
|
|
|
|
" nn.Linear(256, 6),\n",
|
|
|
|
|
" )\n",
|
|
|
|
|
" self._init_weights()\n",
|
|
|
|
|
"\n",
|
|
|
|
|
" def _normalize_sin_cos(self, _sin, _cos):\n",
|
|
|
|
|
" _len = torch.sqrt(_sin ** 2 + _cos ** 2)\n",
|
|
|
|
|
" return _sin / _len, _cos / _len\n",
|
|
|
|
|
"\n",
|
|
|
|
|
" def _init_weights(self):\n",
|
|
|
|
|
" for module in self.head.modules():\n",
|
|
|
|
|
" if isinstance(module, nn.Linear):\n",
|
|
|
|
|
" nn.init.kaiming_normal_(module.weight, mode='fan_in', nonlinearity='relu')\n",
|
|
|
|
|
" if module.bias is not None:\n",
|
|
|
|
|
" nn.init.zeros_(module.bias)\n",
|
|
|
|
|
"\n",
|
|
|
|
|
" def forward(self, img1, img2):\n",
|
|
|
|
|
" f1 = self.backbone(img1)\n",
|
|
|
|
|
" f2 = self.backbone(img2)\n",
|
|
|
|
|
@@ -342,6 +419,126 @@
|
|
|
|
|
" }\n",
|
|
|
|
|
"\n",
|
|
|
|
|
"\n",
|
|
|
|
|
"class HomographyHybridCNN(nn.Module):\n",
|
|
|
|
|
" def __init__(self, input_channels=3, use_resnet_layers=2, dropout_rate=0.3):\n",
|
|
|
|
|
" super().__init__()\n",
|
|
|
|
|
" \n",
|
|
|
|
|
" if use_resnet_layers == 1:\n",
|
|
|
|
|
" resnet = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1)\n",
|
|
|
|
|
" self.conv1 = resnet.conv1\n",
|
|
|
|
|
" self.bn1 = resnet.bn1\n",
|
|
|
|
|
" self.relu = resnet.relu\n",
|
|
|
|
|
" self.maxpool = resnet.maxpool\n",
|
|
|
|
|
" conv_out_channels = 64\n",
|
|
|
|
|
" elif use_resnet_layers == 2:\n",
|
|
|
|
|
" resnet = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1)\n",
|
|
|
|
|
" self.conv1 = resnet.conv1\n",
|
|
|
|
|
" self.bn1 = resnet.bn1\n",
|
|
|
|
|
" self.relu = resnet.relu\n",
|
|
|
|
|
" self.maxpool = resnet.maxpool\n",
|
|
|
|
|
" self.conv2 = resnet.layer1[0].conv1\n",
|
|
|
|
|
" self.bn2 = resnet.layer1[0].bn1\n",
|
|
|
|
|
" self.conv2_2 = resnet.layer1[0].conv2\n",
|
|
|
|
|
" self.bn2_2 = resnet.layer1[0].bn2\n",
|
|
|
|
|
" self.relu2 = resnet.layer1[0].relu\n",
|
|
|
|
|
" self.maxpool2 = resnet.maxpool\n",
|
|
|
|
|
" conv_out_channels = 64\n",
|
|
|
|
|
" else:\n",
|
|
|
|
|
" raise ValueError(\"use_resnet_layers must be 1 or 2\")\n",
|
|
|
|
|
" \n",
|
|
|
|
|
" self.use_resnet_layers = use_resnet_layers\n",
|
|
|
|
|
" self.feature_map_size = 64\n",
|
|
|
|
|
" \n",
|
|
|
|
|
" self.conv_head = nn.Sequential(\n",
|
|
|
|
|
" nn.Conv2d(conv_out_channels, 128, kernel_size=3, padding=1),\n",
|
|
|
|
|
" nn.BatchNorm2d(128),\n",
|
|
|
|
|
" nn.ReLU(inplace=True),\n",
|
|
|
|
|
" nn.Conv2d(128, 256, kernel_size=3, padding=1),\n",
|
|
|
|
|
" nn.BatchNorm2d(256),\n",
|
|
|
|
|
" nn.ReLU(inplace=True),\n",
|
|
|
|
|
" nn.MaxPool2d(2),\n",
|
|
|
|
|
" )\n",
|
|
|
|
|
" \n",
|
|
|
|
|
" self.global_pool = nn.AdaptiveAvgPool2d((1, 1))\n",
|
|
|
|
|
" \n",
|
|
|
|
|
" feature_dim = 256 * 4\n",
|
|
|
|
|
" \n",
|
|
|
|
|
" self.head = nn.Sequential(\n",
|
|
|
|
|
" nn.Linear(feature_dim, 1024),\n",
|
|
|
|
|
" nn.ReLU(inplace=True),\n",
|
|
|
|
|
" nn.Dropout(dropout_rate),\n",
|
|
|
|
|
" nn.Linear(1024, 512),\n",
|
|
|
|
|
" nn.ReLU(inplace=True),\n",
|
|
|
|
|
" nn.Dropout(dropout_rate),\n",
|
|
|
|
|
" nn.Linear(512, 256),\n",
|
|
|
|
|
" nn.ReLU(inplace=True),\n",
|
|
|
|
|
" nn.Dropout(dropout_rate),\n",
|
|
|
|
|
" nn.Linear(256, 6),\n",
|
|
|
|
|
" )\n",
|
|
|
|
|
" self._init_weights()\n",
|
|
|
|
|
"\n",
|
|
|
|
|
" def _init_weights(self):\n",
|
|
|
|
|
" for module in self.head.modules():\n",
|
|
|
|
|
" if isinstance(module, nn.Linear):\n",
|
|
|
|
|
" nn.init.kaiming_normal_(module.weight, mode='fan_in', nonlinearity='relu')\n",
|
|
|
|
|
" if module.bias is not None:\n",
|
|
|
|
|
" nn.init.zeros_(module.bias)\n",
|
|
|
|
|
"\n",
|
|
|
|
|
" def forward(self, img1, img2):\n",
|
|
|
|
|
" x1 = self._extract_features(img1)\n",
|
|
|
|
|
" x2 = self._extract_features(img2)\n",
|
|
|
|
|
" \n",
|
|
|
|
|
" combined = torch.cat([x1, x2, torch.abs(x1 - x2), x1 * x2], dim=1)\n",
|
|
|
|
|
" output = self.head(combined)\n",
|
|
|
|
|
" \n",
|
|
|
|
|
" output = torch.tanh(output)\n",
|
|
|
|
|
" modified = output.clone()\n",
|
|
|
|
|
" modified[:, 2:6] = torch.mul(output[:, 2:6], torch.pi)\n",
|
|
|
|
|
" \n",
|
|
|
|
|
" return modified\n",
|
|
|
|
|
"\n",
|
|
|
|
|
" def _extract_features(self, x):\n",
|
|
|
|
|
" x = self.conv1(x)\n",
|
|
|
|
|
" x = self.bn1(x)\n",
|
|
|
|
|
" x = self.relu(x)\n",
|
|
|
|
|
" x = self.maxpool(x)\n",
|
|
|
|
|
" \n",
|
|
|
|
|
" if self.use_resnet_layers >= 2:\n",
|
|
|
|
|
" x = self.conv2(x)\n",
|
|
|
|
|
" x = self.bn2(x)\n",
|
|
|
|
|
" x = self.relu(x)\n",
|
|
|
|
|
" x = self.conv2_2(x)\n",
|
|
|
|
|
" x = self.bn2_2(x)\n",
|
|
|
|
|
" x = self.relu2(x)\n",
|
|
|
|
|
" x = self.maxpool2(x)\n",
|
|
|
|
|
" \n",
|
|
|
|
|
" x = self.conv_head(x)\n",
|
|
|
|
|
" x = self.global_pool(x)\n",
|
|
|
|
|
" x = x.view(x.size(0), -1)\n",
|
|
|
|
|
" \n",
|
|
|
|
|
" return x\n",
|
|
|
|
|
"\n",
|
|
|
|
|
" def decode_output(self, output):\n",
|
|
|
|
|
" tx = output[:, 0]\n",
|
|
|
|
|
" ty = output[:, 1]\n",
|
|
|
|
|
" scale = output[:, 5]\n",
|
|
|
|
|
" angle1 = output[:, 2]\n",
|
|
|
|
|
" angle2 = output[:, 3]\n",
|
|
|
|
|
" angle3 = output[:, 4]\n",
|
|
|
|
|
" return torch.stack([tx, ty, angle1, angle2, angle3, scale], dim=1)\n",
|
|
|
|
|
"\n",
|
|
|
|
|
" def get_components(self, output):\n",
|
|
|
|
|
" decoded = self.decode_output(output)\n",
|
|
|
|
|
" return {\n",
|
|
|
|
|
" \"tx\": decoded[:, 0],\n",
|
|
|
|
|
" \"ty\": decoded[:, 1],\n",
|
|
|
|
|
" \"rx\": decoded[:, 2],\n",
|
|
|
|
|
" \"ry\": decoded[:, 3],\n",
|
|
|
|
|
" \"rz\": decoded[:, 4],\n",
|
|
|
|
|
" \"scale\": decoded[:, 5],\n",
|
|
|
|
|
" }\n",
|
|
|
|
|
"\n",
|
|
|
|
|
"\n",
|
|
|
|
|
"class HomographyLoss6(nn.Module):\n",
|
|
|
|
|
" def __init__(self, angle_loss_weight=1.0, trans_loss_weight=1.0, scale_loss_weight=1.0):\n",
|
|
|
|
|
" super().__init__()\n",
|
|
|
|
|
@@ -403,6 +600,9 @@
|
|
|
|
|
" }\n",
|
|
|
|
|
"\n",
|
|
|
|
|
"\n",
|
|
|
|
|
"HomographyLoss = HomographyLoss6\n",
|
|
|
|
|
"\n",
|
|
|
|
|
"\n",
|
|
|
|
|
"def count_parameters(model):\n",
|
|
|
|
|
" return sum(p.numel() for p in model.parameters())\n",
|
|
|
|
|
"\n"
|
|
|
|
|
@@ -410,12 +610,28 @@
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
"cell_type": "markdown",
|
|
|
|
|
"id": "d0991e10",
|
|
|
|
|
"metadata": {},
|
|
|
|
|
"source": "## Training\n\n`HomographyTrainer` — training loop with validation and checkpointing.\n\n**Features:**\n- Epoch-based training with tqdm progress bar\n- Adam optimizer with configurable LR\n- Validation after each epoch\n- Best model auto-save\n- Periodic checkpoints (every N epochs via `save_every_n_epochs`)\n\n**Checkpoint saving:**\n- `best_model.pt` — lowest validation loss\n"
|
|
|
|
|
"source": [
|
|
|
|
|
"## Training\n",
|
|
|
|
|
"\n",
|
|
|
|
|
"`HomographyTrainer` — training loop with validation and checkpointing.\n",
|
|
|
|
|
"\n",
|
|
|
|
|
"**Features:**\n",
|
|
|
|
|
"- Epoch-based training with tqdm progress bar\n",
|
|
|
|
|
"- Adam optimizer with configurable LR\n",
|
|
|
|
|
"- Validation after each epoch\n",
|
|
|
|
|
"- Best model auto-save\n",
|
|
|
|
|
"- Periodic checkpoints (every N epochs via `save_every_n_epochs`)\n",
|
|
|
|
|
"\n",
|
|
|
|
|
"**Checkpoint saving:**\n",
|
|
|
|
|
"- `best_model.pt` — lowest validation loss\n"
|
|
|
|
|
]
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
"cell_type": "code",
|
|
|
|
|
"execution_count": null,
|
|
|
|
|
"execution_count": 40,
|
|
|
|
|
"id": "83ff7cc6",
|
|
|
|
|
"metadata": {},
|
|
|
|
|
"outputs": [],
|
|
|
|
|
"source": [
|
|
|
|
|
@@ -424,12 +640,12 @@
|
|
|
|
|
"\n",
|
|
|
|
|
"\n",
|
|
|
|
|
"class HomographyTrainer:\n",
|
|
|
|
|
" def __init__(self, model, train_loader, val_loader, device):\n",
|
|
|
|
|
" def __init__(self, model, train_loader, val_loader, device, criterion):\n",
|
|
|
|
|
" self.model = model.to(device)\n",
|
|
|
|
|
" self.train_loader = train_loader\n",
|
|
|
|
|
" self.val_loader = val_loader\n",
|
|
|
|
|
" self.device = device\n",
|
|
|
|
|
" self.criterion = HomographyLoss6()\n",
|
|
|
|
|
" self.criterion = criterion\n",
|
|
|
|
|
" self.optimizer = optim.Adam(model.parameters(), lr=config[\"learning_rate\"], weight_decay=1e-4)\n",
|
|
|
|
|
" self.writer = None\n",
|
|
|
|
|
" self.best_val_loss = float(\"inf\")\n",
|
|
|
|
|
@@ -537,17 +753,28 @@
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
"cell_type": "markdown",
|
|
|
|
|
"id": "d6dbc5ea",
|
|
|
|
|
"metadata": {},
|
|
|
|
|
"source": "## Analysis\n\nVisualization and evaluation tools:\n\n- Training metrics plots (loss curves)\n- Prediction visualization on sample images\n"
|
|
|
|
|
"source": [
|
|
|
|
|
"## Analysis\n",
|
|
|
|
|
"\n",
|
|
|
|
|
"Visualization and evaluation tools:\n",
|
|
|
|
|
"\n",
|
|
|
|
|
"- Training metrics plots (loss curves)\n",
|
|
|
|
|
"- Prediction visualization on sample images\n"
|
|
|
|
|
]
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
"cell_type": "code",
|
|
|
|
|
"execution_count": null,
|
|
|
|
|
"execution_count": 41,
|
|
|
|
|
"id": "6fac17d5",
|
|
|
|
|
"metadata": {},
|
|
|
|
|
"outputs": [],
|
|
|
|
|
"source": [
|
|
|
|
|
"\n",
|
|
|
|
|
"\n",
|
|
|
|
|
"sns.set_theme(style=\"whitegrid\", palette=\"muted\", font_scale=1.2)\n",
|
|
|
|
|
"\n",
|
|
|
|
|
"IMG_DIR = os.path.join(config[\"output_dir\"], \"images\")\n",
|
|
|
|
|
"os.makedirs(IMG_DIR, exist_ok=True)\n",
|
|
|
|
|
"\n",
|
|
|
|
|
@@ -623,84 +850,132 @@
|
|
|
|
|
" mean_errors = [np.mean(all_errors[i]) for i in range(6)]\n",
|
|
|
|
|
" std_errors = [np.std(all_errors[i]) for i in range(6)]\n",
|
|
|
|
|
" \n",
|
|
|
|
|
" angle_errors_deg = [np.degrees(mean_errors[i]) for i in range(2, 5)]\n",
|
|
|
|
|
" \n",
|
|
|
|
|
" all_targets_stacked = [np.array(all_targets[i]) for i in range(6)]\n",
|
|
|
|
|
" target_ranges = [np.ptp(all_targets_stacked[i]) for i in range(6)]\n",
|
|
|
|
|
" relative_errors = [mean_errors[i] / target_ranges[i] if target_ranges[i] > 1e-8 else 0 for i in range(6)]\n",
|
|
|
|
|
" \n",
|
|
|
|
|
" if len(trainer.train_losses) > 0:\n",
|
|
|
|
|
" epochs = range(1, len(trainer.train_losses) + 1)\n",
|
|
|
|
|
" fig, axes = plt.subplots(2, 2, figsize=(16, 12))\n",
|
|
|
|
|
" \n",
|
|
|
|
|
" axes[0, 0].plot(epochs, trainer.train_losses, \"b-\", label=\"Train Loss\")\n",
|
|
|
|
|
" axes[0, 0].plot(epochs, trainer.val_losses, \"r-\", label=\"Val Loss\")\n",
|
|
|
|
|
" axes[0, 0].plot(epochs, trainer.train_losses, color=\"#2ecc71\", linewidth=2, label=\"Train Loss\")\n",
|
|
|
|
|
" axes[0, 0].plot(epochs, trainer.val_losses, color=\"#e74c3c\", linewidth=2, label=\"Val Loss\")\n",
|
|
|
|
|
" axes[0, 0].set_xlabel(\"Epoch\")\n",
|
|
|
|
|
" axes[0, 0].set_ylabel(\"Loss\")\n",
|
|
|
|
|
" axes[0, 0].set_title(\"Training & Validation Loss\")\n",
|
|
|
|
|
" axes[0, 0].legend()\n",
|
|
|
|
|
" axes[0, 0].set_title(\"Training & Validation Loss\", fontweight=\"bold\")\n",
|
|
|
|
|
" axes[0, 0].legend(framealpha=0.9)\n",
|
|
|
|
|
" axes[0, 0].grid(True, alpha=0.3)\n",
|
|
|
|
|
" \n",
|
|
|
|
|
" axes[0, 1].plot(epochs, trainer.val_losses, \"r-\", label=\"Val Loss\")\n",
|
|
|
|
|
" axes[0, 1].plot(epochs, trainer.val_losses, color=\"#e74c3c\", linewidth=2, label=\"Val Loss\")\n",
|
|
|
|
|
" axes[0, 1].set_xlabel(\"Epoch\")\n",
|
|
|
|
|
" axes[0, 1].set_ylabel(\"Loss\")\n",
|
|
|
|
|
" axes[0, 1].set_title(\"Validation Loss\")\n",
|
|
|
|
|
" axes[0, 1].legend()\n",
|
|
|
|
|
" axes[0, 1].set_title(\"Validation Loss\", fontweight=\"bold\")\n",
|
|
|
|
|
" axes[0, 1].legend(framealpha=0.9)\n",
|
|
|
|
|
" axes[0, 1].grid(True, alpha=0.3)\n",
|
|
|
|
|
" \n",
|
|
|
|
|
" axes[1, 0].plot(epochs, trainer.val_mse_trans, \"g-\", label=\"Translation (tx, ty)\")\n",
|
|
|
|
|
" axes[1, 0].plot(epochs, trainer.val_mse_angle, \"m-\", label=\"Angle (rx, ry, rz)\")\n",
|
|
|
|
|
" axes[1, 0].plot(epochs, trainer.val_mse_scale, \"c-\", label=\"Scale\")\n",
|
|
|
|
|
" axes[1, 0].plot(epochs, trainer.val_mse_trans, color=\"#3498db\", linewidth=2, label=\"Translation (tx, ty)\")\n",
|
|
|
|
|
" axes[1, 0].plot(epochs, trainer.val_mse_angle, color=\"#9b59b6\", linewidth=2, label=\"Angle (rx, ry, rz)\")\n",
|
|
|
|
|
" axes[1, 0].plot(epochs, trainer.val_mse_scale, color=\"#e67e22\", linewidth=2, label=\"Scale\")\n",
|
|
|
|
|
" axes[1, 0].set_xlabel(\"Epoch\")\n",
|
|
|
|
|
" axes[1, 0].set_ylabel(\"MSE\")\n",
|
|
|
|
|
" axes[1, 0].set_title(\"Validation MSE by Category\")\n",
|
|
|
|
|
" axes[1, 0].legend()\n",
|
|
|
|
|
" axes[1, 0].set_title(\"Validation MSE by Category\", fontweight=\"bold\")\n",
|
|
|
|
|
" axes[1, 0].legend(framealpha=0.9)\n",
|
|
|
|
|
" axes[1, 0].grid(True, alpha=0.3)\n",
|
|
|
|
|
" \n",
|
|
|
|
|
" x_pos = np.arange(6)\n",
|
|
|
|
|
" axes[1, 1].bar(x_pos, mean_errors, yerr=std_errors, capsize=5, color=[\"c\", \"m\", \"y\", \"g\", \"b\", \"r\"], alpha=0.8)\n",
|
|
|
|
|
" colors = [\"#3498db\", \"#e74c3c\", \"#9b59b6\", \"#2ecc71\", \"#f39c12\", \"#1abc9c\"]\n",
|
|
|
|
|
" bars = axes[1, 1].bar(x_pos, mean_errors, yerr=std_errors, capsize=6, color=colors, alpha=0.85, edgecolor=\"white\", linewidth=1.5)\n",
|
|
|
|
|
" axes[1, 1].set_xticks(x_pos)\n",
|
|
|
|
|
" axes[1, 1].set_xticklabels(names)\n",
|
|
|
|
|
" axes[1, 1].set_ylabel(\"Mean Absolute Error\")\n",
|
|
|
|
|
" axes[1, 1].set_title(f\"Mean Absolute Error per Parameter ({n_samples} samples)\")\n",
|
|
|
|
|
" axes[1, 1].set_title(f\"Mean Absolute Error per Parameter ({n_samples} samples)\", fontweight=\"bold\")\n",
|
|
|
|
|
" axes[1, 1].grid(True, alpha=0.3, axis=\"y\")\n",
|
|
|
|
|
" \n",
|
|
|
|
|
" plt.tight_layout()\n",
|
|
|
|
|
" plt.savefig(os.path.join(IMG_DIR, \"training_loss_plots.png\"), dpi=150)\n",
|
|
|
|
|
" plt.savefig(os.path.join(IMG_DIR, \"training_loss_plots.png\"), dpi=150, bbox_inches=\"tight\")\n",
|
|
|
|
|
" print(\"Saved training_loss_plots.png\")\n",
|
|
|
|
|
" plt.show()\n",
|
|
|
|
|
" \n",
|
|
|
|
|
" fig, axes = plt.subplots(2, 3, figsize=(18, 10))\n",
|
|
|
|
|
" colors = [\"#3498db\", \"#e74c3c\", \"#9b59b6\", \"#2ecc71\", \"#f39c12\", \"#1abc9c\"]\n",
|
|
|
|
|
" for j in range(6):\n",
|
|
|
|
|
" row = j // 3\n",
|
|
|
|
|
" col = j % 3\n",
|
|
|
|
|
" axes[row, col].bar(range(len(all_errors[j])), all_errors[j], color=\"steelblue\", alpha=0.7)\n",
|
|
|
|
|
" axes[row, col].set_xlabel(\"Sample\")\n",
|
|
|
|
|
" axes[row, col].set_ylabel(\"Absolute Error\")\n",
|
|
|
|
|
" axes[row, col].set_title(f\"{names[j]}: Mean={np.mean(all_errors[j]):.4f}, Std={np.std(all_errors[j]):.4f}\")\n",
|
|
|
|
|
" axes[row, col].bar(range(len(all_errors[j])), all_errors[j], color=colors[j], alpha=0.75)\n",
|
|
|
|
|
" axes[row, col].set_xlabel(\"Sample\", fontsize=10)\n",
|
|
|
|
|
" axes[row, col].set_ylabel(\"Absolute Error\", fontsize=10)\n",
|
|
|
|
|
" axes[row, col].set_title(f\"{names[j]}: Mean={np.mean(all_errors[j]):.4f}, Std={np.std(all_errors[j]):.4f}\", fontweight=\"bold\", fontsize=11)\n",
|
|
|
|
|
" axes[row, col].grid(True, alpha=0.3, axis=\"y\")\n",
|
|
|
|
|
" plt.suptitle(f\"Mean Absolute Error per Parameter ({n_samples} samples)\", fontsize=14)\n",
|
|
|
|
|
" plt.suptitle(f\"Mean Absolute Error per Parameter ({n_samples} samples)\", fontsize=14, fontweight=\"bold\")\n",
|
|
|
|
|
" plt.tight_layout()\n",
|
|
|
|
|
" plt.savefig(os.path.join(IMG_DIR, \"mae_per_parameter.png\"), dpi=150)\n",
|
|
|
|
|
" plt.savefig(os.path.join(IMG_DIR, \"mae_per_parameter.png\"), dpi=150, bbox_inches=\"tight\")\n",
|
|
|
|
|
" print(\"Saved mae_per_parameter.png\")\n",
|
|
|
|
|
" plt.show()\n",
|
|
|
|
|
" \n",
|
|
|
|
|
" fig, axes = plt.subplots(1, 3, figsize=(18, 6))\n",
|
|
|
|
|
" \n",
|
|
|
|
|
" x_pos = np.arange(6)\n",
|
|
|
|
|
" colors = [\"#3498db\", \"#e74c3c\", \"#9b59b6\", \"#2ecc71\", \"#f39c12\", \"#1abc9c\"]\n",
|
|
|
|
|
" bars = axes[0].bar(x_pos, mean_errors, yerr=std_errors, capsize=6, color=colors, alpha=0.85, edgecolor=\"white\", linewidth=1.5)\n",
|
|
|
|
|
" axes[0].set_xticks(x_pos)\n",
|
|
|
|
|
" axes[0].set_xticklabels(names)\n",
|
|
|
|
|
" axes[0].set_ylabel(\"Mean Absolute Error\")\n",
|
|
|
|
|
" axes[0].set_title(\"Mean Absolute Error per Parameter (with std)\", fontweight=\"bold\")\n",
|
|
|
|
|
" axes[0].grid(True, alpha=0.3, axis=\"y\")\n",
|
|
|
|
|
" \n",
|
|
|
|
|
" bp = axes[1].boxplot([all_errors[i] for i in range(6)], labels=names, patch_artist=True)\n",
|
|
|
|
|
" for patch, color in zip(bp[\"boxes\"], colors):\n",
|
|
|
|
|
" patch.set_facecolor(color)\n",
|
|
|
|
|
" patch.set_alpha(0.8)\n",
|
|
|
|
|
" axes[1].set_ylabel(\"Absolute Error\")\n",
|
|
|
|
|
" axes[1].set_title(f\"Error Distribution per Parameter ({n_samples} samples)\", fontweight=\"bold\")\n",
|
|
|
|
|
" axes[1].grid(True, alpha=0.3, axis=\"y\")\n",
|
|
|
|
|
" \n",
|
|
|
|
|
" rel_err_pos = np.arange(6)\n",
|
|
|
|
|
" bars = axes[2].bar(rel_err_pos, relative_errors, color=colors, alpha=0.85, edgecolor=\"white\", linewidth=1.5)\n",
|
|
|
|
|
" axes[2].set_xticks(rel_err_pos)\n",
|
|
|
|
|
" axes[2].set_xticklabels(names)\n",
|
|
|
|
|
" axes[2].set_ylabel(\"Relative Error (MAE / Range)\")\n",
|
|
|
|
|
" axes[2].set_title(\"Relative Error per Parameter\", fontweight=\"bold\")\n",
|
|
|
|
|
" axes[2].grid(True, alpha=0.3, axis=\"y\")\n",
|
|
|
|
|
" \n",
|
|
|
|
|
" plt.tight_layout()\n",
|
|
|
|
|
" plt.savefig(os.path.join(IMG_DIR, \"mae_boxplot.png\"), dpi=150, bbox_inches=\"tight\")\n",
|
|
|
|
|
" print(\"Saved mae_boxplot.png\")\n",
|
|
|
|
|
" plt.show()\n",
|
|
|
|
|
" \n",
|
|
|
|
|
" fig, axes = plt.subplots(1, 2, figsize=(14, 6))\n",
|
|
|
|
|
" \n",
|
|
|
|
|
" x_pos = np.arange(6)\n",
|
|
|
|
|
" axes[0].bar(x_pos, mean_errors, yerr=std_errors, capsize=5, color=[\"c\", \"m\", \"y\", \"g\", \"b\", \"r\"], alpha=0.8)\n",
|
|
|
|
|
" angle_names = [\"rx\", \"ry\", \"rz\"]\n",
|
|
|
|
|
" x_pos = np.arange(3)\n",
|
|
|
|
|
" colors_angle = [\"#9b59b6\", \"#2ecc71\", \"#f39c12\"]\n",
|
|
|
|
|
" bars = axes[0].bar(x_pos, angle_errors_deg, color=colors_angle, alpha=0.85, edgecolor=\"white\", linewidth=1.5)\n",
|
|
|
|
|
" axes[0].set_xticks(x_pos)\n",
|
|
|
|
|
" axes[0].set_xticklabels(names)\n",
|
|
|
|
|
" axes[0].set_ylabel(\"Mean Absolute Error\")\n",
|
|
|
|
|
" axes[0].set_title(\"Mean Absolute Error per Parameter (with std)\")\n",
|
|
|
|
|
" axes[0].set_xticklabels(angle_names)\n",
|
|
|
|
|
" axes[0].set_ylabel(\"Mean Absolute Error (degrees)\")\n",
|
|
|
|
|
" axes[0].set_title(\"Angle MAE in Degrees\", fontweight=\"bold\")\n",
|
|
|
|
|
" axes[0].grid(True, alpha=0.3, axis=\"y\")\n",
|
|
|
|
|
" for i, e in enumerate(angle_errors_deg):\n",
|
|
|
|
|
" axes[0].text(i, e + 0.5, f\"{e:.1f}°\", ha=\"center\", va=\"bottom\", fontsize=11, fontweight=\"bold\")\n",
|
|
|
|
|
" \n",
|
|
|
|
|
" bp = axes[1].boxplot([all_errors[i] for i in range(6)], labels=names, patch_artist=True)\n",
|
|
|
|
|
" colors = [\"c\", \"m\", \"y\", \"g\", \"b\", \"r\"]\n",
|
|
|
|
|
" for patch, color in zip(bp[\"boxes\"], colors):\n",
|
|
|
|
|
" patch.set_facecolor(color)\n",
|
|
|
|
|
" patch.set_alpha(0.7)\n",
|
|
|
|
|
" axes[1].set_ylabel(\"Absolute Error\")\n",
|
|
|
|
|
" axes[1].set_title(f\"Error Distribution per Parameter ({n_samples} samples)\")\n",
|
|
|
|
|
" trans_scale_errs = [mean_errors[0], mean_errors[1], mean_errors[5]]\n",
|
|
|
|
|
" trans_scale_names = [\"tx\", \"ty\", \"scale\"]\n",
|
|
|
|
|
" x_pos = np.arange(3)\n",
|
|
|
|
|
" colors_trans = [\"#3498db\", \"#e74c3c\", \"#1abc9c\"]\n",
|
|
|
|
|
" bars = axes[1].bar(x_pos, trans_scale_errs, color=colors_trans, alpha=0.85, edgecolor=\"white\", linewidth=1.5)\n",
|
|
|
|
|
" axes[1].set_xticks(x_pos)\n",
|
|
|
|
|
" axes[1].set_xticklabels(trans_scale_names)\n",
|
|
|
|
|
" axes[1].set_ylabel(\"Mean Absolute Error\")\n",
|
|
|
|
|
" axes[1].set_title(\"Translation & Scale MAE\", fontweight=\"bold\")\n",
|
|
|
|
|
" axes[1].grid(True, alpha=0.3, axis=\"y\")\n",
|
|
|
|
|
" for i, e in enumerate(trans_scale_errs):\n",
|
|
|
|
|
" axes[1].text(i, e + 0.01, f\"{e:.4f}\", ha=\"center\", va=\"bottom\", fontsize=11, fontweight=\"bold\")\n",
|
|
|
|
|
" \n",
|
|
|
|
|
" plt.tight_layout()\n",
|
|
|
|
|
" plt.savefig(os.path.join(IMG_DIR, \"mae_boxplot.png\"), dpi=150)\n",
|
|
|
|
|
" print(\"Saved mae_boxplot.png\")\n",
|
|
|
|
|
" plt.savefig(os.path.join(IMG_DIR, \"mae_by_category.png\"), dpi=150, bbox_inches=\"tight\")\n",
|
|
|
|
|
" print(\"Saved mae_by_category.png\")\n",
|
|
|
|
|
" plt.show()\n",
|
|
|
|
|
" \n",
|
|
|
|
|
" print(\"\\n=== Sample Predictions (20 pairs) ===\")\n",
|
|
|
|
|
@@ -737,50 +1012,58 @@
|
|
|
|
|
" fig, axes = plt.subplots(2, 2, figsize=(12, 10))\n",
|
|
|
|
|
" \n",
|
|
|
|
|
" axes[0, 0].imshow(google_img[0].cpu().permute(1, 2, 0))\n",
|
|
|
|
|
" axes[0, 0].set_title(f\"Google Image\")\n",
|
|
|
|
|
" axes[0, 0].set_title(\"Google Image\", fontweight=\"bold\", fontsize=12)\n",
|
|
|
|
|
" axes[0, 0].axis(\"off\")\n",
|
|
|
|
|
" \n",
|
|
|
|
|
" axes[0, 1].imshow(yandex_img[0].cpu().permute(1, 2, 0))\n",
|
|
|
|
|
" axes[0, 1].set_title(f\"Yandex Image\")\n",
|
|
|
|
|
" axes[0, 1].set_title(\"Yandex Image\", fontweight=\"bold\", fontsize=12)\n",
|
|
|
|
|
" axes[0, 1].axis(\"off\")\n",
|
|
|
|
|
" \n",
|
|
|
|
|
" x_pos = np.arange(6)\n",
|
|
|
|
|
" width = 0.35\n",
|
|
|
|
|
" axes[1, 0].bar(x_pos - width/2, targets, width, label=\"Target\", color=\"steelblue\", alpha=0.8)\n",
|
|
|
|
|
" axes[1, 0].bar(x_pos + width/2, preds, width, label=\"Predicted\", color=\"coral\", alpha=0.8)\n",
|
|
|
|
|
" axes[1, 0].bar(x_pos - width/2, targets, width, label=\"Target\", color=\"#3498db\", alpha=0.85)\n",
|
|
|
|
|
" axes[1, 0].bar(x_pos + width/2, preds, width, label=\"Predicted\", color=\"#e74c3c\", alpha=0.85)\n",
|
|
|
|
|
" axes[1, 0].set_xticks(x_pos)\n",
|
|
|
|
|
" axes[1, 0].set_xticklabels(names)\n",
|
|
|
|
|
" axes[1, 0].set_ylabel(\"Parameter Value\")\n",
|
|
|
|
|
" axes[1, 0].set_title(\"Target vs Predicted\")\n",
|
|
|
|
|
" axes[1, 0].legend()\n",
|
|
|
|
|
" axes[1, 0].set_title(\"Target vs Predicted\", fontweight=\"bold\", fontsize=12)\n",
|
|
|
|
|
" axes[1, 0].legend(framealpha=0.9)\n",
|
|
|
|
|
" axes[1, 0].grid(True, alpha=0.3, axis=\"y\")\n",
|
|
|
|
|
" \n",
|
|
|
|
|
" axes[1, 1].bar(x_pos, errors, color=[\"c\", \"m\", \"y\", \"g\", \"b\", \"r\"], alpha=0.8)\n",
|
|
|
|
|
" colors = [\"#3498db\", \"#e74c3c\", \"#9b59b6\", \"#2ecc71\", \"#f39c12\", \"#1abc9c\"]\n",
|
|
|
|
|
" bars = axes[1, 1].bar(x_pos, errors, color=colors, alpha=0.85, edgecolor=\"white\", linewidth=1.2)\n",
|
|
|
|
|
" axes[1, 1].set_xticks(x_pos)\n",
|
|
|
|
|
" axes[1, 1].set_xticklabels(names)\n",
|
|
|
|
|
" axes[1, 1].set_ylabel(\"Absolute Error\")\n",
|
|
|
|
|
" axes[1, 1].set_title(f\"Prediction Error (Mean: {np.mean(errors):.4f})\")\n",
|
|
|
|
|
" axes[1, 1].set_title(f\"Prediction Error (Mean: {np.mean(errors):.4f})\", fontweight=\"bold\", fontsize=12)\n",
|
|
|
|
|
" axes[1, 1].grid(True, alpha=0.3, axis=\"y\")\n",
|
|
|
|
|
" for i_e, e in enumerate(errors):\n",
|
|
|
|
|
" axes[1, 1].text(i_e, e + 0.01, f\"{e:.3f}\", ha=\"center\", va=\"bottom\", fontsize=8)\n",
|
|
|
|
|
" axes[1, 1].text(i_e, e + 0.01, f\"{e:.3f}\", ha=\"center\", va=\"bottom\", fontsize=9)\n",
|
|
|
|
|
" \n",
|
|
|
|
|
" plt.suptitle(f\"Sample {vis_count + 1}\", fontsize=14)\n",
|
|
|
|
|
" plt.suptitle(f\"Sample {vis_count + 1}\", fontsize=14, fontweight=\"bold\")\n",
|
|
|
|
|
" plt.tight_layout()\n",
|
|
|
|
|
" plt.savefig(os.path.join(IMG_DIR, f\"prediction_sample_{vis_count + 1:02d}.png\"), dpi=100)\n",
|
|
|
|
|
" plt.savefig(os.path.join(IMG_DIR, f\"prediction_sample_{vis_count + 1:02d}.png\"), dpi=100, bbox_inches=\"tight\")\n",
|
|
|
|
|
" plt.show()\n",
|
|
|
|
|
" print(f\"Saved prediction_sample_{vis_count + 1:02d}.png\")\n",
|
|
|
|
|
" \n",
|
|
|
|
|
" vis_count += 1\n",
|
|
|
|
|
" \n",
|
|
|
|
|
" print(f\"\\nPrediction errors over {n_samples} samples:\")\n",
|
|
|
|
|
" print(f\"{'Param':<8} {'Mean Error':>12} {'Std Error':>12} {'Min':>8} {'Max':>8}\")\n",
|
|
|
|
|
" print(\"-\" * 52)\n",
|
|
|
|
|
" print(f\"{'Param':<8} {'Mean Error':>12} {'Std Error':>12} {'Min':>8} {'Max':>8} {'Rel Err':>10}\")\n",
|
|
|
|
|
" print(\"-\" * 62)\n",
|
|
|
|
|
" for i in range(6):\n",
|
|
|
|
|
" mean_err = np.mean(all_errors[i])\n",
|
|
|
|
|
" std_err = np.std(all_errors[i])\n",
|
|
|
|
|
" min_err = np.min(all_errors[i])\n",
|
|
|
|
|
" max_err = np.max(all_errors[i])\n",
|
|
|
|
|
" print(f\"{names[i]:<8} {mean_err:>12.4f} {std_err:>12.4f} {min_err:>8.4f} {max_err:>8.4f}\")\n",
|
|
|
|
|
" rel_err = relative_errors[i]\n",
|
|
|
|
|
" print(f\"{names[i]:<8} {mean_err:>12.4f} {std_err:>12.4f} {min_err:>8.4f} {max_err:>8.4f} {rel_err:>10.4f}\")\n",
|
|
|
|
|
" \n",
|
|
|
|
|
" print(f\"\\nAngle errors in degrees:\")\n",
|
|
|
|
|
" print(f\"{'Param':<8} {'MAE (deg)':>12} {'MAE (rad)':>12}\")\n",
|
|
|
|
|
" print(\"-\" * 35)\n",
|
|
|
|
|
" for i, name in enumerate([\"rx\", \"ry\", \"rz\"]):\n",
|
|
|
|
|
" print(f\"{name:<8} {angle_errors_deg[i]:>12.2f} {mean_errors[i+2]:>12.4f}\")\n",
|
|
|
|
|
"\n",
|
|
|
|
|
" return {\n",
|
|
|
|
|
" \"best_val_loss\": trainer.best_val_loss,\n",
|
|
|
|
|
@@ -789,20 +1072,62 @@
|
|
|
|
|
" \"val_mse_trans\": trainer.val_mse_trans,\n",
|
|
|
|
|
" \"val_mse_angle\": trainer.val_mse_angle,\n",
|
|
|
|
|
" \"val_mse_scale\": trainer.val_mse_scale,\n",
|
|
|
|
|
" \"mean_errors\": mean_errors,\n",
|
|
|
|
|
" \"std_errors\": std_errors,\n",
|
|
|
|
|
" \"angle_errors_deg\": angle_errors_deg,\n",
|
|
|
|
|
" \"relative_errors\": relative_errors,\n",
|
|
|
|
|
" }\n",
|
|
|
|
|
"\n"
|
|
|
|
|
]
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
"cell_type": "markdown",
|
|
|
|
|
"id": "e36c328c",
|
|
|
|
|
"metadata": {},
|
|
|
|
|
"source": "## Main Pipeline\n\nExecutes the full training workflow:\n1. Load dataset info\n2. Create data loaders\n3. Initialize model\n4. Train with validation\n5. Analyze and export results\n\n**Outputs:**\n- Model checkpoints in `runs/checkpoints/`\n- TensorBoard logs in `runs/`\n"
|
|
|
|
|
"source": [
|
|
|
|
|
"## Main Pipeline\n",
|
|
|
|
|
"\n",
|
|
|
|
|
"Executes the full training workflow:\n",
|
|
|
|
|
"1. Load dataset info\n",
|
|
|
|
|
"2. Create data loaders\n",
|
|
|
|
|
"3. Initialize model\n",
|
|
|
|
|
"4. Train with validation\n",
|
|
|
|
|
"5. Analyze and export results\n",
|
|
|
|
|
"\n",
|
|
|
|
|
"**Outputs:**\n",
|
|
|
|
|
"- Model checkpoints in `runs/checkpoints/`\n",
|
|
|
|
|
"- TensorBoard logs in `runs/`\n"
|
|
|
|
|
]
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
"cell_type": "code",
|
|
|
|
|
"execution_count": null,
|
|
|
|
|
"id": "0a271cba",
|
|
|
|
|
"metadata": {},
|
|
|
|
|
"outputs": [],
|
|
|
|
|
"outputs": [
|
|
|
|
|
{
|
|
|
|
|
"name": "stderr",
|
|
|
|
|
"output_type": "stream",
|
|
|
|
|
"text": [
|
|
|
|
|
"2026-04-06 23:05:04,927 - ==================================================\n",
|
|
|
|
|
"2026-04-06 23:05:04,928 - SiaN Training Pipeline\n",
|
|
|
|
|
"2026-04-06 23:05:04,929 - ==================================================\n",
|
|
|
|
|
"2026-04-06 23:05:06,049 - Dataset: 33 samples, keys=['google_img', 'yandex_img', 'homography_matrix', 'homography_params']\n",
|
|
|
|
|
"2026-04-06 23:05:08,308 - Data loaders created: train=26, val=7\n"
|
|
|
|
|
]
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
"ename": "KeyError",
|
|
|
|
|
"evalue": "'droupout_rate'",
|
|
|
|
|
"output_type": "error",
|
|
|
|
|
"traceback": [
|
|
|
|
|
"\u001b[31m---------------------------------------------------------------------------\u001b[39m",
|
|
|
|
|
"\u001b[31mKeyError\u001b[39m Traceback (most recent call last)",
|
|
|
|
|
"\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[42]\u001b[39m\u001b[32m, line 29\u001b[39m\n\u001b[32m 18\u001b[39m logger.info(\u001b[33mf\u001b[39m\u001b[33m\"\u001b[39m\u001b[33mData loaders created: train=\u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28mlen\u001b[39m(train_loader.dataset)\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m, val=\u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28mlen\u001b[39m(val_loader.dataset)\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m\"\u001b[39m)\n\u001b[32m 20\u001b[39m \u001b[38;5;66;03m# model = HomographyCNN6(\u001b[39;00m\n\u001b[32m 21\u001b[39m \u001b[38;5;66;03m# input_channels=3,\u001b[39;00m\n\u001b[32m 22\u001b[39m \u001b[38;5;66;03m# backbone_name=config[\"backbone\"],\u001b[39;00m\n\u001b[32m 23\u001b[39m \u001b[38;5;66;03m# pretrained=True,\u001b[39;00m\n\u001b[32m 24\u001b[39m \u001b[38;5;66;03m# dropout_rate=config[\"dropout_rate\"]\u001b[39;00m\n\u001b[32m 25\u001b[39m \u001b[38;5;66;03m# )\u001b[39;00m\n\u001b[32m 27\u001b[39m model = HomographyHybridCNN(\n\u001b[32m 28\u001b[39m input_channels=\u001b[32m3\u001b[39m,\n\u001b[32m---> \u001b[39m\u001b[32m29\u001b[39m dropout_rate=\u001b[43mconfig\u001b[49m\u001b[43m[\u001b[49m\u001b[33;43m\"\u001b[39;49m\u001b[33;43mdroupout_rate\u001b[39;49m\u001b[33;43m\"\u001b[39;49m\u001b[43m]\u001b[49m,\n\u001b[32m 30\u001b[39m )\n\u001b[32m 32\u001b[39m logger.info(\u001b[33mf\u001b[39m\u001b[33m\"\u001b[39m\u001b[33mModel created with \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mcount_parameters(model)\u001b[38;5;132;01m:\u001b[39;00m\u001b[33m,\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m parameters\u001b[39m\u001b[33m\"\u001b[39m)\n\u001b[32m 34\u001b[39m device = torch.device(\u001b[33m\"\u001b[39m\u001b[33mcuda\u001b[39m\u001b[33m\"\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m torch.cuda.is_available() \u001b[38;5;28;01melse\u001b[39;00m \u001b[33m\"\u001b[39m\u001b[33mcpu\u001b[39m\u001b[33m\"\u001b[39m)\n",
|
|
|
|
|
"\u001b[31mKeyError\u001b[39m: 'droupout_rate'"
|
|
|
|
|
]
|
|
|
|
|
}
|
|
|
|
|
],
|
|
|
|
|
"source": [
|
|
|
|
|
"\n",
|
|
|
|
|
"\n",
|
|
|
|
|
@@ -827,18 +1152,24 @@
|
|
|
|
|
")\n",
|
|
|
|
|
"logger.info(f\"Data loaders created: train={len(train_loader.dataset)}, val={len(val_loader.dataset)}\")\n",
|
|
|
|
|
"\n",
|
|
|
|
|
"model = HomographyCNN6(\n",
|
|
|
|
|
"# model = HomographyCNN6(\n",
|
|
|
|
|
"# input_channels=3,\n",
|
|
|
|
|
"# backbone_name=config[\"backbone\"],\n",
|
|
|
|
|
"# pretrained=True,\n",
|
|
|
|
|
"# dropout_rate=config[\"dropout_rate\"]\n",
|
|
|
|
|
"# )\n",
|
|
|
|
|
"\n",
|
|
|
|
|
"model = HomographyHybridCNN(\n",
|
|
|
|
|
" input_channels=3,\n",
|
|
|
|
|
" backbone_name=config[\"backbone\"],\n",
|
|
|
|
|
" pretrained=True,\n",
|
|
|
|
|
" dropout_rate=config[\"dropout_rate\"]\n",
|
|
|
|
|
" dropout_rate=config[\"dropout_rate\"],\n",
|
|
|
|
|
")\n",
|
|
|
|
|
"\n",
|
|
|
|
|
"logger.info(f\"Model created with {count_parameters(model):,} parameters\")\n",
|
|
|
|
|
"\n",
|
|
|
|
|
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
|
|
|
|
|
"logger.info(f\"Using device: {device}\")\n",
|
|
|
|
|
"\n",
|
|
|
|
|
"trainer = HomographyTrainer(model, train_loader, val_loader, device)\n",
|
|
|
|
|
"trainer = HomographyTrainer(model, train_loader, val_loader, device, HomographyLoss())\n",
|
|
|
|
|
"logger.info(\"Starting training...\")\n",
|
|
|
|
|
"trainer.train(config[\"epochs\"])\n",
|
|
|
|
|
"logger.info(\"Training completed\")\n",
|
|
|
|
|
@@ -856,6 +1187,7 @@
|
|
|
|
|
{
|
|
|
|
|
"cell_type": "code",
|
|
|
|
|
"execution_count": null,
|
|
|
|
|
"id": "c1ea5b8f",
|
|
|
|
|
"metadata": {},
|
|
|
|
|
"outputs": [],
|
|
|
|
|
"source": [
|
|
|
|
|
|