try custom cnn

This commit is contained in:
2026-04-14 12:37:09 +03:00
parent 2ec0763e6d
commit fc072d798e
4 changed files with 1349 additions and 888 deletions

View File

@@ -2,7 +2,8 @@
"cells": [
{
"cell_type": "code",
"execution_count": null,
"execution_count": 35,
"id": "4230e9cd",
"metadata": {},
"outputs": [],
"source": [
@@ -17,6 +18,7 @@
"import torch.optim as optim\n",
"import matplotlib.pyplot as plt\n",
"from PIL import Image\n",
"import seaborn as sns\n",
"from torch.utils.data import DataLoader, Dataset, Subset\n",
"from torch.utils.tensorboard import SummaryWriter\n",
"from torchvision import transforms, models\n",
@@ -25,12 +27,21 @@
},
{
"cell_type": "markdown",
"id": "d41ef314",
"metadata": {},
"source": "# Configuration\n\nGlobal settings for:\n- Data paths and image parameters\n- Training hyperparameters\n- Model architecture options\n"
"source": [
"# Configuration\n",
"\n",
"Global settings for:\n",
"- Data paths and image parameters\n",
"- Training hyperparameters\n",
"- Model architecture options\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 36,
"id": "4463d9d3",
"metadata": {},
"outputs": [],
"source": [
@@ -96,12 +107,25 @@
},
{
"cell_type": "markdown",
"id": "e5a40be4",
"metadata": {},
"source": "## Dataset\n\nGoogle/Yandex image pair loader with homography augmentation.\n\n**Features:**\n- Loads paired images from dual camera sources\n- Applies random homography transformations\n- Supports configurable train/val split\n\n**Returns:**\n"
"source": [
"## Dataset\n",
"\n",
"Google/Yandex image pair loader with homography augmentation.\n",
"\n",
"**Features:**\n",
"- Loads paired images from dual camera sources\n",
"- Applies random homography transformations\n",
"- Supports configurable train/val split\n",
"\n",
"**Returns:**\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 37,
"id": "37358bfe",
"metadata": {},
"outputs": [],
"source": [
@@ -242,9 +266,39 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 38,
"id": "9fee48b8",
"metadata": {},
"outputs": [],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"tensor([[ 1.0661e+00, -4.5570e-02, -2.1649e+01],\n",
" [ 3.7228e-02, 9.1285e-01, 1.3237e+01],\n",
" [ 5.3873e-04, -6.4843e-04, 9.6602e-01]])\n",
"tensor([-0.0379, 0.0183, -0.0856, -0.0661, -0.0375, 0.9617])\n",
"[[ 1.0660727e+00 -4.5569740e-02 -2.1648926e+01]\n",
" [ 3.7227791e-02 9.1284692e-01 1.3237175e+01]\n",
" [ 5.3873385e-04 -6.4843398e-04 9.6602309e-01]]\n",
"[ 0. 0. -0.08696619 -0.0720376 -0.03181122 0.9519815 ]\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"c:\\Users\\admin\\Projects\\autopilot\\.venv\\Lib\\site-packages\\torch\\utils\\data\\dataloader.py:775: UserWarning: 'pin_memory' argument is set as true but no accelerator is found, then device pinned memory won't be used.\n",
" super().__init__(loader)\n",
"C:\\Users\\admin\\AppData\\Local\\Temp\\ipykernel_19016\\897759767.py:32: DeprecationWarning: __array_wrap__ must accept context and return_scalar arguments (positionally) in the future. (Deprecated NumPy 2.0)\n",
" cy, sy = np.cos(rz), np.sin(rz)\n",
"C:\\Users\\admin\\AppData\\Local\\Temp\\ipykernel_19016\\897759767.py:33: DeprecationWarning: __array_wrap__ must accept context and return_scalar arguments (positionally) in the future. (Deprecated NumPy 2.0)\n",
" cp, sp = np.cos(ry), np.sin(ry)\n",
"C:\\Users\\admin\\AppData\\Local\\Temp\\ipykernel_19016\\897759767.py:34: DeprecationWarning: __array_wrap__ must accept context and return_scalar arguments (positionally) in the future. (Deprecated NumPy 2.0)\n",
" cr, sr = np.cos(rx), np.sin(rx)\n"
]
}
],
"source": [
"\n",
"train_loader, val_loader = create_data_loaders(config['data_dir'])\n",
@@ -265,12 +319,27 @@
},
{
"cell_type": "markdown",
"id": "e8072ee6",
"metadata": {},
"source": "## Model\n\n`HomographyCNN6` — CNN architecture for homography estimation.\n\n**Output:** 6 parameters\n- `rx, ry, rz` — rotation angles (radians)\n- `tx, ty` — translation offsets\n- `scale` — isotropic scale factor\n\n**Architecture:**\n- Dual-branch CNN (Google + Yandex images)\n- Shared backbone (configurable: resnet18/34/50)\n"
"source": [
"## Model\n",
"\n",
"`HomographyCNN6` — CNN architecture for homography estimation.\n",
"\n",
"**Output:** 6 parameters\n",
"- `rx, ry, rz` — rotation angles (radians)\n",
"- `tx, ty` — translation offsets\n",
"- `scale` — isotropic scale factor\n",
"\n",
"**Architecture:**\n",
"- Dual-branch CNN (Google + Yandex images)\n",
"- Shared backbone (configurable: resnet18/34/50)\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 39,
"id": "c246531d",
"metadata": {},
"outputs": [],
"source": [
@@ -302,11 +371,19 @@
" nn.Dropout(dropout_rate),\n",
" nn.Linear(256, 6),\n",
" )\n",
" self._init_weights()\n",
"\n",
" def _normalize_sin_cos(self, _sin, _cos):\n",
" _len = torch.sqrt(_sin ** 2 + _cos ** 2)\n",
" return _sin / _len, _cos / _len\n",
"\n",
" def _init_weights(self):\n",
" for module in self.head.modules():\n",
" if isinstance(module, nn.Linear):\n",
" nn.init.kaiming_normal_(module.weight, mode='fan_in', nonlinearity='relu')\n",
" if module.bias is not None:\n",
" nn.init.zeros_(module.bias)\n",
"\n",
" def forward(self, img1, img2):\n",
" f1 = self.backbone(img1)\n",
" f2 = self.backbone(img2)\n",
@@ -342,6 +419,126 @@
" }\n",
"\n",
"\n",
"class HomographyHybridCNN(nn.Module):\n",
" def __init__(self, input_channels=3, use_resnet_layers=2, dropout_rate=0.3):\n",
" super().__init__()\n",
" \n",
" if use_resnet_layers == 1:\n",
" resnet = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1)\n",
" self.conv1 = resnet.conv1\n",
" self.bn1 = resnet.bn1\n",
" self.relu = resnet.relu\n",
" self.maxpool = resnet.maxpool\n",
" conv_out_channels = 64\n",
" elif use_resnet_layers == 2:\n",
" resnet = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1)\n",
" self.conv1 = resnet.conv1\n",
" self.bn1 = resnet.bn1\n",
" self.relu = resnet.relu\n",
" self.maxpool = resnet.maxpool\n",
" self.conv2 = resnet.layer1[0].conv1\n",
" self.bn2 = resnet.layer1[0].bn1\n",
" self.conv2_2 = resnet.layer1[0].conv2\n",
" self.bn2_2 = resnet.layer1[0].bn2\n",
" self.relu2 = resnet.layer1[0].relu\n",
" self.maxpool2 = resnet.maxpool\n",
" conv_out_channels = 64\n",
" else:\n",
" raise ValueError(\"use_resnet_layers must be 1 or 2\")\n",
" \n",
" self.use_resnet_layers = use_resnet_layers\n",
" self.feature_map_size = 64\n",
" \n",
" self.conv_head = nn.Sequential(\n",
" nn.Conv2d(conv_out_channels, 128, kernel_size=3, padding=1),\n",
" nn.BatchNorm2d(128),\n",
" nn.ReLU(inplace=True),\n",
" nn.Conv2d(128, 256, kernel_size=3, padding=1),\n",
" nn.BatchNorm2d(256),\n",
" nn.ReLU(inplace=True),\n",
" nn.MaxPool2d(2),\n",
" )\n",
" \n",
" self.global_pool = nn.AdaptiveAvgPool2d((1, 1))\n",
" \n",
" feature_dim = 256 * 4\n",
" \n",
" self.head = nn.Sequential(\n",
" nn.Linear(feature_dim, 1024),\n",
" nn.ReLU(inplace=True),\n",
" nn.Dropout(dropout_rate),\n",
" nn.Linear(1024, 512),\n",
" nn.ReLU(inplace=True),\n",
" nn.Dropout(dropout_rate),\n",
" nn.Linear(512, 256),\n",
" nn.ReLU(inplace=True),\n",
" nn.Dropout(dropout_rate),\n",
" nn.Linear(256, 6),\n",
" )\n",
" self._init_weights()\n",
"\n",
" def _init_weights(self):\n",
" for module in self.head.modules():\n",
" if isinstance(module, nn.Linear):\n",
" nn.init.kaiming_normal_(module.weight, mode='fan_in', nonlinearity='relu')\n",
" if module.bias is not None:\n",
" nn.init.zeros_(module.bias)\n",
"\n",
" def forward(self, img1, img2):\n",
" x1 = self._extract_features(img1)\n",
" x2 = self._extract_features(img2)\n",
" \n",
" combined = torch.cat([x1, x2, torch.abs(x1 - x2), x1 * x2], dim=1)\n",
" output = self.head(combined)\n",
" \n",
" output = torch.tanh(output)\n",
" modified = output.clone()\n",
" modified[:, 2:6] = torch.mul(output[:, 2:6], torch.pi)\n",
" \n",
" return modified\n",
"\n",
" def _extract_features(self, x):\n",
" x = self.conv1(x)\n",
" x = self.bn1(x)\n",
" x = self.relu(x)\n",
" x = self.maxpool(x)\n",
" \n",
" if self.use_resnet_layers >= 2:\n",
" x = self.conv2(x)\n",
" x = self.bn2(x)\n",
" x = self.relu(x)\n",
" x = self.conv2_2(x)\n",
" x = self.bn2_2(x)\n",
" x = self.relu2(x)\n",
" x = self.maxpool2(x)\n",
" \n",
" x = self.conv_head(x)\n",
" x = self.global_pool(x)\n",
" x = x.view(x.size(0), -1)\n",
" \n",
" return x\n",
"\n",
" def decode_output(self, output):\n",
" tx = output[:, 0]\n",
" ty = output[:, 1]\n",
" scale = output[:, 5]\n",
" angle1 = output[:, 2]\n",
" angle2 = output[:, 3]\n",
" angle3 = output[:, 4]\n",
" return torch.stack([tx, ty, angle1, angle2, angle3, scale], dim=1)\n",
"\n",
" def get_components(self, output):\n",
" decoded = self.decode_output(output)\n",
" return {\n",
" \"tx\": decoded[:, 0],\n",
" \"ty\": decoded[:, 1],\n",
" \"rx\": decoded[:, 2],\n",
" \"ry\": decoded[:, 3],\n",
" \"rz\": decoded[:, 4],\n",
" \"scale\": decoded[:, 5],\n",
" }\n",
"\n",
"\n",
"class HomographyLoss6(nn.Module):\n",
" def __init__(self, angle_loss_weight=1.0, trans_loss_weight=1.0, scale_loss_weight=1.0):\n",
" super().__init__()\n",
@@ -403,6 +600,9 @@
" }\n",
"\n",
"\n",
"HomographyLoss = HomographyLoss6\n",
"\n",
"\n",
"def count_parameters(model):\n",
" return sum(p.numel() for p in model.parameters())\n",
"\n"
@@ -410,12 +610,28 @@
},
{
"cell_type": "markdown",
"id": "d0991e10",
"metadata": {},
"source": "## Training\n\n`HomographyTrainer` — training loop with validation and checkpointing.\n\n**Features:**\n- Epoch-based training with tqdm progress bar\n- Adam optimizer with configurable LR\n- Validation after each epoch\n- Best model auto-save\n- Periodic checkpoints (every N epochs via `save_every_n_epochs`)\n\n**Checkpoint saving:**\n- `best_model.pt` — lowest validation loss\n"
"source": [
"## Training\n",
"\n",
"`HomographyTrainer` — training loop with validation and checkpointing.\n",
"\n",
"**Features:**\n",
"- Epoch-based training with tqdm progress bar\n",
"- Adam optimizer with configurable LR\n",
"- Validation after each epoch\n",
"- Best model auto-save\n",
"- Periodic checkpoints (every N epochs via `save_every_n_epochs`)\n",
"\n",
"**Checkpoint saving:**\n",
"- `best_model.pt` — lowest validation loss\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 40,
"id": "83ff7cc6",
"metadata": {},
"outputs": [],
"source": [
@@ -424,12 +640,12 @@
"\n",
"\n",
"class HomographyTrainer:\n",
" def __init__(self, model, train_loader, val_loader, device):\n",
" def __init__(self, model, train_loader, val_loader, device, criterion):\n",
" self.model = model.to(device)\n",
" self.train_loader = train_loader\n",
" self.val_loader = val_loader\n",
" self.device = device\n",
" self.criterion = HomographyLoss6()\n",
" self.criterion = criterion\n",
" self.optimizer = optim.Adam(model.parameters(), lr=config[\"learning_rate\"], weight_decay=1e-4)\n",
" self.writer = None\n",
" self.best_val_loss = float(\"inf\")\n",
@@ -537,17 +753,28 @@
},
{
"cell_type": "markdown",
"id": "d6dbc5ea",
"metadata": {},
"source": "## Analysis\n\nVisualization and evaluation tools:\n\n- Training metrics plots (loss curves)\n- Prediction visualization on sample images\n"
"source": [
"## Analysis\n",
"\n",
"Visualization and evaluation tools:\n",
"\n",
"- Training metrics plots (loss curves)\n",
"- Prediction visualization on sample images\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 41,
"id": "6fac17d5",
"metadata": {},
"outputs": [],
"source": [
"\n",
"\n",
"sns.set_theme(style=\"whitegrid\", palette=\"muted\", font_scale=1.2)\n",
"\n",
"IMG_DIR = os.path.join(config[\"output_dir\"], \"images\")\n",
"os.makedirs(IMG_DIR, exist_ok=True)\n",
"\n",
@@ -623,84 +850,132 @@
" mean_errors = [np.mean(all_errors[i]) for i in range(6)]\n",
" std_errors = [np.std(all_errors[i]) for i in range(6)]\n",
" \n",
" angle_errors_deg = [np.degrees(mean_errors[i]) for i in range(2, 5)]\n",
" \n",
" all_targets_stacked = [np.array(all_targets[i]) for i in range(6)]\n",
" target_ranges = [np.ptp(all_targets_stacked[i]) for i in range(6)]\n",
" relative_errors = [mean_errors[i] / target_ranges[i] if target_ranges[i] > 1e-8 else 0 for i in range(6)]\n",
" \n",
" if len(trainer.train_losses) > 0:\n",
" epochs = range(1, len(trainer.train_losses) + 1)\n",
" fig, axes = plt.subplots(2, 2, figsize=(16, 12))\n",
" \n",
" axes[0, 0].plot(epochs, trainer.train_losses, \"b-\", label=\"Train Loss\")\n",
" axes[0, 0].plot(epochs, trainer.val_losses, \"r-\", label=\"Val Loss\")\n",
" axes[0, 0].plot(epochs, trainer.train_losses, color=\"#2ecc71\", linewidth=2, label=\"Train Loss\")\n",
" axes[0, 0].plot(epochs, trainer.val_losses, color=\"#e74c3c\", linewidth=2, label=\"Val Loss\")\n",
" axes[0, 0].set_xlabel(\"Epoch\")\n",
" axes[0, 0].set_ylabel(\"Loss\")\n",
" axes[0, 0].set_title(\"Training & Validation Loss\")\n",
" axes[0, 0].legend()\n",
" axes[0, 0].set_title(\"Training & Validation Loss\", fontweight=\"bold\")\n",
" axes[0, 0].legend(framealpha=0.9)\n",
" axes[0, 0].grid(True, alpha=0.3)\n",
" \n",
" axes[0, 1].plot(epochs, trainer.val_losses, \"r-\", label=\"Val Loss\")\n",
" axes[0, 1].plot(epochs, trainer.val_losses, color=\"#e74c3c\", linewidth=2, label=\"Val Loss\")\n",
" axes[0, 1].set_xlabel(\"Epoch\")\n",
" axes[0, 1].set_ylabel(\"Loss\")\n",
" axes[0, 1].set_title(\"Validation Loss\")\n",
" axes[0, 1].legend()\n",
" axes[0, 1].set_title(\"Validation Loss\", fontweight=\"bold\")\n",
" axes[0, 1].legend(framealpha=0.9)\n",
" axes[0, 1].grid(True, alpha=0.3)\n",
" \n",
" axes[1, 0].plot(epochs, trainer.val_mse_trans, \"g-\", label=\"Translation (tx, ty)\")\n",
" axes[1, 0].plot(epochs, trainer.val_mse_angle, \"m-\", label=\"Angle (rx, ry, rz)\")\n",
" axes[1, 0].plot(epochs, trainer.val_mse_scale, \"c-\", label=\"Scale\")\n",
" axes[1, 0].plot(epochs, trainer.val_mse_trans, color=\"#3498db\", linewidth=2, label=\"Translation (tx, ty)\")\n",
" axes[1, 0].plot(epochs, trainer.val_mse_angle, color=\"#9b59b6\", linewidth=2, label=\"Angle (rx, ry, rz)\")\n",
" axes[1, 0].plot(epochs, trainer.val_mse_scale, color=\"#e67e22\", linewidth=2, label=\"Scale\")\n",
" axes[1, 0].set_xlabel(\"Epoch\")\n",
" axes[1, 0].set_ylabel(\"MSE\")\n",
" axes[1, 0].set_title(\"Validation MSE by Category\")\n",
" axes[1, 0].legend()\n",
" axes[1, 0].set_title(\"Validation MSE by Category\", fontweight=\"bold\")\n",
" axes[1, 0].legend(framealpha=0.9)\n",
" axes[1, 0].grid(True, alpha=0.3)\n",
" \n",
" x_pos = np.arange(6)\n",
" axes[1, 1].bar(x_pos, mean_errors, yerr=std_errors, capsize=5, color=[\"c\", \"m\", \"y\", \"g\", \"b\", \"r\"], alpha=0.8)\n",
" colors = [\"#3498db\", \"#e74c3c\", \"#9b59b6\", \"#2ecc71\", \"#f39c12\", \"#1abc9c\"]\n",
" bars = axes[1, 1].bar(x_pos, mean_errors, yerr=std_errors, capsize=6, color=colors, alpha=0.85, edgecolor=\"white\", linewidth=1.5)\n",
" axes[1, 1].set_xticks(x_pos)\n",
" axes[1, 1].set_xticklabels(names)\n",
" axes[1, 1].set_ylabel(\"Mean Absolute Error\")\n",
" axes[1, 1].set_title(f\"Mean Absolute Error per Parameter ({n_samples} samples)\")\n",
" axes[1, 1].set_title(f\"Mean Absolute Error per Parameter ({n_samples} samples)\", fontweight=\"bold\")\n",
" axes[1, 1].grid(True, alpha=0.3, axis=\"y\")\n",
" \n",
" plt.tight_layout()\n",
" plt.savefig(os.path.join(IMG_DIR, \"training_loss_plots.png\"), dpi=150)\n",
" plt.savefig(os.path.join(IMG_DIR, \"training_loss_plots.png\"), dpi=150, bbox_inches=\"tight\")\n",
" print(\"Saved training_loss_plots.png\")\n",
" plt.show()\n",
" \n",
" fig, axes = plt.subplots(2, 3, figsize=(18, 10))\n",
" colors = [\"#3498db\", \"#e74c3c\", \"#9b59b6\", \"#2ecc71\", \"#f39c12\", \"#1abc9c\"]\n",
" for j in range(6):\n",
" row = j // 3\n",
" col = j % 3\n",
" axes[row, col].bar(range(len(all_errors[j])), all_errors[j], color=\"steelblue\", alpha=0.7)\n",
" axes[row, col].set_xlabel(\"Sample\")\n",
" axes[row, col].set_ylabel(\"Absolute Error\")\n",
" axes[row, col].set_title(f\"{names[j]}: Mean={np.mean(all_errors[j]):.4f}, Std={np.std(all_errors[j]):.4f}\")\n",
" axes[row, col].bar(range(len(all_errors[j])), all_errors[j], color=colors[j], alpha=0.75)\n",
" axes[row, col].set_xlabel(\"Sample\", fontsize=10)\n",
" axes[row, col].set_ylabel(\"Absolute Error\", fontsize=10)\n",
" axes[row, col].set_title(f\"{names[j]}: Mean={np.mean(all_errors[j]):.4f}, Std={np.std(all_errors[j]):.4f}\", fontweight=\"bold\", fontsize=11)\n",
" axes[row, col].grid(True, alpha=0.3, axis=\"y\")\n",
" plt.suptitle(f\"Mean Absolute Error per Parameter ({n_samples} samples)\", fontsize=14)\n",
" plt.suptitle(f\"Mean Absolute Error per Parameter ({n_samples} samples)\", fontsize=14, fontweight=\"bold\")\n",
" plt.tight_layout()\n",
" plt.savefig(os.path.join(IMG_DIR, \"mae_per_parameter.png\"), dpi=150)\n",
" plt.savefig(os.path.join(IMG_DIR, \"mae_per_parameter.png\"), dpi=150, bbox_inches=\"tight\")\n",
" print(\"Saved mae_per_parameter.png\")\n",
" plt.show()\n",
" \n",
" fig, axes = plt.subplots(1, 3, figsize=(18, 6))\n",
" \n",
" x_pos = np.arange(6)\n",
" colors = [\"#3498db\", \"#e74c3c\", \"#9b59b6\", \"#2ecc71\", \"#f39c12\", \"#1abc9c\"]\n",
" bars = axes[0].bar(x_pos, mean_errors, yerr=std_errors, capsize=6, color=colors, alpha=0.85, edgecolor=\"white\", linewidth=1.5)\n",
" axes[0].set_xticks(x_pos)\n",
" axes[0].set_xticklabels(names)\n",
" axes[0].set_ylabel(\"Mean Absolute Error\")\n",
" axes[0].set_title(\"Mean Absolute Error per Parameter (with std)\", fontweight=\"bold\")\n",
" axes[0].grid(True, alpha=0.3, axis=\"y\")\n",
" \n",
" bp = axes[1].boxplot([all_errors[i] for i in range(6)], labels=names, patch_artist=True)\n",
" for patch, color in zip(bp[\"boxes\"], colors):\n",
" patch.set_facecolor(color)\n",
" patch.set_alpha(0.8)\n",
" axes[1].set_ylabel(\"Absolute Error\")\n",
" axes[1].set_title(f\"Error Distribution per Parameter ({n_samples} samples)\", fontweight=\"bold\")\n",
" axes[1].grid(True, alpha=0.3, axis=\"y\")\n",
" \n",
" rel_err_pos = np.arange(6)\n",
" bars = axes[2].bar(rel_err_pos, relative_errors, color=colors, alpha=0.85, edgecolor=\"white\", linewidth=1.5)\n",
" axes[2].set_xticks(rel_err_pos)\n",
" axes[2].set_xticklabels(names)\n",
" axes[2].set_ylabel(\"Relative Error (MAE / Range)\")\n",
" axes[2].set_title(\"Relative Error per Parameter\", fontweight=\"bold\")\n",
" axes[2].grid(True, alpha=0.3, axis=\"y\")\n",
" \n",
" plt.tight_layout()\n",
" plt.savefig(os.path.join(IMG_DIR, \"mae_boxplot.png\"), dpi=150, bbox_inches=\"tight\")\n",
" print(\"Saved mae_boxplot.png\")\n",
" plt.show()\n",
" \n",
" fig, axes = plt.subplots(1, 2, figsize=(14, 6))\n",
" \n",
" x_pos = np.arange(6)\n",
" axes[0].bar(x_pos, mean_errors, yerr=std_errors, capsize=5, color=[\"c\", \"m\", \"y\", \"g\", \"b\", \"r\"], alpha=0.8)\n",
" angle_names = [\"rx\", \"ry\", \"rz\"]\n",
" x_pos = np.arange(3)\n",
" colors_angle = [\"#9b59b6\", \"#2ecc71\", \"#f39c12\"]\n",
" bars = axes[0].bar(x_pos, angle_errors_deg, color=colors_angle, alpha=0.85, edgecolor=\"white\", linewidth=1.5)\n",
" axes[0].set_xticks(x_pos)\n",
" axes[0].set_xticklabels(names)\n",
" axes[0].set_ylabel(\"Mean Absolute Error\")\n",
" axes[0].set_title(\"Mean Absolute Error per Parameter (with std)\")\n",
" axes[0].set_xticklabels(angle_names)\n",
" axes[0].set_ylabel(\"Mean Absolute Error (degrees)\")\n",
" axes[0].set_title(\"Angle MAE in Degrees\", fontweight=\"bold\")\n",
" axes[0].grid(True, alpha=0.3, axis=\"y\")\n",
" for i, e in enumerate(angle_errors_deg):\n",
" axes[0].text(i, e + 0.5, f\"{e:.1f}°\", ha=\"center\", va=\"bottom\", fontsize=11, fontweight=\"bold\")\n",
" \n",
" bp = axes[1].boxplot([all_errors[i] for i in range(6)], labels=names, patch_artist=True)\n",
" colors = [\"c\", \"m\", \"y\", \"g\", \"b\", \"r\"]\n",
" for patch, color in zip(bp[\"boxes\"], colors):\n",
" patch.set_facecolor(color)\n",
" patch.set_alpha(0.7)\n",
" axes[1].set_ylabel(\"Absolute Error\")\n",
" axes[1].set_title(f\"Error Distribution per Parameter ({n_samples} samples)\")\n",
" trans_scale_errs = [mean_errors[0], mean_errors[1], mean_errors[5]]\n",
" trans_scale_names = [\"tx\", \"ty\", \"scale\"]\n",
" x_pos = np.arange(3)\n",
" colors_trans = [\"#3498db\", \"#e74c3c\", \"#1abc9c\"]\n",
" bars = axes[1].bar(x_pos, trans_scale_errs, color=colors_trans, alpha=0.85, edgecolor=\"white\", linewidth=1.5)\n",
" axes[1].set_xticks(x_pos)\n",
" axes[1].set_xticklabels(trans_scale_names)\n",
" axes[1].set_ylabel(\"Mean Absolute Error\")\n",
" axes[1].set_title(\"Translation & Scale MAE\", fontweight=\"bold\")\n",
" axes[1].grid(True, alpha=0.3, axis=\"y\")\n",
" for i, e in enumerate(trans_scale_errs):\n",
" axes[1].text(i, e + 0.01, f\"{e:.4f}\", ha=\"center\", va=\"bottom\", fontsize=11, fontweight=\"bold\")\n",
" \n",
" plt.tight_layout()\n",
" plt.savefig(os.path.join(IMG_DIR, \"mae_boxplot.png\"), dpi=150)\n",
" print(\"Saved mae_boxplot.png\")\n",
" plt.savefig(os.path.join(IMG_DIR, \"mae_by_category.png\"), dpi=150, bbox_inches=\"tight\")\n",
" print(\"Saved mae_by_category.png\")\n",
" plt.show()\n",
" \n",
" print(\"\\n=== Sample Predictions (20 pairs) ===\")\n",
@@ -737,50 +1012,58 @@
" fig, axes = plt.subplots(2, 2, figsize=(12, 10))\n",
" \n",
" axes[0, 0].imshow(google_img[0].cpu().permute(1, 2, 0))\n",
" axes[0, 0].set_title(f\"Google Image\")\n",
" axes[0, 0].set_title(\"Google Image\", fontweight=\"bold\", fontsize=12)\n",
" axes[0, 0].axis(\"off\")\n",
" \n",
" axes[0, 1].imshow(yandex_img[0].cpu().permute(1, 2, 0))\n",
" axes[0, 1].set_title(f\"Yandex Image\")\n",
" axes[0, 1].set_title(\"Yandex Image\", fontweight=\"bold\", fontsize=12)\n",
" axes[0, 1].axis(\"off\")\n",
" \n",
" x_pos = np.arange(6)\n",
" width = 0.35\n",
" axes[1, 0].bar(x_pos - width/2, targets, width, label=\"Target\", color=\"steelblue\", alpha=0.8)\n",
" axes[1, 0].bar(x_pos + width/2, preds, width, label=\"Predicted\", color=\"coral\", alpha=0.8)\n",
" axes[1, 0].bar(x_pos - width/2, targets, width, label=\"Target\", color=\"#3498db\", alpha=0.85)\n",
" axes[1, 0].bar(x_pos + width/2, preds, width, label=\"Predicted\", color=\"#e74c3c\", alpha=0.85)\n",
" axes[1, 0].set_xticks(x_pos)\n",
" axes[1, 0].set_xticklabels(names)\n",
" axes[1, 0].set_ylabel(\"Parameter Value\")\n",
" axes[1, 0].set_title(\"Target vs Predicted\")\n",
" axes[1, 0].legend()\n",
" axes[1, 0].set_title(\"Target vs Predicted\", fontweight=\"bold\", fontsize=12)\n",
" axes[1, 0].legend(framealpha=0.9)\n",
" axes[1, 0].grid(True, alpha=0.3, axis=\"y\")\n",
" \n",
" axes[1, 1].bar(x_pos, errors, color=[\"c\", \"m\", \"y\", \"g\", \"b\", \"r\"], alpha=0.8)\n",
" colors = [\"#3498db\", \"#e74c3c\", \"#9b59b6\", \"#2ecc71\", \"#f39c12\", \"#1abc9c\"]\n",
" bars = axes[1, 1].bar(x_pos, errors, color=colors, alpha=0.85, edgecolor=\"white\", linewidth=1.2)\n",
" axes[1, 1].set_xticks(x_pos)\n",
" axes[1, 1].set_xticklabels(names)\n",
" axes[1, 1].set_ylabel(\"Absolute Error\")\n",
" axes[1, 1].set_title(f\"Prediction Error (Mean: {np.mean(errors):.4f})\")\n",
" axes[1, 1].set_title(f\"Prediction Error (Mean: {np.mean(errors):.4f})\", fontweight=\"bold\", fontsize=12)\n",
" axes[1, 1].grid(True, alpha=0.3, axis=\"y\")\n",
" for i_e, e in enumerate(errors):\n",
" axes[1, 1].text(i_e, e + 0.01, f\"{e:.3f}\", ha=\"center\", va=\"bottom\", fontsize=8)\n",
" axes[1, 1].text(i_e, e + 0.01, f\"{e:.3f}\", ha=\"center\", va=\"bottom\", fontsize=9)\n",
" \n",
" plt.suptitle(f\"Sample {vis_count + 1}\", fontsize=14)\n",
" plt.suptitle(f\"Sample {vis_count + 1}\", fontsize=14, fontweight=\"bold\")\n",
" plt.tight_layout()\n",
" plt.savefig(os.path.join(IMG_DIR, f\"prediction_sample_{vis_count + 1:02d}.png\"), dpi=100)\n",
" plt.savefig(os.path.join(IMG_DIR, f\"prediction_sample_{vis_count + 1:02d}.png\"), dpi=100, bbox_inches=\"tight\")\n",
" plt.show()\n",
" print(f\"Saved prediction_sample_{vis_count + 1:02d}.png\")\n",
" \n",
" vis_count += 1\n",
" \n",
" print(f\"\\nPrediction errors over {n_samples} samples:\")\n",
" print(f\"{'Param':<8} {'Mean Error':>12} {'Std Error':>12} {'Min':>8} {'Max':>8}\")\n",
" print(\"-\" * 52)\n",
" print(f\"{'Param':<8} {'Mean Error':>12} {'Std Error':>12} {'Min':>8} {'Max':>8} {'Rel Err':>10}\")\n",
" print(\"-\" * 62)\n",
" for i in range(6):\n",
" mean_err = np.mean(all_errors[i])\n",
" std_err = np.std(all_errors[i])\n",
" min_err = np.min(all_errors[i])\n",
" max_err = np.max(all_errors[i])\n",
" print(f\"{names[i]:<8} {mean_err:>12.4f} {std_err:>12.4f} {min_err:>8.4f} {max_err:>8.4f}\")\n",
" rel_err = relative_errors[i]\n",
" print(f\"{names[i]:<8} {mean_err:>12.4f} {std_err:>12.4f} {min_err:>8.4f} {max_err:>8.4f} {rel_err:>10.4f}\")\n",
" \n",
" print(f\"\\nAngle errors in degrees:\")\n",
" print(f\"{'Param':<8} {'MAE (deg)':>12} {'MAE (rad)':>12}\")\n",
" print(\"-\" * 35)\n",
" for i, name in enumerate([\"rx\", \"ry\", \"rz\"]):\n",
" print(f\"{name:<8} {angle_errors_deg[i]:>12.2f} {mean_errors[i+2]:>12.4f}\")\n",
"\n",
" return {\n",
" \"best_val_loss\": trainer.best_val_loss,\n",
@@ -789,20 +1072,62 @@
" \"val_mse_trans\": trainer.val_mse_trans,\n",
" \"val_mse_angle\": trainer.val_mse_angle,\n",
" \"val_mse_scale\": trainer.val_mse_scale,\n",
" \"mean_errors\": mean_errors,\n",
" \"std_errors\": std_errors,\n",
" \"angle_errors_deg\": angle_errors_deg,\n",
" \"relative_errors\": relative_errors,\n",
" }\n",
"\n"
]
},
{
"cell_type": "markdown",
"id": "e36c328c",
"metadata": {},
"source": "## Main Pipeline\n\nExecutes the full training workflow:\n1. Load dataset info\n2. Create data loaders\n3. Initialize model\n4. Train with validation\n5. Analyze and export results\n\n**Outputs:**\n- Model checkpoints in `runs/checkpoints/`\n- TensorBoard logs in `runs/`\n"
"source": [
"## Main Pipeline\n",
"\n",
"Executes the full training workflow:\n",
"1. Load dataset info\n",
"2. Create data loaders\n",
"3. Initialize model\n",
"4. Train with validation\n",
"5. Analyze and export results\n",
"\n",
"**Outputs:**\n",
"- Model checkpoints in `runs/checkpoints/`\n",
"- TensorBoard logs in `runs/`\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "0a271cba",
"metadata": {},
"outputs": [],
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"2026-04-06 23:05:04,927 - ==================================================\n",
"2026-04-06 23:05:04,928 - SiaN Training Pipeline\n",
"2026-04-06 23:05:04,929 - ==================================================\n",
"2026-04-06 23:05:06,049 - Dataset: 33 samples, keys=['google_img', 'yandex_img', 'homography_matrix', 'homography_params']\n",
"2026-04-06 23:05:08,308 - Data loaders created: train=26, val=7\n"
]
},
{
"ename": "KeyError",
"evalue": "'droupout_rate'",
"output_type": "error",
"traceback": [
"\u001b[31m---------------------------------------------------------------------------\u001b[39m",
"\u001b[31mKeyError\u001b[39m Traceback (most recent call last)",
"\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[42]\u001b[39m\u001b[32m, line 29\u001b[39m\n\u001b[32m 18\u001b[39m logger.info(\u001b[33mf\u001b[39m\u001b[33m\"\u001b[39m\u001b[33mData loaders created: train=\u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28mlen\u001b[39m(train_loader.dataset)\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m, val=\u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28mlen\u001b[39m(val_loader.dataset)\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m\"\u001b[39m)\n\u001b[32m 20\u001b[39m \u001b[38;5;66;03m# model = HomographyCNN6(\u001b[39;00m\n\u001b[32m 21\u001b[39m \u001b[38;5;66;03m# input_channels=3,\u001b[39;00m\n\u001b[32m 22\u001b[39m \u001b[38;5;66;03m# backbone_name=config[\"backbone\"],\u001b[39;00m\n\u001b[32m 23\u001b[39m \u001b[38;5;66;03m# pretrained=True,\u001b[39;00m\n\u001b[32m 24\u001b[39m \u001b[38;5;66;03m# dropout_rate=config[\"dropout_rate\"]\u001b[39;00m\n\u001b[32m 25\u001b[39m \u001b[38;5;66;03m# )\u001b[39;00m\n\u001b[32m 27\u001b[39m model = HomographyHybridCNN(\n\u001b[32m 28\u001b[39m input_channels=\u001b[32m3\u001b[39m,\n\u001b[32m---> \u001b[39m\u001b[32m29\u001b[39m dropout_rate=\u001b[43mconfig\u001b[49m\u001b[43m[\u001b[49m\u001b[33;43m\"\u001b[39;49m\u001b[33;43mdroupout_rate\u001b[39;49m\u001b[33;43m\"\u001b[39;49m\u001b[43m]\u001b[49m,\n\u001b[32m 30\u001b[39m )\n\u001b[32m 32\u001b[39m logger.info(\u001b[33mf\u001b[39m\u001b[33m\"\u001b[39m\u001b[33mModel created with \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mcount_parameters(model)\u001b[38;5;132;01m:\u001b[39;00m\u001b[33m,\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m parameters\u001b[39m\u001b[33m\"\u001b[39m)\n\u001b[32m 34\u001b[39m device = torch.device(\u001b[33m\"\u001b[39m\u001b[33mcuda\u001b[39m\u001b[33m\"\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m torch.cuda.is_available() \u001b[38;5;28;01melse\u001b[39;00m \u001b[33m\"\u001b[39m\u001b[33mcpu\u001b[39m\u001b[33m\"\u001b[39m)\n",
"\u001b[31mKeyError\u001b[39m: 'droupout_rate'"
]
}
],
"source": [
"\n",
"\n",
@@ -827,18 +1152,24 @@
")\n",
"logger.info(f\"Data loaders created: train={len(train_loader.dataset)}, val={len(val_loader.dataset)}\")\n",
"\n",
"model = HomographyCNN6(\n",
"# model = HomographyCNN6(\n",
"# input_channels=3,\n",
"# backbone_name=config[\"backbone\"],\n",
"# pretrained=True,\n",
"# dropout_rate=config[\"dropout_rate\"]\n",
"# )\n",
"\n",
"model = HomographyHybridCNN(\n",
" input_channels=3,\n",
" backbone_name=config[\"backbone\"],\n",
" pretrained=True,\n",
" dropout_rate=config[\"dropout_rate\"]\n",
" dropout_rate=config[\"dropout_rate\"],\n",
")\n",
"\n",
"logger.info(f\"Model created with {count_parameters(model):,} parameters\")\n",
"\n",
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
"logger.info(f\"Using device: {device}\")\n",
"\n",
"trainer = HomographyTrainer(model, train_loader, val_loader, device)\n",
"trainer = HomographyTrainer(model, train_loader, val_loader, device, HomographyLoss())\n",
"logger.info(\"Starting training...\")\n",
"trainer.train(config[\"epochs\"])\n",
"logger.info(\"Training completed\")\n",
@@ -856,6 +1187,7 @@
{
"cell_type": "code",
"execution_count": null,
"id": "c1ea5b8f",
"metadata": {},
"outputs": [],
"source": [

View File

@@ -4,7 +4,7 @@ import os
import torch
from .dataloader import create_data_loaders, get_dataset_info
from .model import HomographyCNN6, count_parameters
from .model import HomographyCNN6, HomographyHybridCNN, HomographyLoss, count_parameters
from .train import HomographyTrainer
from .analyze import analyze_training
from .utils import config
@@ -29,18 +29,24 @@ train_loader, val_loader = create_data_loaders(
)
logger.info(f"Data loaders created: train={len(train_loader.dataset)}, val={len(val_loader.dataset)}")
model = HomographyCNN6(
# model = HomographyCNN6(
# input_channels=3,
# backbone_name=config["backbone"],
# pretrained=True,
# dropout_rate=config["dropout_rate"]
# )
model = HomographyHybridCNN(
input_channels=3,
backbone_name=config["backbone"],
pretrained=True,
dropout_rate=config["dropout_rate"]
dropout_rate=config["droupout_rate"],
)
logger.info(f"Model created with {count_parameters(model):,} parameters")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logger.info(f"Using device: {device}")
trainer = HomographyTrainer(model, train_loader, val_loader, device)
trainer = HomographyTrainer(model, train_loader, val_loader, device, HomographyLoss())
logger.info("Starting training...")
trainer.train(config["epochs"])
logger.info("Training completed")

View File

@@ -77,6 +77,126 @@ class HomographyCNN6(nn.Module):
}
class HomographyHybridCNN(nn.Module):
def __init__(self, input_channels=3, use_resnet_layers=2, dropout_rate=0.3):
super().__init__()
if use_resnet_layers == 1:
resnet = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1)
self.conv1 = resnet.conv1
self.bn1 = resnet.bn1
self.relu = resnet.relu
self.maxpool = resnet.maxpool
conv_out_channels = 64
elif use_resnet_layers == 2:
resnet = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1)
self.conv1 = resnet.conv1
self.bn1 = resnet.bn1
self.relu = resnet.relu
self.maxpool = resnet.maxpool
self.conv2 = resnet.layer1[0].conv1
self.bn2 = resnet.layer1[0].bn1
self.conv2_2 = resnet.layer1[0].conv2
self.bn2_2 = resnet.layer1[0].bn2
self.relu2 = resnet.layer1[0].relu
self.maxpool2 = resnet.maxpool
conv_out_channels = 64
else:
raise ValueError("use_resnet_layers must be 1 or 2")
self.use_resnet_layers = use_resnet_layers
self.feature_map_size = 64
self.conv_head = nn.Sequential(
nn.Conv2d(conv_out_channels, 128, kernel_size=3, padding=1),
nn.BatchNorm2d(128),
nn.ReLU(inplace=True),
nn.Conv2d(128, 256, kernel_size=3, padding=1),
nn.BatchNorm2d(256),
nn.ReLU(inplace=True),
nn.MaxPool2d(2),
)
self.global_pool = nn.AdaptiveAvgPool2d((1, 1))
feature_dim = 256 * 4
self.head = nn.Sequential(
nn.Linear(feature_dim, 1024),
nn.ReLU(inplace=True),
nn.Dropout(dropout_rate),
nn.Linear(1024, 512),
nn.ReLU(inplace=True),
nn.Dropout(dropout_rate),
nn.Linear(512, 256),
nn.ReLU(inplace=True),
nn.Dropout(dropout_rate),
nn.Linear(256, 6),
)
self._init_weights()
def _init_weights(self):
for module in self.head.modules():
if isinstance(module, nn.Linear):
nn.init.kaiming_normal_(module.weight, mode='fan_in', nonlinearity='relu')
if module.bias is not None:
nn.init.zeros_(module.bias)
def forward(self, img1, img2):
x1 = self._extract_features(img1)
x2 = self._extract_features(img2)
combined = torch.cat([x1, x2, torch.abs(x1 - x2), x1 * x2], dim=1)
output = self.head(combined)
output = torch.tanh(output)
modified = output.clone()
modified[:, 2:6] = torch.mul(output[:, 2:6], torch.pi)
return modified
def _extract_features(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.maxpool(x)
if self.use_resnet_layers >= 2:
x = self.conv2(x)
x = self.bn2(x)
x = self.relu(x)
x = self.conv2_2(x)
x = self.bn2_2(x)
x = self.relu2(x)
x = self.maxpool2(x)
x = self.conv_head(x)
x = self.global_pool(x)
x = x.view(x.size(0), -1)
return x
def decode_output(self, output):
tx = output[:, 0]
ty = output[:, 1]
scale = output[:, 5]
angle1 = output[:, 2]
angle2 = output[:, 3]
angle3 = output[:, 4]
return torch.stack([tx, ty, angle1, angle2, angle3, scale], dim=1)
def get_components(self, output):
decoded = self.decode_output(output)
return {
"tx": decoded[:, 0],
"ty": decoded[:, 1],
"rx": decoded[:, 2],
"ry": decoded[:, 3],
"rz": decoded[:, 4],
"scale": decoded[:, 5],
}
class HomographyLoss6(nn.Module):
def __init__(self, angle_loss_weight=1.0, trans_loss_weight=1.0, scale_loss_weight=1.0):
super().__init__()
@@ -138,5 +258,8 @@ class HomographyLoss6(nn.Module):
}
HomographyLoss = HomographyLoss6
def count_parameters(model):
return sum(p.numel() for p in model.parameters())

View File

@@ -12,12 +12,12 @@ from .utils import config
class HomographyTrainer:
def __init__(self, model, train_loader, val_loader, device):
def __init__(self, model, train_loader, val_loader, device, criterion):
self.model = model.to(device)
self.train_loader = train_loader
self.val_loader = val_loader
self.device = device
self.criterion = HomographyLoss6()
self.criterion = criterion
self.optimizer = optim.Adam(model.parameters(), lr=config["learning_rate"], weight_decay=1e-4)
self.writer = None
self.best_val_loss = float("inf")