feat(SiaN): improve analyzing

This commit is contained in:
2026-04-05 12:50:28 +03:00
parent ec8b3ae20e
commit fa4c4b83ae
6 changed files with 1148 additions and 543 deletions

View File

@@ -1,3 +1,3 @@
runs
*.gen.py
*.img
*.png

View File

@@ -18,7 +18,6 @@ from torch.utils.tensorboard import SummaryWriter
from torchvision import transforms, models
from tqdm import tqdm
# code: ./src/utils.py
# markdown
"""# Configuration
@@ -28,8 +27,8 @@ Global settings for:
- Model architecture options
Contains the `config` dictionary used across all modules."""
# code: ./src/utils.py
# code: ./src/dataloader.py
# markdown
"""## Dataset
@@ -42,8 +41,8 @@ Google/Yandex image pair loader with homography augmentation.
**Returns:**
- Batch dict with `google_img`, `yandex_img`, `homography_params`"""
# code: ./src/dataloader.py
# code: ./src/model.py
# markdown
"""## Model
@@ -58,8 +57,8 @@ Google/Yandex image pair loader with homography augmentation.
- Dual-branch CNN (Google + Yandex images)
- Shared backbone (configurable: resnet18/34/50)
- Fusion head with dropout regularization"""
# code: ./src/model.py
# code: ./src/train.py
# markdown
"""## Training
@@ -75,8 +74,8 @@ Google/Yandex image pair loader with homography augmentation.
**Checkpoint saving:**
- `best_model.pt` — lowest validation loss
- `checkpoint_epoch_N.pt` — periodic saves"""
# code: ./src/train.py
# code: ./src/analyze.py
# markdown
"""## Analysis
@@ -85,8 +84,8 @@ Visualization and evaluation tools:
- Training metrics plots (loss curves)
- Prediction visualization on sample images
- Error analysis and statistics"""
# code: ./src/analyze.py
# code: ./src/main.py
# markdown
"""## Main Pipeline
@@ -101,6 +100,7 @@ Executes the full training workflow:
- Model checkpoints in `runs/checkpoints/`
- TensorBoard logs in `runs/`
- Analysis plots"""
# code: ./src/main.py
# # shell:
# !zip artefacts.zip runs/gan_training/checkpoints/best_model.pt
# !zip artefacts.zip runs/checkpoints/best_model.pt runs/images/ runs/events.*

View File

