858 lines
41 KiB
Plaintext
858 lines
41 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"import os\n",
|
|
"import random\n",
|
|
"import logging\n",
|
|
"from typing import Tuple\n",
|
|
"import cv2\n",
|
|
"import numpy as np\n",
|
|
"import torch\n",
|
|
"import torch.nn as nn\n",
|
|
"import torch.optim as optim\n",
|
|
"import matplotlib.pyplot as plt\n",
|
|
"from PIL import Image\n",
|
|
"from torch.utils.data import DataLoader, Dataset, Subset\n",
|
|
"from torch.utils.tensorboard import SummaryWriter\n",
|
|
"from torchvision import transforms, models\n",
|
|
"from tqdm import tqdm\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": "# Configuration\n\nGlobal settings for:\n- Data paths and image parameters\n- Training hyperparameters\n- Model architecture options\n"
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"\n",
|
|
"\n",
|
|
"config = {\n",
|
|
" \"data_dir\": r\"C:\\Users\\admin\\Projects\\autopilot\\datasets\\ya_go_maps\\images\",\n",
|
|
" \"image_size\": (256, 256),\n",
|
|
" \"batch_size\": 32,\n",
|
|
" \"train_split\": 0.8,\n",
|
|
" \"num_workers\": 0,\n",
|
|
" \"epochs\": 10,\n",
|
|
" \"learning_rate\": 2e-4,\n",
|
|
" \"dropout_rate\": 0.5,\n",
|
|
" \"backbone\": \"resnet18\",\n",
|
|
" \"output_dir\": r\"C:\\Users\\admin\\Projects\\autopilot\\models\\SiaN\\runs\",\n",
|
|
" \"save_every_n_epochs\": 15,\n",
|
|
"}\n",
|
|
"\n",
|
|
"\n",
|
|
"def get_camera_matrix(w, h):\n",
|
|
" return np.array([[w / 2, 0, w / 2], [0, h / 2, h / 2], [0, 0, 1]], dtype=np.float32)\n",
|
|
"\n",
|
|
"\n",
|
|
"def generate_random_homography_params(angle_range=10, translation_range=0.1, scale_range=(0.9, 1.1)):\n",
|
|
" scale = np.random.uniform(*scale_range)\n",
|
|
" tx = np.random.uniform(-translation_range, translation_range)\n",
|
|
" ty = np.random.uniform(-translation_range, translation_range)\n",
|
|
" rx = np.radians(np.random.uniform(-angle_range, angle_range))\n",
|
|
" ry = np.radians(np.random.uniform(-angle_range, angle_range))\n",
|
|
" rz = np.radians(np.random.uniform(-angle_range, angle_range))\n",
|
|
" return np.array([tx, ty, rx, ry, rz, scale])\n",
|
|
"\n",
|
|
"\n",
|
|
"def homography_params_to_matrix(params, K):\n",
|
|
" tx, ty, rx, ry, rz, scale = params\n",
|
|
" cy, sy = np.cos(rz), np.sin(rz)\n",
|
|
" cp, sp = np.cos(ry), np.sin(ry)\n",
|
|
" cr, sr = np.cos(rx), np.sin(rx)\n",
|
|
" Rz = np.array([[cy, -sy, 0], [sy, cy, 0], [0, 0, 1]], dtype=np.float32)\n",
|
|
" Ry = np.array([[cp, 0, sp], [0, 1, 0], [-sp, 0, cp]], dtype=np.float32)\n",
|
|
" Rx = np.array([[1, 0, 0], [0, cr, -sr], [0, sr, cr]], dtype=np.float32)\n",
|
|
" T = np.array([[1, 0, tx], [0, 1, ty], [0, 0, scale]], dtype=np.float32)\n",
|
|
" return K @ Rx @ Ry @ Rz @ T @ np.linalg.inv(K)\n",
|
|
"\n",
|
|
"\n",
|
|
"def matrix_to_homography_params(H, K):\n",
|
|
" K_inv = np.linalg.inv(K)\n",
|
|
" E = K_inv @ H @ K\n",
|
|
" scale = np.sqrt(np.linalg.det(E[:2, :2]))\n",
|
|
" R = E[:2, :2] / scale\n",
|
|
" tx, ty = E[0, 2], E[1, 2]\n",
|
|
" rz = np.arctan2(R[1, 0], R[0, 0])\n",
|
|
" r20, r21 = E[2, 0], E[2, 1]\n",
|
|
" ry = np.arctan2(r20, r21)\n",
|
|
" rx = np.arctan2(-E[1, 2], E[1, 1])\n",
|
|
" return np.array([tx, ty, rx, ry, rz, scale], dtype=np.float32)\n",
|
|
"\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": "## Dataset\n\nGoogle/Yandex image pair loader with homography augmentation.\n\n**Features:**\n- Loads paired images from dual camera sources\n- Applies random homography transformations\n- Supports configurable train/val split\n\n**Returns:**\n"
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"\n",
|
|
"\n",
|
|
"\n",
|
|
"\n",
|
|
"class YaGoDataset(Dataset):\n",
|
|
" def __init__(self, root_dir: str, transform=None, augment: bool = True, \n",
|
|
" image_size: Tuple[int, int] = (256, 256), cache_level: int = 5):\n",
|
|
" self.root_dir = root_dir\n",
|
|
" self.transform = transform\n",
|
|
" self.augment = augment\n",
|
|
" self.image_size = image_size\n",
|
|
" self.cache_level = cache_level\n",
|
|
" self.K = get_camera_matrix(image_size[1], image_size[0])\n",
|
|
" self.image_pairs = self._discover_image_pairs()\n",
|
|
" self._load_images_to_memory()\n",
|
|
" self._init_cache()\n",
|
|
"\n",
|
|
" def _discover_image_pairs(self):\n",
|
|
" pairs = []\n",
|
|
" for f in os.listdir(self.root_dir):\n",
|
|
" if f.endswith(\"_google.png\"):\n",
|
|
" idx = f.split(\"_\")[0]\n",
|
|
" yandex_path = os.path.join(self.root_dir, f\"{idx}_yandex.png\")\n",
|
|
" if os.path.exists(yandex_path):\n",
|
|
" pairs.append({\"idx\": int(idx), \"google\": os.path.join(self.root_dir, f), \"yandex\": yandex_path})\n",
|
|
" return sorted(pairs, key=lambda x: x[\"idx\"])\n",
|
|
"\n",
|
|
" def _load_images_to_memory(self):\n",
|
|
" self._google_images = []\n",
|
|
" self._yandex_images = []\n",
|
|
" for pair in self.image_pairs:\n",
|
|
" google_img = cv2.imread(pair[\"google\"])\n",
|
|
" google_img = cv2.cvtColor(google_img, cv2.COLOR_BGR2RGB)\n",
|
|
" google_img = cv2.resize(google_img, (self.image_size[1], self.image_size[0]), interpolation=cv2.INTER_LINEAR)\n",
|
|
" \n",
|
|
" yandex_img = cv2.imread(pair[\"yandex\"])\n",
|
|
" yandex_img = cv2.cvtColor(yandex_img, cv2.COLOR_BGR2RGB)\n",
|
|
" yandex_img = cv2.resize(yandex_img, (self.image_size[1], self.image_size[0]), interpolation=cv2.INTER_LINEAR)\n",
|
|
" \n",
|
|
" self._google_images.append(google_img)\n",
|
|
" self._yandex_images.append(yandex_img)\n",
|
|
"\n",
|
|
" def _init_cache(self):\n",
|
|
" self._access_counts = [0] * len(self.image_pairs)\n",
|
|
" self._cached_google = [None] * len(self.image_pairs)\n",
|
|
" self._cached_yandex = [None] * len(self.image_pairs)\n",
|
|
" self._cached_homography = [None] * len(self.image_pairs)\n",
|
|
"\n",
|
|
" def _generate_augmented(self, idx):\n",
|
|
" google_img = self._google_images[idx].copy()\n",
|
|
" yandex_img = self._yandex_images[idx].copy()\n",
|
|
"\n",
|
|
" params1 = generate_random_homography_params()\n",
|
|
" params2 = generate_random_homography_params()\n",
|
|
" H1 = homography_params_to_matrix(params1, self.K)\n",
|
|
" H2 = homography_params_to_matrix(params2, self.K)\n",
|
|
" H_combined = np.linalg.inv(H1) @ H2\n",
|
|
" \n",
|
|
" google_warped = cv2.warpPerspective(google_img, H2, (self.image_size[1], self.image_size[0]))\n",
|
|
" yandex_warped = cv2.warpPerspective(yandex_img, H1, (self.image_size[1], self.image_size[0]))\n",
|
|
" \n",
|
|
" target_params = matrix_to_homography_params(H_combined, self.K)\n",
|
|
" \n",
|
|
" return google_warped, yandex_warped, H_combined, target_params\n",
|
|
"\n",
|
|
" def __len__(self):\n",
|
|
" return len(self.image_pairs)\n",
|
|
"\n",
|
|
" def __getitem__(self, idx):\n",
|
|
" self._access_counts[idx] += 1\n",
|
|
" \n",
|
|
" use_cache = self.augment and self.cache_level > 0 and self._access_counts[idx] > 1 and (self._access_counts[idx] - 1) % self.cache_level != 0\n",
|
|
" \n",
|
|
" if use_cache:\n",
|
|
" google_img = self._cached_google[idx]\n",
|
|
" yandex_img = self._cached_yandex[idx]\n",
|
|
" target_matrix = self._cached_homography[idx]\n",
|
|
" target_params = matrix_to_homography_params(target_matrix, self.K)\n",
|
|
" elif self.augment:\n",
|
|
" google_img, yandex_img, target_matrix, target_params = self._generate_augmented(idx)\n",
|
|
" if self.cache_level > 0:\n",
|
|
" self._cached_google[idx] = google_img\n",
|
|
" self._cached_yandex[idx] = yandex_img\n",
|
|
" self._cached_homography[idx] = target_matrix\n",
|
|
" else:\n",
|
|
" google_img = self._google_images[idx]\n",
|
|
" yandex_img = self._yandex_images[idx]\n",
|
|
" target_params = np.array([0, 0, 0, 0, 0, 1], dtype=np.float32)\n",
|
|
" target_matrix = np.eye(3, dtype=np.float32)\n",
|
|
"\n",
|
|
" google_img = Image.fromarray(google_img)\n",
|
|
" yandex_img = Image.fromarray(yandex_img)\n",
|
|
"\n",
|
|
" if self.transform:\n",
|
|
" google_img = self.transform(google_img)\n",
|
|
" yandex_img = self.transform(yandex_img)\n",
|
|
"\n",
|
|
" return {\n",
|
|
" \"google_img\": google_img,\n",
|
|
" \"yandex_img\": yandex_img,\n",
|
|
" \"homography_matrix\": torch.from_numpy(target_matrix).float(),\n",
|
|
" \"homography_params\": torch.from_numpy(target_params).float(),\n",
|
|
" }\n",
|
|
"\n",
|
|
"\n",
|
|
"def create_data_loaders(root_dir, batch_size=32, train_split=0.8, num_workers=0, \n",
|
|
" image_size=(256, 256), augment_train=True, cache_level=5):\n",
|
|
" transform = transforms.Compose([\n",
|
|
" transforms.ToTensor(),\n",
|
|
" # transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),\n",
|
|
" ])\n",
|
|
" \n",
|
|
" full_ds = YaGoDataset(root_dir, transform=transform, augment=False, image_size=image_size, cache_level=cache_level)\n",
|
|
" aug_ds = YaGoDataset(root_dir, transform=transform, augment=True, image_size=image_size, cache_level=cache_level)\n",
|
|
"\n",
|
|
" indices = list(range(len(full_ds)))\n",
|
|
" random.shuffle(indices)\n",
|
|
" split = int(train_split * len(indices))\n",
|
|
" \n",
|
|
" train_ds = Subset(aug_ds if augment_train else full_ds, indices[:split])\n",
|
|
" val_ds = Subset(aug_ds, indices[split:])\n",
|
|
"\n",
|
|
" return (DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=True),\n",
|
|
" DataLoader(val_ds, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=True))\n",
|
|
"\n",
|
|
"\n",
|
|
"def get_dataset_info():\n",
|
|
" ds = YaGoDataset(config[\"data_dir\"], augment=True, image_size=config[\"image_size\"])\n",
|
|
" return {\n",
|
|
" \"size\": len(ds),\n",
|
|
" \"sample_keys\": list(ds[0].keys()),\n",
|
|
" \"sample_params\": ds[0][\"homography_params\"].numpy()\n",
|
|
" }\n",
|
|
"\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": "## Model\n\n`HomographyCNN6` — CNN architecture for homography estimation.\n\n**Output:** 6 parameters\n- `rx, ry, rz` — rotation angles (radians)\n- `tx, ty` — translation offsets\n- `scale` — isotropic scale factor\n\n**Architecture:**\n- Dual-branch CNN (Google + Yandex images)\n- Shared backbone (configurable: resnet18/34/50)\n"
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"\n",
|
|
"\n",
|
|
"def angular_difference(pred_angles, target_angles):\n",
|
|
" diff = pred_angles - target_angles\n",
|
|
" diff = torch.atan2(torch.sin(diff), torch.cos(diff))\n",
|
|
" return torch.abs(diff)\n",
|
|
"\n",
|
|
"\n",
|
|
"class HomographyCNN6(nn.Module):\n",
|
|
" def __init__(self, input_channels=3, backbone_name=\"resnet18\", pretrained=True, dropout_rate=0.3):\n",
|
|
" super().__init__()\n",
|
|
" backbone = getattr(models, backbone_name)(weights=models.ResNet18_Weights.IMAGENET1K_V1 if pretrained else None)\n",
|
|
" self.feature_dim = backbone.fc.in_features\n",
|
|
" backbone.fc = nn.Identity()\n",
|
|
" self.backbone = backbone\n",
|
|
"\n",
|
|
" self.head = nn.Sequential(\n",
|
|
" nn.Linear(self.feature_dim * 4, 512),\n",
|
|
" nn.ReLU(inplace=True),\n",
|
|
" nn.Dropout(dropout_rate),\n",
|
|
" nn.Linear(1024, 512),\n",
|
|
" nn.ReLU(inplace=True),\n",
|
|
" nn.Dropout(dropout_rate),\n",
|
|
" nn.Linear(512, 256),\n",
|
|
" nn.ReLU(inplace=True),\n",
|
|
" nn.Dropout(dropout_rate),\n",
|
|
" nn.Linear(512, 9),\n",
|
|
" )\n",
|
|
"\n",
|
|
" def _normalize_sin_cos(self, _sin, _cos):\n",
|
|
" _len = torch.sqrt(_sin ** 2 + _cos ** 2)\n",
|
|
" return _sin / _len, _cos / _len\n",
|
|
"\n",
|
|
" def forward(self, img1, img2):\n",
|
|
" f1 = self.backbone(img1)\n",
|
|
" f2 = self.backbone(img2)\n",
|
|
" combined = torch.cat([f1, f2, torch.abs(f1 - f2), f1 * f2], dim=1)\n",
|
|
"\n",
|
|
" combined[:, (0, 1)] = torch.tanh(combined[:, (0, 1)]) * 10 # [-10; 10]\n",
|
|
" combined[:, (2, 3)] = self._normalize_sin_cos(torch.tanh(combined[:, 2]), torch.tanh(combined[:, 3]))\n",
|
|
" combined[:, (4, 5)] = self._normalize_sin_cos(torch.tanh(combined[:, 4]), torch.tanh(combined[:, 5]))\n",
|
|
" combined[:, (6, 7)] = self._normalize_sin_cos(torch.tanh(combined[:, 6]), torch.tanh(combined[:, 7]))\n",
|
|
" \n",
|
|
" return self.head(combined)\n",
|
|
"\n",
|
|
" def decode_output(self, output):\n",
|
|
" tx = output[:, 0]\n",
|
|
" ty = output[:, 1]\n",
|
|
" scale = output[:, 8]\n",
|
|
"\n",
|
|
" angle1 = torch.atan2(output[:, 2], output[:, 3])\n",
|
|
" angle2 = torch.atan2(output[:, 4], output[:, 5])\n",
|
|
" angle3 = torch.atan2(output[:, 6], output[:, 7])\n",
|
|
"\n",
|
|
" return torch.stack([tx, ty, angle1, angle2, angle3, scale], dim=1)\n",
|
|
"\n",
|
|
" def get_components(self, output):\n",
|
|
" decoded = self.decode_output(output)\n",
|
|
" return {\n",
|
|
" \"tx\": decoded[:, 0],\n",
|
|
" \"ty\": decoded[:, 1],\n",
|
|
" \"rx\": decoded[:, 2],\n",
|
|
" \"ry\": decoded[:, 3],\n",
|
|
" \"rz\": decoded[:, 4],\n",
|
|
" \"scale\": decoded[:, 5],\n",
|
|
" }\n",
|
|
"\n",
|
|
"\n",
|
|
"class HomographyLoss6(nn.Module):\n",
|
|
" def __init__(self, angle_loss_weight=1.0, trans_loss_weight=1.0, scale_loss_weight=1.0):\n",
|
|
" super().__init__()\n",
|
|
" self.criterion = nn.MSELoss()\n",
|
|
" self.angle_loss_weight = angle_loss_weight\n",
|
|
" self.trans_loss_weight = trans_loss_weight\n",
|
|
" self.scale_loss_weight = scale_loss_weight\n",
|
|
"\n",
|
|
" def forward(self, pred, target):\n",
|
|
" tx_loss = self.criterion(pred[:, 0], target[:, 0])\n",
|
|
" ty_loss = self.criterion(pred[:, 1], target[:, 1])\n",
|
|
"\n",
|
|
" sin_rx_pred = pred[:, 2]\n",
|
|
" cos_rx_pred = pred[:, 3]\n",
|
|
" sin_ry_pred = pred[:, 4]\n",
|
|
" cos_ry_pred = pred[:, 5]\n",
|
|
" sin_rz_pred = pred[:, 6]\n",
|
|
" cos_rz_pred = pred[:, 7]\n",
|
|
"\n",
|
|
" sin_rx_target = torch.sin(target[:, 2])\n",
|
|
" cos_rx_target = torch.cos(target[:, 2])\n",
|
|
" sin_ry_target = torch.sin(target[:, 3])\n",
|
|
" cos_ry_target = torch.cos(target[:, 3])\n",
|
|
" sin_rz_target = torch.sin(target[:, 4])\n",
|
|
" cos_rz_target = torch.cos(target[:, 4])\n",
|
|
"\n",
|
|
" dot_rx = sin_rx_pred * sin_rx_target + cos_rx_pred * cos_rx_target\n",
|
|
" dot_ry = sin_ry_pred * sin_ry_target + cos_ry_pred * cos_ry_target\n",
|
|
" dot_rz = sin_rz_pred * sin_rz_target + cos_rz_pred * cos_rz_target\n",
|
|
"\n",
|
|
" rx_loss = (1 - dot_rx).mean()\n",
|
|
" ry_loss = (1 - dot_ry).mean()\n",
|
|
" rz_loss = (1 - dot_rz).mean()\n",
|
|
"\n",
|
|
" scale_loss = self.criterion(pred[:, 8], target[:, 5])\n",
|
|
"\n",
|
|
" total_loss = (\n",
|
|
" self.trans_loss_weight * (tx_loss + ty_loss) +\n",
|
|
" self.angle_loss_weight * (rx_loss + ry_loss + rz_loss) +\n",
|
|
" self.scale_loss_weight * scale_loss\n",
|
|
" )\n",
|
|
"\n",
|
|
" return total_loss\n",
|
|
"\n",
|
|
" def compute_mse_components(self, pred, target):\n",
|
|
" decoded = self.decode_output(pred)\n",
|
|
" tx_mse = self.criterion(decoded[:, 0], target[:, 0]).item()\n",
|
|
" ty_mse = self.criterion(decoded[:, 1], target[:, 1]).item()\n",
|
|
"\n",
|
|
" rx_mse = angular_difference(decoded[:, 2], target[:, 2]).item()\n",
|
|
" ry_mse = angular_difference(decoded[:, 3], target[:, 3]).item()\n",
|
|
" rz_mse = angular_difference(decoded[:, 4], target[:, 4]).item()\n",
|
|
"\n",
|
|
" scale_mse = self.criterion(decoded[:, 5], target[:, 5]).item()\n",
|
|
"\n",
|
|
" avg_angle_loss = (rx_mse + ry_mse + rz_mse) / 3\n",
|
|
"\n",
|
|
" return {\n",
|
|
" 'trans': (tx_mse + ty_mse) / 2,\n",
|
|
" 'angle': avg_angle_loss,\n",
|
|
" 'scale': scale_mse\n",
|
|
" }\n",
|
|
"\n",
|
|
"\n",
|
|
"def count_parameters(model):\n",
|
|
" return sum(p.numel() for p in model.parameters())\n",
|
|
"\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": "## Training\n\n`HomographyTrainer` — training loop with validation and checkpointing.\n\n**Features:**\n- Epoch-based training with tqdm progress bar\n- Adam optimizer with configurable LR\n- Validation after each epoch\n- Best model auto-save\n- Periodic checkpoints (every N epochs via `save_every_n_epochs`)\n\n**Checkpoint saving:**\n- `best_model.pt` — lowest validation loss\n"
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"\n",
|
|
"\n",
|
|
"\n",
|
|
"\n",
|
|
"class HomographyTrainer:\n",
|
|
" def __init__(self, model, train_loader, val_loader, device):\n",
|
|
" self.model = model.to(device)\n",
|
|
" self.train_loader = train_loader\n",
|
|
" self.val_loader = val_loader\n",
|
|
" self.device = device\n",
|
|
" self.criterion = HomographyLoss6()\n",
|
|
" self.optimizer = optim.Adam(model.parameters(), lr=config[\"learning_rate\"], weight_decay=1e-4)\n",
|
|
" self.writer = None\n",
|
|
" self.best_val_loss = float(\"inf\")\n",
|
|
" self.train_losses = []\n",
|
|
" self.val_losses = []\n",
|
|
" self.train_mse_trans = []\n",
|
|
" self.train_mse_angle = []\n",
|
|
" self.train_mse_scale = []\n",
|
|
" self.val_mse_trans = []\n",
|
|
" self.val_mse_angle = []\n",
|
|
" self.val_mse_scale = []\n",
|
|
"\n",
|
|
" def train_epoch(self, epoch):\n",
|
|
" self.model.train()\n",
|
|
" total_loss, total_samples = 0, 0\n",
|
|
" mse_trans_sum, mse_angle_sum, mse_scale_sum = 0, 0, 0\n",
|
|
" pbar = tqdm(self.train_loader, desc=f\"Epoch {epoch}\")\n",
|
|
" for batch_idx, batch in enumerate(pbar):\n",
|
|
" google_img = batch[\"google_img\"].to(self.device)\n",
|
|
" yandex_img = batch[\"yandex_img\"].to(self.device)\n",
|
|
" target = batch[\"homography_params\"].to(self.device)\n",
|
|
"\n",
|
|
" self.optimizer.zero_grad()\n",
|
|
" output = self.model(google_img, yandex_img)\n",
|
|
" loss = self.criterion(output, target)\n",
|
|
" loss.backward()\n",
|
|
" self.optimizer.step()\n",
|
|
"\n",
|
|
" total_loss += loss.item() * google_img.size(0)\n",
|
|
" total_samples += google_img.size(0)\n",
|
|
" \n",
|
|
" decoded_output = self.model.decode_output(output)\n",
|
|
" mse_components = self.criterion.compute_mse_components(decoded_output, target)\n",
|
|
" mse_trans_sum += mse_components['trans'] * google_img.size(0)\n",
|
|
" mse_angle_sum += mse_components['angle'] * google_img.size(0)\n",
|
|
" mse_scale_sum += mse_components['scale'] * google_img.size(0)\n",
|
|
" \n",
|
|
" pbar.set_postfix({\"loss\": loss.item()})\n",
|
|
"\n",
|
|
" self.train_mse_trans.append(mse_trans_sum / total_samples)\n",
|
|
" self.train_mse_angle.append(mse_angle_sum / total_samples)\n",
|
|
" self.train_mse_scale.append(mse_scale_sum / total_samples)\n",
|
|
" \n",
|
|
" return {\"loss\": total_loss / total_samples}\n",
|
|
"\n",
|
|
" def validate(self):\n",
|
|
" self.model.eval()\n",
|
|
" total_loss, total_samples = 0, 0\n",
|
|
" mse_trans_sum, mse_angle_sum, mse_scale_sum = 0, 0, 0\n",
|
|
" with torch.no_grad():\n",
|
|
" for batch in tqdm(self.val_loader, desc=\"Validation\"):\n",
|
|
" google_img = batch[\"google_img\"].to(self.device)\n",
|
|
" yandex_img = batch[\"yandex_img\"].to(self.device)\n",
|
|
" target = batch[\"homography_params\"].to(self.device)\n",
|
|
" output = self.model(google_img, yandex_img)\n",
|
|
" decoded_output = self.model.decode_output(output)\n",
|
|
" loss = self.criterion(output, target)\n",
|
|
" total_loss += loss.item() * google_img.size(0)\n",
|
|
" total_samples += google_img.size(0)\n",
|
|
" \n",
|
|
" mse_components = self.criterion.compute_mse_components(decoded_output, target)\n",
|
|
" mse_trans_sum += mse_components['trans'] * google_img.size(0)\n",
|
|
" mse_angle_sum += mse_components['angle'] * google_img.size(0)\n",
|
|
" mse_scale_sum += mse_components['scale'] * google_img.size(0)\n",
|
|
" \n",
|
|
" self.val_mse_trans.append(mse_trans_sum / total_samples)\n",
|
|
" self.val_mse_angle.append(mse_angle_sum / total_samples)\n",
|
|
" self.val_mse_scale.append(mse_scale_sum / total_samples)\n",
|
|
" \n",
|
|
" return {\"loss\": total_loss / total_samples}\n",
|
|
"\n",
|
|
" def train(self, num_epochs):\n",
|
|
" log_dir = config[\"output_dir\"]\n",
|
|
" os.makedirs(log_dir, exist_ok=True)\n",
|
|
" self.writer = SummaryWriter(log_dir)\n",
|
|
"\n",
|
|
" for epoch in range(1, num_epochs + 1):\n",
|
|
" train_metrics = self.train_epoch(epoch)\n",
|
|
" val_metrics = self.validate()\n",
|
|
" self.train_losses.append(train_metrics[\"loss\"])\n",
|
|
" self.val_losses.append(val_metrics[\"loss\"])\n",
|
|
" print(f\"Train Loss: {train_metrics['loss']:.4f}, Val Loss: {val_metrics['loss']:.4f}\")\n",
|
|
" print(f\" MSE - Trans: {self.val_mse_trans[-1]:.4f}, Angle: {self.val_mse_angle[-1]:.4f}, Scale: {self.val_mse_scale[-1]:.4f}\")\n",
|
|
"\n",
|
|
" if val_metrics[\"loss\"] < self.best_val_loss:\n",
|
|
" self.best_val_loss = val_metrics[\"loss\"]\n",
|
|
" self.save_checkpoint(epoch, is_best=True)\n",
|
|
" print(f\"Best model saved (val loss: {val_metrics['loss']:.4f})\")\n",
|
|
"\n",
|
|
" if epoch % config[\"save_every_n_epochs\"] == 0:\n",
|
|
" self.save_checkpoint(epoch, is_best=False)\n",
|
|
" print(f\"Checkpoint saved at epoch {epoch}\")\n",
|
|
"\n",
|
|
" self.writer.close()\n",
|
|
"\n",
|
|
" def save_checkpoint(self, epoch, is_best=False):\n",
|
|
" ckpt_dir = os.path.join(config[\"output_dir\"], \"checkpoints\")\n",
|
|
" os.makedirs(ckpt_dir, exist_ok=True)\n",
|
|
" ckpt = {\"epoch\": epoch, \"model_state_dict\": self.model.state_dict(), \"val_loss\": self.best_val_loss}\n",
|
|
" torch.save(ckpt, os.path.join(ckpt_dir, f\"checkpoint_epoch_{epoch}.pt\"))\n",
|
|
" if is_best:\n",
|
|
" torch.save(ckpt, os.path.join(ckpt_dir, \"best_model.pt\"))\n",
|
|
"\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": "## Analysis\n\nVisualization and evaluation tools:\n\n- Training metrics plots (loss curves)\n- Prediction visualization on sample images\n"
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"\n",
|
|
"\n",
|
|
"IMG_DIR = os.path.join(config[\"output_dir\"], \"images\")\n",
|
|
"os.makedirs(IMG_DIR, exist_ok=True)\n",
|
|
"\n",
|
|
"\n",
|
|
"def analyze_training(trainer):\n",
|
|
" print(\"=== Training Analysis ===\\n\")\n",
|
|
"\n",
|
|
" if trainer.writer:\n",
|
|
" print(\"TensorBoard logs available at:\", trainer.writer.log_dir)\n",
|
|
"\n",
|
|
" print(f\"\\nBest val loss: {trainer.best_val_loss:.4f}\")\n",
|
|
"\n",
|
|
" best_model_path = os.path.join(config[\"output_dir\"], \"checkpoints\", \"best_model.pt\")\n",
|
|
" if os.path.exists(best_model_path):\n",
|
|
" checkpoint = torch.load(best_model_path, map_location=trainer.device)\n",
|
|
" trainer.model.load_state_dict(checkpoint[\"model_state_dict\"])\n",
|
|
" print(f\"\\nLoaded best model from epoch {checkpoint['epoch']} (val loss: {checkpoint['val_loss']:.4f})\")\n",
|
|
" \n",
|
|
" trainer.model.eval()\n",
|
|
" \n",
|
|
" n_samples = 50\n",
|
|
" names = [\"tx\", \"ty\", \"rx\", \"ry\", \"rz\", \"scale\"]\n",
|
|
" \n",
|
|
" _, val_loader_for_analysis = create_data_loaders(\n",
|
|
" root_dir=config[\"data_dir\"],\n",
|
|
" batch_size=config[\"batch_size\"],\n",
|
|
" train_split=config[\"train_split\"],\n",
|
|
" num_workers=config[\"num_workers\"],\n",
|
|
" image_size=config[\"image_size\"],\n",
|
|
" augment_train=True,\n",
|
|
" cache_level=0,\n",
|
|
" )\n",
|
|
" \n",
|
|
" with torch.no_grad():\n",
|
|
" all_errors = [[] for _ in range(6)]\n",
|
|
" all_targets = [[] for _ in range(6)]\n",
|
|
" all_preds = [[] for _ in range(6)]\n",
|
|
" \n",
|
|
" sample_count = 0\n",
|
|
" for batch in val_loader_for_analysis:\n",
|
|
" if sample_count >= n_samples:\n",
|
|
" break\n",
|
|
" \n",
|
|
" google_img = batch[\"google_img\"].to(trainer.device)\n",
|
|
" yandex_img = batch[\"yandex_img\"].to(trainer.device)\n",
|
|
" target_params = batch[\"homography_params\"].to(trainer.device)\n",
|
|
" pred_params = trainer.model(google_img, yandex_img)\n",
|
|
" decoded_pred = trainer.model.decode_output(pred_params)\n",
|
|
" \n",
|
|
" batch_size = google_img.size(0)\n",
|
|
" for i in range(batch_size):\n",
|
|
" if sample_count >= n_samples:\n",
|
|
" break\n",
|
|
" \n",
|
|
" tx_error = torch.abs(decoded_pred[i, 0] - target_params[i, 0]).item()\n",
|
|
" ty_error = torch.abs(decoded_pred[i, 1] - target_params[i, 1]).item()\n",
|
|
" rx_error = angular_difference(decoded_pred[i, 2], target_params[i, 2]).item()\n",
|
|
" ry_error = angular_difference(decoded_pred[i, 3], target_params[i, 3]).item()\n",
|
|
" rz_error = angular_difference(decoded_pred[i, 4], target_params[i, 4]).item()\n",
|
|
" scale_error = torch.abs(decoded_pred[i, 5] - target_params[i, 5]).item()\n",
|
|
" \n",
|
|
" errors = [tx_error, ty_error, rx_error, ry_error, rz_error, scale_error]\n",
|
|
" target_reordered = target_params[i].cpu().numpy()\n",
|
|
" pred_reordered = decoded_pred[i].cpu().numpy()\n",
|
|
" \n",
|
|
" for j in range(6):\n",
|
|
" all_errors[j].append(errors[j])\n",
|
|
" all_targets[j].append(target_reordered[j])\n",
|
|
" all_preds[j].append(pred_reordered[j])\n",
|
|
" \n",
|
|
" sample_count += 1\n",
|
|
" \n",
|
|
" mean_errors = [np.mean(all_errors[i]) for i in range(6)]\n",
|
|
" std_errors = [np.std(all_errors[i]) for i in range(6)]\n",
|
|
" \n",
|
|
" if len(trainer.train_losses) > 0:\n",
|
|
" epochs = range(1, len(trainer.train_losses) + 1)\n",
|
|
" fig, axes = plt.subplots(2, 2, figsize=(16, 12))\n",
|
|
" \n",
|
|
" axes[0, 0].plot(epochs, trainer.train_losses, \"b-\", label=\"Train Loss\")\n",
|
|
" axes[0, 0].plot(epochs, trainer.val_losses, \"r-\", label=\"Val Loss\")\n",
|
|
" axes[0, 0].set_xlabel(\"Epoch\")\n",
|
|
" axes[0, 0].set_ylabel(\"Loss\")\n",
|
|
" axes[0, 0].set_title(\"Training & Validation Loss\")\n",
|
|
" axes[0, 0].legend()\n",
|
|
" axes[0, 0].grid(True, alpha=0.3)\n",
|
|
" \n",
|
|
" axes[0, 1].plot(epochs, trainer.val_losses, \"r-\", label=\"Val Loss\")\n",
|
|
" axes[0, 1].set_xlabel(\"Epoch\")\n",
|
|
" axes[0, 1].set_ylabel(\"Loss\")\n",
|
|
" axes[0, 1].set_title(\"Validation Loss\")\n",
|
|
" axes[0, 1].legend()\n",
|
|
" axes[0, 1].grid(True, alpha=0.3)\n",
|
|
" \n",
|
|
" axes[1, 0].plot(epochs, trainer.val_mse_trans, \"g-\", label=\"Translation (tx, ty)\")\n",
|
|
" axes[1, 0].plot(epochs, trainer.val_mse_angle, \"m-\", label=\"Angle (rx, ry, rz)\")\n",
|
|
" axes[1, 0].plot(epochs, trainer.val_mse_scale, \"c-\", label=\"Scale\")\n",
|
|
" axes[1, 0].set_xlabel(\"Epoch\")\n",
|
|
" axes[1, 0].set_ylabel(\"MSE\")\n",
|
|
" axes[1, 0].set_title(\"Validation MSE by Category\")\n",
|
|
" axes[1, 0].legend()\n",
|
|
" axes[1, 0].grid(True, alpha=0.3)\n",
|
|
" \n",
|
|
" x_pos = np.arange(6)\n",
|
|
" axes[1, 1].bar(x_pos, mean_errors, yerr=std_errors, capsize=5, color=[\"c\", \"m\", \"y\", \"g\", \"b\", \"r\"], alpha=0.8)\n",
|
|
" axes[1, 1].set_xticks(x_pos)\n",
|
|
" axes[1, 1].set_xticklabels(names)\n",
|
|
" axes[1, 1].set_ylabel(\"Mean Absolute Error\")\n",
|
|
" axes[1, 1].set_title(f\"Mean Absolute Error per Parameter ({n_samples} samples)\")\n",
|
|
" axes[1, 1].grid(True, alpha=0.3, axis=\"y\")\n",
|
|
" \n",
|
|
" plt.tight_layout()\n",
|
|
" plt.savefig(os.path.join(IMG_DIR, \"training_loss_plots.png\"), dpi=150)\n",
|
|
" print(\"Saved training_loss_plots.png\")\n",
|
|
" plt.show()\n",
|
|
" \n",
|
|
" fig, axes = plt.subplots(2, 3, figsize=(18, 10))\n",
|
|
" for j in range(6):\n",
|
|
" row = j // 3\n",
|
|
" col = j % 3\n",
|
|
" axes[row, col].bar(range(len(all_errors[j])), all_errors[j], color=\"steelblue\", alpha=0.7)\n",
|
|
" axes[row, col].set_xlabel(\"Sample\")\n",
|
|
" axes[row, col].set_ylabel(\"Absolute Error\")\n",
|
|
" axes[row, col].set_title(f\"{names[j]}: Mean={np.mean(all_errors[j]):.4f}, Std={np.std(all_errors[j]):.4f}\")\n",
|
|
" axes[row, col].grid(True, alpha=0.3, axis=\"y\")\n",
|
|
" plt.suptitle(f\"Mean Absolute Error per Parameter ({n_samples} samples)\", fontsize=14)\n",
|
|
" plt.tight_layout()\n",
|
|
" plt.savefig(os.path.join(IMG_DIR, \"mae_per_parameter.png\"), dpi=150)\n",
|
|
" print(\"Saved mae_per_parameter.png\")\n",
|
|
" plt.show()\n",
|
|
" \n",
|
|
" fig, axes = plt.subplots(1, 2, figsize=(14, 6))\n",
|
|
" \n",
|
|
" x_pos = np.arange(6)\n",
|
|
" axes[0].bar(x_pos, mean_errors, yerr=std_errors, capsize=5, color=[\"c\", \"m\", \"y\", \"g\", \"b\", \"r\"], alpha=0.8)\n",
|
|
" axes[0].set_xticks(x_pos)\n",
|
|
" axes[0].set_xticklabels(names)\n",
|
|
" axes[0].set_ylabel(\"Mean Absolute Error\")\n",
|
|
" axes[0].set_title(\"Mean Absolute Error per Parameter (with std)\")\n",
|
|
" axes[0].grid(True, alpha=0.3, axis=\"y\")\n",
|
|
" \n",
|
|
" bp = axes[1].boxplot([all_errors[i] for i in range(6)], labels=names, patch_artist=True)\n",
|
|
" colors = [\"c\", \"m\", \"y\", \"g\", \"b\", \"r\"]\n",
|
|
" for patch, color in zip(bp[\"boxes\"], colors):\n",
|
|
" patch.set_facecolor(color)\n",
|
|
" patch.set_alpha(0.7)\n",
|
|
" axes[1].set_ylabel(\"Absolute Error\")\n",
|
|
" axes[1].set_title(f\"Error Distribution per Parameter ({n_samples} samples)\")\n",
|
|
" axes[1].grid(True, alpha=0.3, axis=\"y\")\n",
|
|
" \n",
|
|
" plt.tight_layout()\n",
|
|
" plt.savefig(os.path.join(IMG_DIR, \"mae_boxplot.png\"), dpi=150)\n",
|
|
" print(\"Saved mae_boxplot.png\")\n",
|
|
" plt.show()\n",
|
|
" \n",
|
|
" print(\"\\n=== Sample Predictions (20 pairs) ===\")\n",
|
|
" n_vis_samples = 20\n",
|
|
" \n",
|
|
" with torch.no_grad():\n",
|
|
" vis_count = 0\n",
|
|
" for batch in val_loader_for_analysis:\n",
|
|
" if vis_count >= n_vis_samples:\n",
|
|
" break\n",
|
|
" batch_size = batch[\"google_img\"].size(0)\n",
|
|
" \n",
|
|
" for i in range(batch_size):\n",
|
|
" if vis_count >= n_vis_samples:\n",
|
|
" break\n",
|
|
" \n",
|
|
" google_img = batch[\"google_img\"][i:i+1].to(trainer.device)\n",
|
|
" yandex_img = batch[\"yandex_img\"][i:i+1].to(trainer.device)\n",
|
|
" target_params = batch[\"homography_params\"][i:i+1].to(trainer.device)\n",
|
|
" pred_params = trainer.model(google_img, yandex_img)\n",
|
|
" decoded_pred = trainer.model.decode_output(pred_params)\n",
|
|
" \n",
|
|
" tx_error = torch.abs(decoded_pred[0, 0] - target_params[0, 0]).item()\n",
|
|
" ty_error = torch.abs(decoded_pred[0, 1] - target_params[0, 1]).item()\n",
|
|
" rx_error = angular_difference(decoded_pred[0, 2], target_params[0, 2]).item()\n",
|
|
" ry_error = angular_difference(decoded_pred[0, 3], target_params[0, 3]).item()\n",
|
|
" rz_error = angular_difference(decoded_pred[0, 4], target_params[0, 4]).item()\n",
|
|
" scale_error = torch.abs(decoded_pred[0, 5] - target_params[0, 5]).item()\n",
|
|
" \n",
|
|
" errors = np.array([tx_error, ty_error, rx_error, ry_error, rz_error, scale_error])\n",
|
|
" targets = target_params[0].cpu().numpy()\n",
|
|
" preds = decoded_pred[0].cpu().numpy()\n",
|
|
" \n",
|
|
" fig, axes = plt.subplots(2, 2, figsize=(12, 10))\n",
|
|
" \n",
|
|
" axes[0, 0].imshow(google_img[0].cpu().permute(1, 2, 0))\n",
|
|
" axes[0, 0].set_title(f\"Google Image\")\n",
|
|
" axes[0, 0].axis(\"off\")\n",
|
|
" \n",
|
|
" axes[0, 1].imshow(yandex_img[0].cpu().permute(1, 2, 0))\n",
|
|
" axes[0, 1].set_title(f\"Yandex Image\")\n",
|
|
" axes[0, 1].axis(\"off\")\n",
|
|
" \n",
|
|
" x_pos = np.arange(6)\n",
|
|
" width = 0.35\n",
|
|
" axes[1, 0].bar(x_pos - width/2, targets, width, label=\"Target\", color=\"steelblue\", alpha=0.8)\n",
|
|
" axes[1, 0].bar(x_pos + width/2, preds, width, label=\"Predicted\", color=\"coral\", alpha=0.8)\n",
|
|
" axes[1, 0].set_xticks(x_pos)\n",
|
|
" axes[1, 0].set_xticklabels(names)\n",
|
|
" axes[1, 0].set_ylabel(\"Parameter Value\")\n",
|
|
" axes[1, 0].set_title(\"Target vs Predicted\")\n",
|
|
" axes[1, 0].legend()\n",
|
|
" axes[1, 0].grid(True, alpha=0.3, axis=\"y\")\n",
|
|
" \n",
|
|
" axes[1, 1].bar(x_pos, errors, color=[\"c\", \"m\", \"y\", \"g\", \"b\", \"r\"], alpha=0.8)\n",
|
|
" axes[1, 1].set_xticks(x_pos)\n",
|
|
" axes[1, 1].set_xticklabels(names)\n",
|
|
" axes[1, 1].set_ylabel(\"Absolute Error\")\n",
|
|
" axes[1, 1].set_title(f\"Prediction Error (Mean: {np.mean(errors):.4f})\")\n",
|
|
" axes[1, 1].grid(True, alpha=0.3, axis=\"y\")\n",
|
|
" for i_e, e in enumerate(errors):\n",
|
|
" axes[1, 1].text(i_e, e + 0.01, f\"{e:.3f}\", ha=\"center\", va=\"bottom\", fontsize=8)\n",
|
|
" \n",
|
|
" plt.suptitle(f\"Sample {vis_count + 1}\", fontsize=14)\n",
|
|
" plt.tight_layout()\n",
|
|
" plt.savefig(os.path.join(IMG_DIR, f\"prediction_sample_{vis_count + 1:02d}.png\"), dpi=100)\n",
|
|
" plt.show()\n",
|
|
" print(f\"Saved prediction_sample_{vis_count + 1:02d}.png\")\n",
|
|
" \n",
|
|
" vis_count += 1\n",
|
|
" \n",
|
|
" print(f\"\\nPrediction errors over {n_samples} samples:\")\n",
|
|
" print(f\"{'Param':<8} {'Mean Error':>12} {'Std Error':>12} {'Min':>8} {'Max':>8}\")\n",
|
|
" print(\"-\" * 52)\n",
|
|
" for i in range(6):\n",
|
|
" mean_err = np.mean(all_errors[i])\n",
|
|
" std_err = np.std(all_errors[i])\n",
|
|
" min_err = np.min(all_errors[i])\n",
|
|
" max_err = np.max(all_errors[i])\n",
|
|
" print(f\"{names[i]:<8} {mean_err:>12.4f} {std_err:>12.4f} {min_err:>8.4f} {max_err:>8.4f}\")\n",
|
|
"\n",
|
|
" return {\n",
|
|
" \"best_val_loss\": trainer.best_val_loss,\n",
|
|
" \"train_losses\": trainer.train_losses,\n",
|
|
" \"val_losses\": trainer.val_losses,\n",
|
|
" \"val_mse_trans\": trainer.val_mse_trans,\n",
|
|
" \"val_mse_angle\": trainer.val_mse_angle,\n",
|
|
" \"val_mse_scale\": trainer.val_mse_scale,\n",
|
|
" }\n",
|
|
"\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": "## Main Pipeline\n\nExecutes the full training workflow:\n1. Load dataset info\n2. Create data loaders\n3. Initialize model\n4. Train with validation\n5. Analyze and export results\n\n**Outputs:**\n- Model checkpoints in `runs/checkpoints/`\n- TensorBoard logs in `runs/`\n"
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"\n",
|
|
"\n",
|
|
"\n",
|
|
"\n",
|
|
"logging.basicConfig(level=logging.INFO, format=\"%(asctime)s - %(message)s\")\n",
|
|
"logger = logging.getLogger(__name__)\n",
|
|
"\n",
|
|
"logger.info(\"=\" * 50)\n",
|
|
"logger.info(\"SiaN Training Pipeline\")\n",
|
|
"logger.info(\"=\" * 50)\n",
|
|
"\n",
|
|
"dataset_info = get_dataset_info()\n",
|
|
"logger.info(f\"Dataset: {dataset_info['size']} samples, keys={dataset_info['sample_keys']}\")\n",
|
|
"\n",
|
|
"train_loader, val_loader = create_data_loaders(\n",
|
|
" root_dir=config[\"data_dir\"],\n",
|
|
" batch_size=config[\"batch_size\"],\n",
|
|
" train_split=config[\"train_split\"],\n",
|
|
" num_workers=config[\"num_workers\"],\n",
|
|
" image_size=config[\"image_size\"],\n",
|
|
")\n",
|
|
"logger.info(f\"Data loaders created: train={len(train_loader.dataset)}, val={len(val_loader.dataset)}\")\n",
|
|
"\n",
|
|
"model = HomographyCNN6(\n",
|
|
" input_channels=3,\n",
|
|
" backbone_name=config[\"backbone\"],\n",
|
|
" pretrained=True,\n",
|
|
" dropout_rate=config[\"dropout_rate\"]\n",
|
|
")\n",
|
|
"logger.info(f\"Model created with {count_parameters(model):,} parameters\")\n",
|
|
"\n",
|
|
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
|
|
"logger.info(f\"Using device: {device}\")\n",
|
|
"\n",
|
|
"trainer = HomographyTrainer(model, train_loader, val_loader, device)\n",
|
|
"logger.info(\"Starting training...\")\n",
|
|
"trainer.train(config[\"epochs\"])\n",
|
|
"logger.info(\"Training completed\")\n",
|
|
"\n",
|
|
"logger.info(\"Analyzing model...\")\n",
|
|
"results = analyze_training(trainer)\n",
|
|
"logger.info(f\"Analysis complete: best_val_loss={results['best_val_loss']:.4f}\")\n",
|
|
"\n",
|
|
"logger.info(\"=\" * 50)\n",
|
|
"logger.info(\"Pipeline completed successfully\")\n",
|
|
"logger.info(\"=\" * 50)\n",
|
|
"\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"!zip artefacts.zip runs/checkpoints/best_model.pt runs/images/ runs/events.*\n",
|
|
"\n"
|
|
]
|
|
}
|
|
],
|
|
"metadata": {
|
|
"kernelspec": {
|
|
"display_name": ".venv",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"name": "python",
|
|
"version": "3.11.0"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 5
|
|
} |