@@ -2,7 +2,7 @@
"cells": [
{
"cell_type": "code",
"execution_count": null,
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
@@ -23,9 +23,21 @@
"from tqdm import tqdm\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Configuration\n",
"\n",
"Global settings for:\n",
"- Data paths and image parameters\n",
"- Training hyperparameters\n",
"- Model architecture options\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
@@ -89,11 +101,23 @@
{
"cell_type": "markdown",
"metadata": {},
"source": "# Configuration\n\nGlobal settings for:\n- Data paths and image parameters\n- Training hyperparameters\n- Model architecture options\n"
"source": [
"## Dataset\n",
"\n",
"Google/Yandex image pair loader with homography augmentation.\n",
"\n",
"**Features:**\n",
"- Loads paired images from dual camera sources\n",
"- Applies random homography transformations\n",
"- Supports configurable train/val split\n",
"\n",
"**Returns:**\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 15,
"id": "8740e758",
"metadata": {},
"outputs": [],
"source": [
@@ -103,13 +127,16 @@
"\n",
"class YaGoDataset(Dataset):\n",
" def __init__(self, root_dir: str, transform=None, augment: bool = True, \n",
" image_size: Tuple[int, int] = (256, 256)):\n",
" image_size: Tuple[int, int] = (256, 256), cache_level: int = 5):\n",
" self.root_dir = root_dir\n",
" self.transform = transform\n",
" self.augment = augment\n",
" self.image_size = image_size\n",
" self.cache_level = cache_level\n",
" self.K = get_camera_matrix(image_size[1], image_size[0])\n",
" self.image_pairs = self._discover_image_pairs()\n",
" self._load_images_to_memory()\n",
" self._init_cache()\n",
"\n",
" def _discover_image_pairs(self):\n",
" pairs = []\n",
@@ -121,28 +148,72 @@
" pairs.append({\"idx\": int(idx), \"google\": os.path.join(self.root_dir, f), \"yandex\": yandex_path})\n",
" return sorted(pairs, key=lambda x: x[\"idx\"])\n",
"\n",
" def _load_images_to_memory(self):\n",
" self._google_images = []\n",
" self._yandex_images = []\n",
" for pair in self.image_pairs:\n",
" google_img = cv2.imread(pair[\"google\"])\n",
" google_img = cv2.cvtColor(google_img, cv2.COLOR_BGR2RGB)\n",
" google_img = cv2.resize(google_img, (self.image_size[1], self.image_size[0]), interpolation=cv2.INTER_LINEAR)\n",
" \n",
" yandex_img = cv2.imread(pair[\"yandex\"])\n",
" yandex_img = cv2.cvtColor(yandex_img, cv2.COLOR_BGR2RGB)\n",
" yandex_img = cv2.resize(yandex_img, (self.image_size[1], self.image_size[0]), interpolation=cv2.INTER_LINEAR)\n",
" \n",
" self._google_images.append(google_img)\n",
" self._yandex_images.append(yandex_img)\n",
"\n",
" def _init_cache(self):\n",
" self._access_counts = [0] * len(self.image_pairs)\n",
" self._cached_google = [None] * len(self.image_pairs)\n",
" self._cached_yandex = [None] * len(self.image_pairs)\n",
" self._cached_homography = [None] * len(self.image_pairs)\n",
"\n",
" def _generate_augmented(self, idx):\n",
" google_img = self._google_images[idx].copy()\n",
" yandex_img = self._yandex_images[idx].copy()\n",
"\n",
" params1 = generate_random_homography_params()\n",
" params2 = generate_random_homography_params()\n",
" H1 = homography_params_to_matrix(params1, self.K)\n",
" H2 = homography_params_to_matrix(params2, self.K)\n",
" H_combined = np.linalg.inv(H1) @ H2\n",
" \n",
" google_warped = cv2.warpPerspective(google_img, H2, (self.image_size[1], self.image_size[0]))\n",
" yandex_warped = cv2.warpPerspective(yandex_img, H1, (self.image_size[1], self.image_size[0]))\n",
" \n",
" target_params = matrix_to_homography_params(H_combined, self.K)\n",
" \n",
" return google_warped, yandex_warped, H_combined, target_params\n",
"\n",
" def __len__(self):\n",
" return len(self.image_pairs)\n",
"\n",
" def __getitem__(self, idx):\n",
" pair = self.image_pairs[idx]\n",
" google_img = Image.open(pair[\"google\"]).convert(\"RGB\").resize((self.image_size[1], self.image_size[0]), Image.BILINEAR)\n",
" yandex_img = Image.open(pair[\"yandex\"]).convert(\"RGB\").resize((self.image_size[1], self.image_size[0]), Image.BILINEAR)\n",
"\n",
" if self.augment:\n",
" params1 = generate_random_homography_params()\n",
" params2 = generate_random_homography_params()\n",
" H1 = homography_params_to_matrix(params1, self.K)\n",
" H2 = homography_params_to_matrix(params2, self.K)\n",
" H_combined = np.linalg.inv(H1) @ H2\n",
" yandex_img = Image.fromarray(cv2.warpPerspective(np.array(yandex_img), H1, self.image_size))\n",
" google_img = Image.fromarray(cv2.warpPerspective(np.array(google_img), H2, self.image_size))\n",
" target_params = matrix_to_homography_params(H_combined, self.K)\n",
" target_matrix = H_combined\n",
" self._access_counts[idx] += 1\n",
" \n",
" use_cache = self.augment and self.cache_level > 0 and self._access_counts[idx] > 1 and (self._access_counts[idx] - 1) % self.cache_level != 0\n",
" \n",
" if use_cache:\n",
" google_img = self._cached_google[idx]\n",
" yandex_img = self._cached_yandex[idx]\n",
" target_matrix = self._cached_homography[idx]\n",
" target_params = matrix_to_homography_params(target_matrix, self.K)\n",
" elif self.augment:\n",
" google_img, yandex_img, target_matrix, target_params = self._generate_augmented(idx)\n",
" if self.cache_level > 0:\n",
" self._cached_google[idx] = google_img\n",
" self._cached_yandex[idx] = yandex_img\n",
" self._cached_homography[idx] = target_matrix\n",
" else:\n",
" google_img = self._google_images[idx]\n",
" yandex_img = self._yandex_images[idx]\n",
" target_params = np.zeros(6, dtype=np.float32)\n",
" target_matrix = np.eye(3, dtype=np.float32)\n",
"\n",
" google_img = Image.fromarray(google_img)\n",
" yandex_img = Image.fromarray(yandex_img)\n",
"\n",
" if self.transform:\n",
" google_img = self.transform(google_img)\n",
" yandex_img = self.transform(yandex_img)\n",
@@ -156,18 +227,21 @@
"\n",
"\n",
"def create_data_loaders(root_dir, batch_size=32, train_split=0.8, num_workers=0, \n",
" image_size=(256, 256), augment_train=True):\n",
" transform = transforms.Compose([transforms.ToTensor()])\n",
" image_size=(256, 256), augment_train=True, cache_level=5):\n",
" transform = transforms.Compose([\n",
" transforms.ToTensor(),\n",
" transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),\n",
" ])\n",
" \n",
" full_ds = YaGoDataset(root_dir, transform=transform, augment=False, image_size=image_size)\n",
" aug_ds = YaGoDataset(root_dir, transform=transform, augment=True, image_size=image_size)\n",
" full_ds = YaGoDataset(root_dir, transform=transform, augment=False, image_size=image_size, cache_level=cache_level)\n",
" aug_ds = YaGoDataset(root_dir, transform=transform, augment=True, image_size=image_size, cache_level=cache_level)\n",
"\n",
" indices = list(range(len(full_ds)))\n",
" random.shuffle(indices)\n",
" split = int(train_split * len(indices))\n",
" \n",
" train_ds = Subset(aug_ds if augment_train else full_ds, indices[:split])\n",
" val_ds = Subset(full_ds, indices[split:])\n",
" val_ds = Subset(aug_ds, indices[split:])\n",
"\n",
" return (DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=True),\n",
" DataLoader(val_ds, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=True))\n",
@@ -186,11 +260,24 @@
{
"cell_type": "markdown",
"metadata": {},
"source": "## Dataset\n\nGoogle/Yandex image pair loader with homography augmentation.\n\n**Features:**\n- Loads paired images from dual camera sources\n- Applies random homography transformations\n- Supports configurable train/val split\n\n**Returns:**\n"
"source": [
"## Model\n",
"\n",
"`HomographyCNN6` — CNN architecture for homography estimation.\n",
"\n",
"**Output:** 6 parameters\n",
"- `rx, ry, rz` — rotation angles (radians)\n",
"- `tx, ty` — translation offsets\n",
"- `scale` — isotropic scale factor\n",
"\n",
"**Architecture:**\n",
"- Dual-branch CNN (Google + Yandex images)\n",
"- Shared backbone (configurable: resnet18/34/50)\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
@@ -238,11 +325,25 @@
{
"cell_type": "markdown",
"metadata": {},
"source": "## Model\n\n`HomographyCNN6` — CNN architecture for homography estimation.\n\n**Output:** 6 parameters\n- `rx, ry, rz` — rotation angles (radians)\n- `tx, ty` — translation offsets\n- `scale` — isotropic scale factor\n\n**Architecture:**\n- Dual-branch CNN (Google + Yandex images)\n- Shared backbone (configurable: resnet18/34/50)\n"
"source": [
"## Training\n",
"\n",
"`HomographyTrainer` — training loop with validation and checkpointing.\n",
"\n",
"**Features:**\n",
"- Epoch-based training with tqdm progress bar\n",
"- Adam optimizer with configurable LR\n",
"- Validation after each epoch\n",
"- Best model auto-save\n",
"- Periodic checkpoints (every N epochs via `save_every_n_epochs`)\n",
"\n",
"**Checkpoint saving:**\n",
"- `best_model.pt` — lowest validation loss\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
@@ -260,10 +361,19 @@
" self.optimizer = optim.Adam(model.parameters(), lr=config[\"learning_rate\"])\n",
" self.writer = None\n",
" self.best_val_loss = float(\"inf\")\n",
" self.train_losses = []\n",
" self.val_losses = []\n",
" self.train_mse_trans = []\n",
" self.train_mse_angle = []\n",
" self.train_mse_scale = []\n",
" self.val_mse_trans = []\n",
" self.val_mse_angle = []\n",
" self.val_mse_scale = []\n",
"\n",
" def train_epoch(self, epoch):\n",
" self.model.train()\n",
" total_loss, total_samples = 0, 0\n",
" mse_trans_sum, mse_angle_sum, mse_scale_sum = 0, 0, 0\n",
" pbar = tqdm(self.train_loader, desc=f\"Epoch {epoch}\")\n",
" for batch_idx, batch in enumerate(pbar):\n",
" google_img = batch[\"google_img\"].to(self.device)\n",
@@ -278,13 +388,23 @@
"\n",
" total_loss += loss.item() * google_img.size(0)\n",
" total_samples += google_img.size(0)\n",
" \n",
" mse_trans_sum += torch.mean((output[:, 3:5] - target[:, 3:5]) ** 2).item() * google_img.size(0)\n",
" mse_angle_sum += torch.mean((output[:, 0:3] - target[:, 0:3]) ** 2).item() * google_img.size(0)\n",
" mse_scale_sum += torch.mean((output[:, 5:6] - target[:, 5:6]) ** 2).item() * google_img.size(0)\n",
" \n",
" pbar.set_postfix({\"loss\": loss.item()})\n",
"\n",
" self.train_mse_trans.append(mse_trans_sum / total_samples)\n",
" self.train_mse_angle.append(mse_angle_sum / total_samples)\n",
" self.train_mse_scale.append(mse_scale_sum / total_samples)\n",
" \n",
" return {\"loss\": total_loss / total_samples}\n",
"\n",
" def validate(self):\n",
" self.model.eval()\n",
" total_loss, total_samples = 0, 0\n",
" mse_trans_sum, mse_angle_sum, mse_scale_sum = 0, 0, 0\n",
" with torch.no_grad():\n",
" for batch in tqdm(self.val_loader, desc=\"Validation\"):\n",
" google_img = batch[\"google_img\"].to(self.device)\n",
@@ -294,6 +414,15 @@
" loss = self.criterion(output, target)\n",
" total_loss += loss.item() * google_img.size(0)\n",
" total_samples += google_img.size(0)\n",
" \n",
" mse_trans_sum += torch.mean((output[:, 3:5] - target[:, 3:5]) ** 2).item() * google_img.size(0)\n",
" mse_angle_sum += torch.mean((output[:, 0:3] - target[:, 0:3]) ** 2).item() * google_img.size(0)\n",
" mse_scale_sum += torch.mean((output[:, 5:6] - target[:, 5:6]) ** 2).item() * google_img.size(0)\n",
" \n",
" self.val_mse_trans.append(mse_trans_sum / total_samples)\n",
" self.val_mse_angle.append(mse_angle_sum / total_samples)\n",
" self.val_mse_scale.append(mse_scale_sum / total_samples)\n",
" \n",
" return {\"loss\": total_loss / total_samples}\n",
"\n",
" def train(self, num_epochs):\n",
@@ -304,7 +433,10 @@
" for epoch in range(1, num_epochs + 1):\n",
" train_metrics = self.train_epoch(epoch)\n",
" val_metrics = self.validate()\n",
" self.train_losses.append(train_metrics[\"loss\"])\n",
" self.val_losses.append(val_metrics[\"loss\"])\n",
" print(f\"Train Loss: {train_metrics['loss']:.4f}, Val Loss: {val_metrics['loss']:.4f}\")\n",
" print(f\" MSE - Trans: {self.val_mse_trans[-1]:.4f}, Angle: {self.val_mse_angle[-1]:.4f}, Scale: {self.val_mse_scale[-1]:.4f}\")\n",
"\n",
" if val_metrics[\"loss\"] < self.best_val_loss:\n",
" self.best_val_loss = val_metrics[\"loss\"]\n",
@@ -330,14 +462,25 @@
{
"cell_type": "markdown",
"metadata": {},
"source": "## Training\n\n`HomographyTrainer` — training loop with validation and checkpointing.\n\n**Features:**\n- Epoch-based training with tqdm progress bar\n- Adam optimizer with configurable LR\n- Validation after each epoch\n- Best model auto-save\n- Periodic checkpoints (every N epochs via `save_every_n_epochs`)\n\n**Checkpoint saving:**\n- `best_model.pt` — lowest validation loss\n"
"source": [
"## Analysis\n",
"\n",
"Visualization and evaluation tools:\n",
"\n",
"- Training metrics plots (loss curves)\n",
"- Prediction visualization on sample images\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"\n",
"\n",
"IMG_DIR = os.path.join(config[\"output_dir\"], \"images\")\n",
"os.makedirs(IMG_DIR, exist_ok=True)\n",
"\n",
"\n",
"def analyze_training(trainer):\n",
@@ -348,58 +491,283 @@
"\n",
" print(f\"\\nBest val loss: {trainer.best_val_loss:.4f}\")\n",
"\n",
" best_model_path = os.path.join(config[\"output_dir\"], \"checkpoints\", \"best_model.pt\")\n",
" if os.path.exists(best_model_path):\n",
" checkpoint = torch.load(best_model_path, map_location=trainer.device)\n",
" trainer.model.load_state_dict(checkpoint[\"model_state_dict\"])\n",
" print(f\"\\nLoaded best model from epoch {checkpoint['epoch']} (val loss: {checkpoint['val_loss']:.4f})\")\n",
" \n",
" trainer.model.eval()\n",
" \n",
" n_samples = 50\n",
" names = [\"rx\", \"ry\", \"rz\", \"tx\", \"ty\", \"scale\"]\n",
" \n",
" with torch.no_grad():\n",
" batch = next(iter(trainer.val_loader))\n",
" google_img = batch[\"google_img\"].to(trainer.device)\n",
" yandex_img = batch[\"yandex_img\"].to(trainer.device)\n",
" target_params = batch[\"homography_params\"].to(trainer.device)\n",
"\n",
" pred_params = trainer.model(google_img, yandex_img)\n",
"\n",
" print(f\"\\nSample predictions (first 3 of batch):\")\n",
" print(f\"{'Param':<8} {'Target':>12} {'Predicted':>12} {'Error':>12}\")\n",
" print(\"-\" * 46)\n",
" names = [\"rx\", \"ry\", \"rz\", \"tx\", \"ty\", \"scale\"]\n",
" for i in range(6):\n",
" t = target_params[0, i].item()\n",
" p = pred_params[0, i].item()\n",
" print(f\"{names[i]:<8} {t:>12.4f} {p:>12.4f} {abs(t-p):>12.4f}\")\n",
"\n",
" print(f\"\\nBatch mean abs error: {torch.mean(torch.abs(pred_params - target_params)).item():.4f}\")\n",
"\n",
" print(\"\\n=== Visualization ===\")\n",
" fig, axes = plt.subplots(1, 3, figsize=(15, 5))\n",
" img1 = google_img[0].cpu()\n",
" img2 = yandex_img[0].cpu()\n",
" axes[0].imshow(img1.permute(1, 2, 0))\n",
" axes[0].set_title(\"Google\")\n",
" axes[0].axis(\"off\")\n",
" axes[1].imshow(img2.permute(1, 2, 0))\n",
" axes[1].set_title(\"Yandex\")\n",
" axes[1].axis(\"off\")\n",
" axes[2].bar(names, pred_params[0].cpu().numpy())\n",
" axes[2].set_title(\"Predicted params\")\n",
" axes[2].axhline(y=0, color=\"k\", lw=0.5)\n",
" all_errors = [[] for _ in range(6)]\n",
" all_targets = [[] for _ in range(6)]\n",
" all_preds = [[] for _ in range(6)]\n",
" \n",
" for i in range(n_samples):\n",
" try:\n",
" batch = next(iter(trainer.val_loader))\n",
" except StopIteration:\n",
" break\n",
" google_img = batch[\"google_img\"].to(trainer.device)\n",
" yandex_img = batch[\"yandex_img\"].to(trainer.device)\n",
" target_params = batch[\"homography_params\"].to(trainer.device)\n",
" pred_params = trainer.model(google_img, yandex_img)\n",
" \n",
" for j in range(6):\n",
" all_errors[j].append(torch.abs(pred_params[0, j] - target_params[0, j]).item())\n",
" all_targets[j].append(target_params[0, j].item())\n",
" all_preds[j].append(pred_params[0, j].item())\n",
" \n",
" mean_errors = [np.mean(all_errors[i]) for i in range(6)]\n",
" std_errors = [np.std(all_errors[i]) for i in range(6)]\n",
" \n",
" if len(trainer.train_losses) > 0:\n",
" epochs = range(1, len(trainer.train_losses) + 1)\n",
" fig, axes = plt.subplots(2, 2, figsize=(16, 12))\n",
" \n",
" axes[0, 0].plot(epochs, trainer.train_losses, \"b-\", label=\"Train Loss\")\n",
" axes[0, 0].plot(epochs, trainer.val_losses, \"r-\", label=\"Val Loss\")\n",
" axes[0, 0].set_xlabel(\"Epoch\")\n",
" axes[0, 0].set_ylabel(\"Loss\")\n",
" axes[0, 0].set_title(\"Training & Validation Loss\")\n",
" axes[0, 0].legend()\n",
" axes[0, 0].grid(True, alpha=0.3)\n",
" \n",
" axes[0, 1].plot(epochs, trainer.val_losses, \"r-\", label=\"Val Loss\")\n",
" axes[0, 1].set_xlabel(\"Epoch\")\n",
" axes[0, 1].set_ylabel(\"Loss\")\n",
" axes[0, 1].set_title(\"Validation Loss\")\n",
" axes[0, 1].legend()\n",
" axes[0, 1].grid(True, alpha=0.3)\n",
" \n",
" axes[1, 0].plot(epochs, trainer.val_mse_trans, \"g-\", label=\"Translation (tx, ty)\")\n",
" axes[1, 0].plot(epochs, trainer.val_mse_angle, \"m-\", label=\"Angle (rx, ry, rz)\")\n",
" axes[1, 0].plot(epochs, trainer.val_mse_scale, \"c-\", label=\"Scale\")\n",
" axes[1, 0].set_xlabel(\"Epoch\")\n",
" axes[1, 0].set_ylabel(\"MSE\")\n",
" axes[1, 0].set_title(\"Validation MSE by Category\")\n",
" axes[1, 0].legend()\n",
" axes[1, 0].grid(True, alpha=0.3)\n",
" \n",
" x_pos = np.arange(6)\n",
" axes[1, 1].bar(x_pos, mean_errors, yerr=std_errors, capsize=5, color=[\"c\", \"m\", \"y\", \"g\", \"b\", \"r\"], alpha=0.8)\n",
" axes[1, 1].set_xticks(x_pos)\n",
" axes[1, 1].set_xticklabels(names)\n",
" axes[1, 1].set_ylabel(\"Mean Absolute Error\")\n",
" axes[1, 1].set_title(f\"Mean Absolute Error per Parameter ({n_samples} samples)\")\n",
" axes[1, 1].grid(True, alpha=0.3, axis=\"y\")\n",
" \n",
" plt.tight_layout()\n",
" plt.savefig(os.path.join(IMG_DIR, \"training_loss_plots.png\"), dpi=150)\n",
" print(\"Saved training_loss_plots.png\")\n",
" plt.show()\n",
" \n",
" fig, axes = plt.subplots(2, 3, figsize=(18, 10))\n",
" for j in range(6):\n",
" row = j // 3\n",
" col = j % 3\n",
" axes[row, col].bar(range(n_samples), all_errors[j], color=\"steelblue\", alpha=0.7)\n",
" axes[row, col].set_xlabel(\"Sample\")\n",
" axes[row, col].set_ylabel(\"Absolute Error\")\n",
" axes[row, col].set_title(f\"{names[j]}: Mean={np.mean(all_errors[j]):.4f}, Std={np.std(all_errors[j]):.4f}\")\n",
" axes[row, col].grid(True, alpha=0.3, axis=\"y\")\n",
" plt.suptitle(f\"Mean Absolute Error per Parameter ({n_samples} samples)\", fontsize=14)\n",
" plt.tight_layout()\n",
" plt.savefig(\"prediction_sample.png\")\n",
" print(\"Saved prediction_sample.png\")\n",
" plt.savefig(os.path.join(IMG_DIR, \"mae_per_parameter.png\"), dpi=150)\n",
" print(\"Saved mae_per_parameter.png\")\n",
" plt.show()\n",
" \n",
" fig, axes = plt.subplots(1, 2, figsize=(14, 6))\n",
" \n",
" x_pos = np.arange(6)\n",
" axes[0].bar(x_pos, mean_errors, yerr=std_errors, capsize=5, color=[\"c\", \"m\", \"y\", \"g\", \"b\", \"r\"], alpha=0.8)\n",
" axes[0].set_xticks(x_pos)\n",
" axes[0].set_xticklabels(names)\n",
" axes[0].set_ylabel(\"Mean Absolute Error\")\n",
" axes[0].set_title(\"Mean Absolute Error per Parameter (with std)\")\n",
" axes[0].grid(True, alpha=0.3, axis=\"y\")\n",
" \n",
" bp = axes[1].boxplot([all_errors[i] for i in range(6)], labels=names, patch_artist=True)\n",
" colors = [\"c\", \"m\", \"y\", \"g\", \"b\", \"r\"]\n",
" for patch, color in zip(bp[\"boxes\"], colors):\n",
" patch.set_facecolor(color)\n",
" patch.set_alpha(0.7)\n",
" axes[1].set_ylabel(\"Absolute Error\")\n",
" axes[1].set_title(f\"Error Distribution per Parameter ({n_samples} samples)\")\n",
" axes[1].grid(True, alpha=0.3, axis=\"y\")\n",
" \n",
" plt.tight_layout()\n",
" plt.savefig(os.path.join(IMG_DIR, \"mae_boxplot.png\"), dpi=150)\n",
" print(\"Saved mae_boxplot.png\")\n",
" plt.show()\n",
" \n",
" print(\"\\n=== Sample Predictions (20 pairs) ===\")\n",
" n_vis_samples = 20\n",
" \n",
" with torch.no_grad():\n",
" for sample_idx in range(n_vis_samples):\n",
" try:\n",
" batch = next(iter(trainer.val_loader))\n",
" except StopIteration:\n",
" break\n",
" google_img = batch[\"google_img\"].to(trainer.device)\n",
" yandex_img = batch[\"yandex_img\"].to(trainer.device)\n",
" target_params = batch[\"homography_params\"].to(trainer.device)\n",
" pred_params = trainer.model(google_img, yandex_img)\n",
" \n",
" errors = torch.abs(pred_params[0] - target_params[0]).cpu().numpy()\n",
" targets = target_params[0].cpu().numpy()\n",
" preds = pred_params[0].cpu().numpy()\n",
" \n",
" fig, axes = plt.subplots(2, 2, figsize=(12, 10))\n",
" \n",
" axes[0, 0].imshow(google_img[0].cpu().permute(1, 2, 0))\n",
" axes[0, 0].set_title(f\"Google Image\")\n",
" axes[0, 0].axis(\"off\")\n",
" \n",
" axes[0, 1].imshow(yandex_img[0].cpu().permute(1, 2, 0))\n",
" axes[0, 1].set_title(f\"Yandex Image\")\n",
" axes[0, 1].axis(\"off\")\n",
" \n",
" x_pos = np.arange(6)\n",
" width = 0.35\n",
" axes[1, 0].bar(x_pos - width/2, targets, width, label=\"Target\", color=\"steelblue\", alpha=0.8)\n",
" axes[1, 0].bar(x_pos + width/2, preds, width, label=\"Predicted\", color=\"coral\", alpha=0.8)\n",
" axes[1, 0].set_xticks(x_pos)\n",
" axes[1, 0].set_xticklabels(names)\n",
" axes[1, 0].set_ylabel(\"Parameter Value\")\n",
" axes[1, 0].set_title(\"Target vs Predicted\")\n",
" axes[1, 0].legend()\n",
" axes[1, 0].grid(True, alpha=0.3, axis=\"y\")\n",
" \n",
" axes[1, 1].bar(x_pos, errors, color=[\"c\", \"m\", \"y\", \"g\", \"b\", \"r\"], alpha=0.8)\n",
" axes[1, 1].set_xticks(x_pos)\n",
" axes[1, 1].set_xticklabels(names)\n",
" axes[1, 1].set_ylabel(\"Absolute Error\")\n",
" axes[1, 1].set_title(f\"Prediction Error (Mean: {np.mean(errors):.4f})\")\n",
" axes[1, 1].grid(True, alpha=0.3, axis=\"y\")\n",
" for i, e in enumerate(errors):\n",
" axes[1, 1].text(i, e + 0.01, f\"{e:.3f}\", ha=\"center\", va=\"bottom\", fontsize=8)\n",
" \n",
" plt.suptitle(f\"Sample {sample_idx + 1}\", fontsize=14)\n",
" plt.tight_layout()\n",
" plt.savefig(os.path.join(IMG_DIR, f\"prediction_sample_{sample_idx + 1:02d}.png\"), dpi=100)\n",
" plt.show()\n",
" print(f\"Saved prediction_sample_{sample_idx + 1:02d}.png\")\n",
" \n",
" print(f\"\\nPrediction errors over {n_samples} samples:\")\n",
" print(f\"{'Param':<8} {'Mean Error':>12} {'Std Error':>12} {'Min':>8} {'Max':>8}\")\n",
" print(\"-\" * 52)\n",
" for i in range(6):\n",
" mean_err = np.mean(all_errors[i])\n",
" std_err = np.std(all_errors[i])\n",
" min_err = np.min(all_errors[i])\n",
" max_err = np.max(all_errors[i])\n",
" print(f\"{names[i]:<8} {mean_err:>12.4f} {std_err:>12.4f} {min_err:>8.4f} {max_err:>8.4f}\")\n",
"\n",
" return {\"best_val_loss\": trainer.best_val_loss}\n",
" return {\n",
" \"best_val_loss\": trainer.best_val_loss,\n",
" \"train_losses\": trainer.train_losses,\n",
" \"val_losses\": trainer.val_losses,\n",
" \"val_mse_trans\": trainer.val_mse_trans,\n",
" \"val_mse_angle\": trainer.val_mse_angle,\n",
" \"val_mse_scale\": trainer.val_mse_scale,\n",
" }\n",
"\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": "## Analysis\n\nVisualization and evaluation tools:\n\n- Training metrics plots (loss curves)\n- Prediction visualization on sample images\n"
"source": [
"## Main Pipeline\n",
"\n",
"Executes the full training workflow:\n",
"1. Load dataset info\n",
"2. Create data loaders\n",
"3. Initialize model\n",
"4. Train with validation\n",
"5. Analyze and export results\n",
"\n",
"**Outputs:**\n",
"- Model checkpoints in `runs/checkpoints/`\n",
"- TensorBoard logs in `runs/`\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"2026-04-05 12:35:20,728 - ==================================================\n",
"2026-04-05 12:35:20,730 - SiaN Training Pipeline\n",
"2026-04-05 12:35:20,731 - ==================================================\n",
"2026-04-05 12:35:32,423 - Dataset: 327 samples, keys=['google_img', 'yandex_img', 'homography_matrix', 'homography_params']\n",
"2026-04-05 12:35:54,074 - Data loaders created: train=261, val=66\n",
"2026-04-05 12:35:54,366 - Model created with 12,358,470 parameters\n",
"2026-04-05 12:35:54,368 - Using device: cpu\n",
"2026-04-05 12:35:54,374 - Starting training...\n",
"Epoch 1: 0%| | 0/9 [00:00<?, ?it/s]c:\\Users\\admin\\Projects\\autopilot\\.venv\\Lib\\site-packages\\torch\\utils\\data\\dataloader.py:775: UserWarning: 'pin_memory' argument is set as true but no accelerator is found, then device pinned memory won't be used.\n",
" super().__init__(loader)\n",
"Epoch 1: 100%|██████████| 9/9 [00:35<00:00, 3.98s/it, loss=0.795]\n",
"Validation: 100%|██████████| 3/3 [00:03<00:00, 1.13s/it]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Train Loss: 0.6337, Val Loss: 0.6173\n",
" MSE - Trans: 0.0323, Angle: 1.2063, Scale: 0.0204\n",
"Best model saved (val loss: 0.6173)\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Epoch 2: 44%|████▍ | 4/9 [00:18<00:22, 4.52s/it, loss=0.631]\n"
]
},
{
"ename": "KeyboardInterrupt",
"evalue": "",
"output_type": "error",
"traceback": [
"\u001b[31m---------------------------------------------------------------------------\u001b[39m",
"\u001b[31mKeyboardInterrupt\u001b[39m Traceback (most recent call last)",
"\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[7]\u001b[39m\u001b[32m, line 33\u001b[39m\n\u001b[32m 31\u001b[39m trainer = HomographyTrainer(model, train_loader, val_loader, device)\n\u001b[32m 32\u001b[39m logger.info(\u001b[33m\"\u001b[39m\u001b[33mStarting training...\u001b[39m\u001b[33m\"\u001b[39m)\n\u001b[32m---> \u001b[39m\u001b[32m33\u001b[39m \u001b[43mtrainer\u001b[49m\u001b[43m.\u001b[49m\u001b[43mtrain\u001b[49m\u001b[43m(\u001b[49m\u001b[43mconfig\u001b[49m\u001b[43m[\u001b[49m\u001b[33;43m\"\u001b[39;49m\u001b[33;43mepochs\u001b[39;49m\u001b[33;43m\"\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 34\u001b[39m logger.info(\u001b[33m\"\u001b[39m\u001b[33mTraining completed\u001b[39m\u001b[33m\"\u001b[39m)\n\u001b[32m 36\u001b[39m logger.info(\u001b[33m\"\u001b[39m\u001b[33mAnalyzing model...\u001b[39m\u001b[33m\"\u001b[39m)\n",
"\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[5]\u001b[39m\u001b[32m, line 81\u001b[39m, in \u001b[36mHomographyTrainer.train\u001b[39m\u001b[34m(self, num_epochs)\u001b[39m\n\u001b[32m 78\u001b[39m \u001b[38;5;28mself\u001b[39m.writer = SummaryWriter(log_dir)\n\u001b[32m 80\u001b[39m \u001b[38;5;28;01mfor\u001b[39;00m epoch \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mrange\u001b[39m(\u001b[32m1\u001b[39m, num_epochs + \u001b[32m1\u001b[39m):\n\u001b[32m---> \u001b[39m\u001b[32m81\u001b[39m train_metrics = \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43mtrain_epoch\u001b[49m\u001b[43m(\u001b[49m\u001b[43mepoch\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 82\u001b[39m val_metrics = \u001b[38;5;28mself\u001b[39m.validate()\n\u001b[32m 83\u001b[39m \u001b[38;5;28mself\u001b[39m.train_losses.append(train_metrics[\u001b[33m\"\u001b[39m\u001b[33mloss\u001b[39m\u001b[33m\"\u001b[39m])\n",
"\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[5]\u001b[39m\u001b[32m, line 31\u001b[39m, in \u001b[36mHomographyTrainer.train_epoch\u001b[39m\u001b[34m(self, epoch)\u001b[39m\n\u001b[32m 28\u001b[39m target = batch[\u001b[33m\"\u001b[39m\u001b[33mhomography_params\u001b[39m\u001b[33m\"\u001b[39m].to(\u001b[38;5;28mself\u001b[39m.device)\n\u001b[32m 30\u001b[39m \u001b[38;5;28mself\u001b[39m.optimizer.zero_grad()\n\u001b[32m---> \u001b[39m\u001b[32m31\u001b[39m output = \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43mmodel\u001b[49m\u001b[43m(\u001b[49m\u001b[43mgoogle_img\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43myandex_img\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 32\u001b[39m loss = \u001b[38;5;28mself\u001b[39m.criterion(output, target)\n\u001b[32m 33\u001b[39m loss.backward()\n",
"\u001b[36mFile \u001b[39m\u001b[32mc:\\Users\\admin\\Projects\\autopilot\\.venv\\Lib\\site-packages\\torch\\nn\\modules\\module.py:1776\u001b[39m, in \u001b[36mModule._wrapped_call_impl\u001b[39m\u001b[34m(self, *args, **kwargs)\u001b[39m\n\u001b[32m 1774\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m._compiled_call_impl(*args, **kwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[32m 1775\u001b[39m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[32m-> \u001b[39m\u001b[32m1776\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[43m*\u001b[49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m*\u001b[49m\u001b[43m*\u001b[49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
"\u001b[36mFile \u001b[39m\u001b[32mc:\\Users\\admin\\Projects\\autopilot\\.venv\\Lib\\site-packages\\torch\\nn\\modules\\module.py:1787\u001b[39m, in \u001b[36mModule._call_impl\u001b[39m\u001b[34m(self, *args, **kwargs)\u001b[39m\n\u001b[32m 1782\u001b[39m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[32m 1783\u001b[39m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[32m 1784\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m._backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m._backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m._forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m._forward_pre_hooks\n\u001b[32m 1785\u001b[39m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[32m 1786\u001b[39m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[32m-> \u001b[39m\u001b[32m1787\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[43m*\u001b[49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m*\u001b[49m\u001b[43m*\u001b[49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 1789\u001b[39m result = \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[32m 1790\u001b[39m called_always_called_hooks = \u001b[38;5;28mset\u001b[39m()\n",
"\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[4]\u001b[39m\u001b[32m, line 20\u001b[39m, in \u001b[36mHomographyCNN6.forward\u001b[39m\u001b[34m(self, img1, img2)\u001b[39m\n\u001b[32m 19\u001b[39m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34mforward\u001b[39m(\u001b[38;5;28mself\u001b[39m, img1, img2):\n\u001b[32m---> \u001b[39m\u001b[32m20\u001b[39m f1 = \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43mbackbone\u001b[49m\u001b[43m(\u001b[49m\u001b[43mimg1\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 21\u001b[39m f2 = \u001b[38;5;28mself\u001b[39m.backbone(img2)\n\u001b[32m 22\u001b[39m combined = torch.cat([f1, f2, torch.abs(f1 - f2), f1 * f2], dim=\u001b[32m1\u001b[39m)\n",
"\u001b[36mFile \u001b[39m\u001b[32mc:\\Users\\admin\\Projects\\autopilot\\.venv\\Lib\\site-packages\\torch\\nn\\modules\\module.py:1776\u001b[39m, in \u001b[36mModule._wrapped_call_impl\u001b[39m\u001b[34m(self, *args, **kwargs)\u001b[39m\n\u001b[32m 1774\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m._compiled_call_impl(*args, **kwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[32m 1775\u001b[39m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[32m-> \u001b[39m\u001b[32m1776\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[43m*\u001b[49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m*\u001b[49m\u001b[43m*\u001b[49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
"\u001b[36mFile \u001b[39m\u001b[32mc:\\Users\\admin\\Projects\\autopilot\\.venv\\Lib\\site-packages\\torch\\nn\\modules\\module.py:1787\u001b[39m, in \u001b[36mModule._call_impl\u001b[39m\u001b[34m(self, *args, **kwargs)\u001b[39m\n\u001b[32m 1782\u001b[39m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[32m 1783\u001b[39m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[32m 1784\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m._backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m._backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m._forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m._forward_pre_hooks\n\u001b[32m 1785\u001b[39m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[32m 1786\u001b[39m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[32m-> \u001b[39m\u001b[32m1787\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[43m*\u001b[49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m*\u001b[49m\u001b[43m*\u001b[49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 1789\u001b[39m result = \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[32m 1790\u001b[39m called_always_called_hooks = \u001b[38;5;28mset\u001b[39m()\n",
"\u001b[36mFile \u001b[39m\u001b[32mc:\\Users\\admin\\Projects\\autopilot\\.venv\\Lib\\site-packages\\torchvision\\models\\resnet.py:285\u001b[39m, in \u001b[36mResNet.forward\u001b[39m\u001b[34m(self, x)\u001b[39m\n\u001b[32m 284\u001b[39m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34mforward\u001b[39m(\u001b[38;5;28mself\u001b[39m, x: Tensor) -> Tensor:\n\u001b[32m--> \u001b[39m\u001b[32m285\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43m_forward_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[43mx\u001b[49m\u001b[43m)\u001b[49m\n",
"\u001b[36mFile \u001b[39m\u001b[32mc:\\Users\\admin\\Projects\\autopilot\\.venv\\Lib\\site-packages\\torchvision\\models\\resnet.py:274\u001b[39m, in \u001b[36mResNet._forward_impl\u001b[39m\u001b[34m(self, x)\u001b[39m\n\u001b[32m 271\u001b[39m x = \u001b[38;5;28mself\u001b[39m.maxpool(x)\n\u001b[32m 273\u001b[39m x = \u001b[38;5;28mself\u001b[39m.layer1(x)\n\u001b[32m--> \u001b[39m\u001b[32m274\u001b[39m x = \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43mlayer2\u001b[49m\u001b[43m(\u001b[49m\u001b[43mx\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 275\u001b[39m x = \u001b[38;5;28mself\u001b[39m.layer3(x)\n\u001b[32m 276\u001b[39m x = \u001b[38;5;28mself\u001b[39m.layer4(x)\n",
"\u001b[36mFile \u001b[39m\u001b[32mc:\\Users\\admin\\Projects\\autopilot\\.venv\\Lib\\site-packages\\torch\\nn\\modules\\module.py:1776\u001b[39m, in \u001b[36mModule._wrapped_call_impl\u001b[39m\u001b[34m(self, *args, **kwargs)\u001b[39m\n\u001b[32m 1774\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m._compiled_call_impl(*args, **kwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[32m 1775\u001b[39m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[32m-> \u001b[39m\u001b[32m1776\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[43m*\u001b[49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m*\u001b[49m\u001b[43m*\u001b[49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
"\u001b[36mFile \u001b[39m\u001b[32mc:\\Users\\admin\\Projects\\autopilot\\.venv\\Lib\\site-packages\\torch\\nn\\modules\\module.py:1787\u001b[39m, in \u001b[36mModule._call_impl\u001b[39m\u001b[34m(self, *args, **kwargs)\u001b[39m\n\u001b[32m 1782\u001b[39m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[32m 1783\u001b[39m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[32m 1784\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m._backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m._backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m._forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m._forward_pre_hooks\n\u001b[32m 1785\u001b[39m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[32m 1786\u001b[39m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[32m-> \u001b[39m\u001b[32m1787\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[43m*\u001b[49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m*\u001b[49m\u001b[43m*\u001b[49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 1789\u001b[39m result = \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[32m 1790\u001b[39m called_always_called_hooks = \u001b[38;5;28mset\u001b[39m()\n",
"\u001b[36mFile \u001b[39m\u001b[32mc:\\Users\\admin\\Projects\\autopilot\\.venv\\Lib\\site-packages\\torch\\nn\\modules\\container.py:253\u001b[39m, in \u001b[36mSequential.forward\u001b[39m\u001b[34m(self, input)\u001b[39m\n\u001b[32m 249\u001b[39m \u001b[38;5;250m\u001b[39m\u001b[33;03m\"\"\"\u001b[39;00m\n\u001b[32m 250\u001b[39m \u001b[33;03mRuns the forward pass.\u001b[39;00m\n\u001b[32m 251\u001b[39m \u001b[33;03m\"\"\"\u001b[39;00m\n\u001b[32m 252\u001b[39m \u001b[38;5;28;01mfor\u001b[39;00m module \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mself\u001b[39m:\n\u001b[32m--> \u001b[39m\u001b[32m253\u001b[39m \u001b[38;5;28minput\u001b[39m = \u001b[43mmodule\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m)\u001b[49m\n\u001b[32m 254\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28minput\u001b[39m\n",
"\u001b[36mFile \u001b[39m\u001b[32mc:\\Users\\admin\\Projects\\autopilot\\.venv\\Lib\\site-packages\\torch\\nn\\modules\\module.py:1776\u001b[39m, in \u001b[36mModule._wrapped_call_impl\u001b[39m\u001b[34m(self, *args, **kwargs)\u001b[39m\n\u001b[32m 1774\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m._compiled_call_impl(*args, **kwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[32m 1775\u001b[39m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[32m-> \u001b[39m\u001b[32m1776\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[43m*\u001b[49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m*\u001b[49m\u001b[43m*\u001b[49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
"\u001b[36mFile \u001b[39m\u001b[32mc:\\Users\\admin\\Projects\\autopilot\\.venv\\Lib\\site-packages\\torch\\nn\\modules\\module.py:1787\u001b[39m, in \u001b[36mModule._call_impl\u001b[39m\u001b[34m(self, *args, **kwargs)\u001b[39m\n\u001b[32m 1782\u001b[39m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[32m 1783\u001b[39m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[32m 1784\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m._backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m._backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m._forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m._forward_pre_hooks\n\u001b[32m 1785\u001b[39m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[32m 1786\u001b[39m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[32m-> \u001b[39m\u001b[32m1787\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[43m*\u001b[49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m*\u001b[49m\u001b[43m*\u001b[49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 1789\u001b[39m result = \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[32m 1790\u001b[39m called_always_called_hooks = \u001b[38;5;28mset\u001b[39m()\n",
"\u001b[36mFile \u001b[39m\u001b[32mc:\\Users\\admin\\Projects\\autopilot\\.venv\\Lib\\site-packages\\torchvision\\models\\resnet.py:96\u001b[39m, in \u001b[36mBasicBlock.forward\u001b[39m\u001b[34m(self, x)\u001b[39m\n\u001b[32m 93\u001b[39m out = \u001b[38;5;28mself\u001b[39m.bn1(out)\n\u001b[32m 94\u001b[39m out = \u001b[38;5;28mself\u001b[39m.relu(out)\n\u001b[32m---> \u001b[39m\u001b[32m96\u001b[39m out = \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43mconv2\u001b[49m\u001b[43m(\u001b[49m\u001b[43mout\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 97\u001b[39m out = \u001b[38;5;28mself\u001b[39m.bn2(out)\n\u001b[32m 99\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m.downsample \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n",
"\u001b[36mFile \u001b[39m\u001b[32mc:\\Users\\admin\\Projects\\autopilot\\.venv\\Lib\\site-packages\\torch\\nn\\modules\\module.py:1776\u001b[39m, in \u001b[36mModule._wrapped_call_impl\u001b[39m\u001b[34m(self, *args, **kwargs)\u001b[39m\n\u001b[32m 1774\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m._compiled_call_impl(*args, **kwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[32m 1775\u001b[39m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[32m-> \u001b[39m\u001b[32m1776\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[43m*\u001b[49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m*\u001b[49m\u001b[43m*\u001b[49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
"\u001b[36mFile \u001b[39m\u001b[32mc:\\Users\\admin\\Projects\\autopilot\\.venv\\Lib\\site-packages\\torch\\nn\\modules\\module.py:1787\u001b[39m, in \u001b[36mModule._call_impl\u001b[39m\u001b[34m(self, *args, **kwargs)\u001b[39m\n\u001b[32m 1782\u001b[39m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[32m 1783\u001b[39m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[32m 1784\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m._backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m._backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m._forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m._forward_pre_hooks\n\u001b[32m 1785\u001b[39m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[32m 1786\u001b[39m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[32m-> \u001b[39m\u001b[32m1787\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[43m*\u001b[49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m*\u001b[49m\u001b[43m*\u001b[49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 1789\u001b[39m result = \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[32m 1790\u001b[39m called_always_called_hooks = \u001b[38;5;28mset\u001b[39m()\n",
"\u001b[36mFile \u001b[39m\u001b[32mc:\\Users\\admin\\Projects\\autopilot\\.venv\\Lib\\site-packages\\torch\\nn\\modules\\conv.py:553\u001b[39m, in \u001b[36mConv2d.forward\u001b[39m\u001b[34m(self, input)\u001b[39m\n\u001b[32m 552\u001b[39m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34mforward\u001b[39m(\u001b[38;5;28mself\u001b[39m, \u001b[38;5;28minput\u001b[39m: Tensor) -> Tensor:\n\u001b[32m--> \u001b[39m\u001b[32m553\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43m_conv_forward\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43mweight\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43mbias\u001b[49m\u001b[43m)\u001b[49m\n",
"\u001b[36mFile \u001b[39m\u001b[32mc:\\Users\\admin\\Projects\\autopilot\\.venv\\Lib\\site-packages\\torch\\nn\\modules\\conv.py:548\u001b[39m, in \u001b[36mConv2d._conv_forward\u001b[39m\u001b[34m(self, input, weight, bias)\u001b[39m\n\u001b[32m 535\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m.padding_mode != \u001b[33m\"\u001b[39m\u001b[33mzeros\u001b[39m\u001b[33m\"\u001b[39m:\n\u001b[32m 536\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m F.conv2d(\n\u001b[32m 537\u001b[39m F.pad(\n\u001b[32m 538\u001b[39m \u001b[38;5;28minput\u001b[39m, \u001b[38;5;28mself\u001b[39m._reversed_padding_repeated_twice, mode=\u001b[38;5;28mself\u001b[39m.padding_mode\n\u001b[32m (...)\u001b[39m\u001b[32m 545\u001b[39m \u001b[38;5;28mself\u001b[39m.groups,\n\u001b[32m 546\u001b[39m )\n\u001b[32m--> \u001b[39m\u001b[32m548\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mF\u001b[49m\u001b[43m.\u001b[49m\u001b[43mconv2d\u001b[49m\u001b[43m(\u001b[49m\n\u001b[32m 549\u001b[39m \u001b[43m \u001b[49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mweight\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mbias\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43mstride\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43mpadding\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43mdilation\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43mgroups\u001b[49m\n\u001b[32m 550\u001b[39m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n",
"\u001b[31mKeyboardInterrupt\u001b[39m: "
]
}
],
"source": [
"\n",
"\n",
@@ -450,18 +818,13 @@
"\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": "## Main Pipeline\n\nExecutes the full training workflow:\n1. Load dataset info\n2. Create data loaders\n3. Initialize model\n4. Train with validation\n5. Analyze and export results\n\n**Outputs:**\n- Model checkpoints in `runs/checkpoints/`\n- TensorBoard logs in `runs/`\n"
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"!zip artefacts.zip runs/gan_training/checkpoints/best_model.pt\n",
"!zip artefacts.zip runs/checkpoints/best_model.pt runs/images/ runs/events.*\n",
"\n"
]
}
@@ -473,7 +836,15 @@
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.0"
}
},

View File

@@ -1,7 +1,13 @@
import os
import torch
import numpy as np
import matplotlib.pyplot as plt
from .utils import config
IMG_DIR = os.path.join(config["output_dir"], "images")
os.makedirs(IMG_DIR, exist_ok=True)
def analyze_training(trainer):
print("=== Training Analysis ===\n")
@@ -11,42 +17,189 @@ def analyze_training(trainer):
print(f"\nBest val loss: {trainer.best_val_loss:.4f}")
best_model_path = os.path.join(config["output_dir"], "checkpoints", "best_model.pt")
if os.path.exists(best_model_path):
checkpoint = torch.load(best_model_path, map_location=trainer.device)
trainer.model.load_state_dict(checkpoint["model_state_dict"])
print(f"\nLoaded best model from epoch {checkpoint['epoch']} (val loss: {checkpoint['val_loss']:.4f})")
trainer.model.eval()
n_samples = 50
names = ["rx", "ry", "rz", "tx", "ty", "scale"]
with torch.no_grad():
batch = next(iter(trainer.val_loader))
google_img = batch["google_img"].to(trainer.device)
yandex_img = batch["yandex_img"].to(trainer.device)
target_params = batch["homography_params"].to(trainer.device)
all_errors = [[] for _ in range(6)]
all_targets = [[] for _ in range(6)]
all_preds = [[] for _ in range(6)]
pred_params = trainer.model(google_img, yandex_img)
for i in range(n_samples):
try:
batch = next(iter(trainer.val_loader))
except StopIteration:
break
google_img = batch["google_img"].to(trainer.device)
yandex_img = batch["yandex_img"].to(trainer.device)
target_params = batch["homography_params"].to(trainer.device)
pred_params = trainer.model(google_img, yandex_img)
print(f"\nSample predictions (first 3 of batch):")
print(f"{'Param':<8} {'Target':>12} {'Predicted':>12} {'Error':>12}")
print("-" * 46)
names = ["rx", "ry", "rz", "tx", "ty", "scale"]
for i in range(6):
t = target_params[0, i].item()
p = pred_params[0, i].item()
print(f"{names[i]:<8} {t:>12.4f} {p:>12.4f} {abs(t-p):>12.4f}")
for j in range(6):
all_errors[j].append(torch.abs(pred_params[0, j] - target_params[0, j]).item())
all_targets[j].append(target_params[0, j].item())
all_preds[j].append(pred_params[0, j].item())
print(f"\nBatch mean abs error: {torch.mean(torch.abs(pred_params - target_params)).item():.4f}")
mean_errors = [np.mean(all_errors[i]) for i in range(6)]
std_errors = [np.std(all_errors[i]) for i in range(6)]
print("\n=== Visualization ===")
fig, axes = plt.subplots(1, 3, figsize=(15, 5))
img1 = google_img[0].cpu()
img2 = yandex_img[0].cpu()
axes[0].imshow(img1.permute(1, 2, 0))
axes[0].set_title("Google")
axes[0].axis("off")
axes[1].imshow(img2.permute(1, 2, 0))
axes[1].set_title("Yandex")
axes[1].axis("off")
axes[2].bar(names, pred_params[0].cpu().numpy())
axes[2].set_title("Predicted params")
axes[2].axhline(y=0, color="k", lw=0.5)
if len(trainer.train_losses) > 0:
epochs = range(1, len(trainer.train_losses) + 1)
fig, axes = plt.subplots(2, 2, figsize=(16, 12))
axes[0, 0].plot(epochs, trainer.train_losses, "b-", label="Train Loss")
axes[0, 0].plot(epochs, trainer.val_losses, "r-", label="Val Loss")
axes[0, 0].set_xlabel("Epoch")
axes[0, 0].set_ylabel("Loss")
axes[0, 0].set_title("Training & Validation Loss")
axes[0, 0].legend()
axes[0, 0].grid(True, alpha=0.3)
axes[0, 1].plot(epochs, trainer.val_losses, "r-", label="Val Loss")
axes[0, 1].set_xlabel("Epoch")
axes[0, 1].set_ylabel("Loss")
axes[0, 1].set_title("Validation Loss")
axes[0, 1].legend()
axes[0, 1].grid(True, alpha=0.3)
axes[1, 0].plot(epochs, trainer.val_mse_trans, "g-", label="Translation (tx, ty)")
axes[1, 0].plot(epochs, trainer.val_mse_angle, "m-", label="Angle (rx, ry, rz)")
axes[1, 0].plot(epochs, trainer.val_mse_scale, "c-", label="Scale")
axes[1, 0].set_xlabel("Epoch")
axes[1, 0].set_ylabel("MSE")
axes[1, 0].set_title("Validation MSE by Category")
axes[1, 0].legend()
axes[1, 0].grid(True, alpha=0.3)
x_pos = np.arange(6)
axes[1, 1].bar(x_pos, mean_errors, yerr=std_errors, capsize=5, color=["c", "m", "y", "g", "b", "r"], alpha=0.8)
axes[1, 1].set_xticks(x_pos)
axes[1, 1].set_xticklabels(names)
axes[1, 1].set_ylabel("Mean Absolute Error")
axes[1, 1].set_title(f"Mean Absolute Error per Parameter ({n_samples} samples)")
axes[1, 1].grid(True, alpha=0.3, axis="y")
plt.tight_layout()
plt.savefig(os.path.join(IMG_DIR, "training_loss_plots.png"), dpi=150)
print("Saved training_loss_plots.png")
plt.show()
fig, axes = plt.subplots(2, 3, figsize=(18, 10))
for j in range(6):
row = j // 3
col = j % 3
axes[row, col].bar(range(n_samples), all_errors[j], color="steelblue", alpha=0.7)
axes[row, col].set_xlabel("Sample")
axes[row, col].set_ylabel("Absolute Error")
axes[row, col].set_title(f"{names[j]}: Mean={np.mean(all_errors[j]):.4f}, Std={np.std(all_errors[j]):.4f}")
axes[row, col].grid(True, alpha=0.3, axis="y")
plt.suptitle(f"Mean Absolute Error per Parameter ({n_samples} samples)", fontsize=14)
plt.tight_layout()
plt.savefig("prediction_sample.png")
print("Saved prediction_sample.png")
plt.savefig(os.path.join(IMG_DIR, "mae_per_parameter.png"), dpi=150)
print("Saved mae_per_parameter.png")
plt.show()
return {"best_val_loss": trainer.best_val_loss}
fig, axes = plt.subplots(1, 2, figsize=(14, 6))
x_pos = np.arange(6)
axes[0].bar(x_pos, mean_errors, yerr=std_errors, capsize=5, color=["c", "m", "y", "g", "b", "r"], alpha=0.8)
axes[0].set_xticks(x_pos)
axes[0].set_xticklabels(names)
axes[0].set_ylabel("Mean Absolute Error")
axes[0].set_title("Mean Absolute Error per Parameter (with std)")
axes[0].grid(True, alpha=0.3, axis="y")
bp = axes[1].boxplot([all_errors[i] for i in range(6)], labels=names, patch_artist=True)
colors = ["c", "m", "y", "g", "b", "r"]
for patch, color in zip(bp["boxes"], colors):
patch.set_facecolor(color)
patch.set_alpha(0.7)
axes[1].set_ylabel("Absolute Error")
axes[1].set_title(f"Error Distribution per Parameter ({n_samples} samples)")
axes[1].grid(True, alpha=0.3, axis="y")
plt.tight_layout()
plt.savefig(os.path.join(IMG_DIR, "mae_boxplot.png"), dpi=150)
print("Saved mae_boxplot.png")
plt.show()
print("\n=== Sample Predictions (20 pairs) ===")
n_vis_samples = 20
with torch.no_grad():
for sample_idx in range(n_vis_samples):
try:
batch = next(iter(trainer.val_loader))
except StopIteration:
break
google_img = batch["google_img"].to(trainer.device)
yandex_img = batch["yandex_img"].to(trainer.device)
target_params = batch["homography_params"].to(trainer.device)
pred_params = trainer.model(google_img, yandex_img)
errors = torch.abs(pred_params[0] - target_params[0]).cpu().numpy()
targets = target_params[0].cpu().numpy()
preds = pred_params[0].cpu().numpy()
fig, axes = plt.subplots(2, 2, figsize=(12, 10))
axes[0, 0].imshow(google_img[0].cpu().permute(1, 2, 0))
axes[0, 0].set_title(f"Google Image")
axes[0, 0].axis("off")
axes[0, 1].imshow(yandex_img[0].cpu().permute(1, 2, 0))
axes[0, 1].set_title(f"Yandex Image")
axes[0, 1].axis("off")
x_pos = np.arange(6)
width = 0.35
axes[1, 0].bar(x_pos - width/2, targets, width, label="Target", color="steelblue", alpha=0.8)
axes[1, 0].bar(x_pos + width/2, preds, width, label="Predicted", color="coral", alpha=0.8)
axes[1, 0].set_xticks(x_pos)
axes[1, 0].set_xticklabels(names)
axes[1, 0].set_ylabel("Parameter Value")
axes[1, 0].set_title("Target vs Predicted")
axes[1, 0].legend()
axes[1, 0].grid(True, alpha=0.3, axis="y")
axes[1, 1].bar(x_pos, errors, color=["c", "m", "y", "g", "b", "r"], alpha=0.8)
axes[1, 1].set_xticks(x_pos)
axes[1, 1].set_xticklabels(names)
axes[1, 1].set_ylabel("Absolute Error")
axes[1, 1].set_title(f"Prediction Error (Mean: {np.mean(errors):.4f})")
axes[1, 1].grid(True, alpha=0.3, axis="y")
for i, e in enumerate(errors):
axes[1, 1].text(i, e + 0.01, f"{e:.3f}", ha="center", va="bottom", fontsize=8)
plt.suptitle(f"Sample {sample_idx + 1}", fontsize=14)
plt.tight_layout()
plt.savefig(os.path.join(IMG_DIR, f"prediction_sample_{sample_idx + 1:02d}.png"), dpi=100)
plt.show()
print(f"Saved prediction_sample_{sample_idx + 1:02d}.png")
print(f"\nPrediction errors over {n_samples} samples:")
print(f"{'Param':<8} {'Mean Error':>12} {'Std Error':>12} {'Min':>8} {'Max':>8}")
print("-" * 52)
for i in range(6):
mean_err = np.mean(all_errors[i])
std_err = np.std(all_errors[i])
min_err = np.min(all_errors[i])
max_err = np.max(all_errors[i])
print(f"{names[i]:<8} {mean_err:>12.4f} {std_err:>12.4f} {min_err:>8.4f} {max_err:>8.4f}")
return {
"best_val_loss": trainer.best_val_loss,
"train_losses": trainer.train_losses,
"val_losses": trainer.val_losses,
"val_mse_trans": trainer.val_mse_trans,
"val_mse_angle": trainer.val_mse_angle,
"val_mse_scale": trainer.val_mse_scale,
}

View File

@@ -14,13 +14,16 @@ from .utils import config, get_camera_matrix, generate_random_homography_params,
class YaGoDataset(Dataset):
def __init__(self, root_dir: str, transform=None, augment: bool = True,
image_size: Tuple[int, int] = (256, 256)):
image_size: Tuple[int, int] = (256, 256), cache_level: int = 5):
self.root_dir = root_dir
self.transform = transform
self.augment = augment
self.image_size = image_size
self.cache_level = cache_level
self.K = get_camera_matrix(image_size[1], image_size[0])
self.image_pairs = self._discover_image_pairs()
self._load_images_to_memory()
self._init_cache()
def _discover_image_pairs(self):
pairs = []
@@ -32,28 +35,72 @@ class YaGoDataset(Dataset):
pairs.append({"idx": int(idx), "google": os.path.join(self.root_dir, f), "yandex": yandex_path})
return sorted(pairs, key=lambda x: x["idx"])
def _load_images_to_memory(self):
self._google_images = []
self._yandex_images = []
for pair in self.image_pairs:
google_img = cv2.imread(pair["google"])
google_img = cv2.cvtColor(google_img, cv2.COLOR_BGR2RGB)
google_img = cv2.resize(google_img, (self.image_size[1], self.image_size[0]), interpolation=cv2.INTER_LINEAR)
yandex_img = cv2.imread(pair["yandex"])
yandex_img = cv2.cvtColor(yandex_img, cv2.COLOR_BGR2RGB)
yandex_img = cv2.resize(yandex_img, (self.image_size[1], self.image_size[0]), interpolation=cv2.INTER_LINEAR)
self._google_images.append(google_img)
self._yandex_images.append(yandex_img)
def _init_cache(self):
self._access_counts = [0] * len(self.image_pairs)
self._cached_google = [None] * len(self.image_pairs)
self._cached_yandex = [None] * len(self.image_pairs)
self._cached_homography = [None] * len(self.image_pairs)
def _generate_augmented(self, idx):
google_img = self._google_images[idx].copy()
yandex_img = self._yandex_images[idx].copy()
params1 = generate_random_homography_params()
params2 = generate_random_homography_params()
H1 = homography_params_to_matrix(params1, self.K)
H2 = homography_params_to_matrix(params2, self.K)
H_combined = np.linalg.inv(H1) @ H2
google_warped = cv2.warpPerspective(google_img, H2, (self.image_size[1], self.image_size[0]))
yandex_warped = cv2.warpPerspective(yandex_img, H1, (self.image_size[1], self.image_size[0]))
target_params = matrix_to_homography_params(H_combined, self.K)
return google_warped, yandex_warped, H_combined, target_params
def __len__(self):
return len(self.image_pairs)
def __getitem__(self, idx):
pair = self.image_pairs[idx]
google_img = Image.open(pair["google"]).convert("RGB").resize((self.image_size[1], self.image_size[0]), Image.BILINEAR)
yandex_img = Image.open(pair["yandex"]).convert("RGB").resize((self.image_size[1], self.image_size[0]), Image.BILINEAR)
self._access_counts[idx] += 1
if self.augment:
params1 = generate_random_homography_params()
params2 = generate_random_homography_params()
H1 = homography_params_to_matrix(params1, self.K)
H2 = homography_params_to_matrix(params2, self.K)
H_combined = np.linalg.inv(H1) @ H2
yandex_img = Image.fromarray(cv2.warpPerspective(np.array(yandex_img), H1, self.image_size))
google_img = Image.fromarray(cv2.warpPerspective(np.array(google_img), H2, self.image_size))
target_params = matrix_to_homography_params(H_combined, self.K)
target_matrix = H_combined
use_cache = self.augment and self.cache_level > 0 and self._access_counts[idx] > 1 and (self._access_counts[idx] - 1) % self.cache_level != 0
if use_cache:
google_img = self._cached_google[idx]
yandex_img = self._cached_yandex[idx]
target_matrix = self._cached_homography[idx]
target_params = matrix_to_homography_params(target_matrix, self.K)
elif self.augment:
google_img, yandex_img, target_matrix, target_params = self._generate_augmented(idx)
if self.cache_level > 0:
self._cached_google[idx] = google_img
self._cached_yandex[idx] = yandex_img
self._cached_homography[idx] = target_matrix
else:
google_img = self._google_images[idx]
yandex_img = self._yandex_images[idx]
target_params = np.zeros(6, dtype=np.float32)
target_matrix = np.eye(3, dtype=np.float32)
google_img = Image.fromarray(google_img)
yandex_img = Image.fromarray(yandex_img)
if self.transform:
google_img = self.transform(google_img)
yandex_img = self.transform(yandex_img)
@@ -67,11 +114,14 @@ class YaGoDataset(Dataset):
def create_data_loaders(root_dir, batch_size=32, train_split=0.8, num_workers=0,
image_size=(256, 256), augment_train=True):
transform = transforms.Compose([transforms.ToTensor()])
image_size=(256, 256), augment_train=True, cache_level=5):
transform = transforms.Compose([
transforms.ToTensor(),
# transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
full_ds = YaGoDataset(root_dir, transform=transform, augment=False, image_size=image_size)
aug_ds = YaGoDataset(root_dir, transform=transform, augment=True, image_size=image_size)
full_ds = YaGoDataset(root_dir, transform=transform, augment=False, image_size=image_size, cache_level=cache_level)
aug_ds = YaGoDataset(root_dir, transform=transform, augment=True, image_size=image_size, cache_level=cache_level)
indices = list(range(len(full_ds)))
random.shuffle(indices)

View File

@@ -21,10 +21,19 @@ class HomographyTrainer:
self.optimizer = optim.Adam(model.parameters(), lr=config["learning_rate"])
self.writer = None
self.best_val_loss = float("inf")
self.train_losses = []
self.val_losses = []
self.train_mse_trans = []
self.train_mse_angle = []
self.train_mse_scale = []
self.val_mse_trans = []
self.val_mse_angle = []
self.val_mse_scale = []
def train_epoch(self, epoch):
self.model.train()
total_loss, total_samples = 0, 0
mse_trans_sum, mse_angle_sum, mse_scale_sum = 0, 0, 0
pbar = tqdm(self.train_loader, desc=f"Epoch {epoch}")
for batch_idx, batch in enumerate(pbar):
google_img = batch["google_img"].to(self.device)
@@ -39,13 +48,23 @@ class HomographyTrainer:
total_loss += loss.item() * google_img.size(0)
total_samples += google_img.size(0)
mse_trans_sum += torch.mean((output[:, 3:5] - target[:, 3:5]) ** 2).item() * google_img.size(0)
mse_angle_sum += torch.mean((output[:, 0:3] - target[:, 0:3]) ** 2).item() * google_img.size(0)
mse_scale_sum += torch.mean((output[:, 5:6] - target[:, 5:6]) ** 2).item() * google_img.size(0)
pbar.set_postfix({"loss": loss.item()})
self.train_mse_trans.append(mse_trans_sum / total_samples)
self.train_mse_angle.append(mse_angle_sum / total_samples)
self.train_mse_scale.append(mse_scale_sum / total_samples)
return {"loss": total_loss / total_samples}
def validate(self):
self.model.eval()
total_loss, total_samples = 0, 0
mse_trans_sum, mse_angle_sum, mse_scale_sum = 0, 0, 0
with torch.no_grad():
for batch in tqdm(self.val_loader, desc="Validation"):
google_img = batch["google_img"].to(self.device)
@@ -55,6 +74,15 @@ class HomographyTrainer:
loss = self.criterion(output, target)
total_loss += loss.item() * google_img.size(0)
total_samples += google_img.size(0)
mse_trans_sum += torch.mean((output[:, 3:5] - target[:, 3:5]) ** 2).item() * google_img.size(0)
mse_angle_sum += torch.mean((output[:, 0:3] - target[:, 0:3]) ** 2).item() * google_img.size(0)
mse_scale_sum += torch.mean((output[:, 5:6] - target[:, 5:6]) ** 2).item() * google_img.size(0)
self.val_mse_trans.append(mse_trans_sum / total_samples)
self.val_mse_angle.append(mse_angle_sum / total_samples)
self.val_mse_scale.append(mse_scale_sum / total_samples)
return {"loss": total_loss / total_samples}
def train(self, num_epochs):
@@ -65,7 +93,10 @@ class HomographyTrainer:
for epoch in range(1, num_epochs + 1):
train_metrics = self.train_epoch(epoch)
val_metrics = self.validate()
self.train_losses.append(train_metrics["loss"])
self.val_losses.append(val_metrics["loss"])
print(f"Train Loss: {train_metrics['loss']:.4f}, Val Loss: {val_metrics['loss']:.4f}")
print(f" MSE - Trans: {self.val_mse_trans[-1]:.4f}, Angle: {self.val_mse_angle[-1]:.4f}, Scale: {self.val_mse_scale[-1]:.4f}")
if val_metrics["loss"] < self.best_val_loss:
self.best_val_loss = val_metrics["loss"]