add auto codegen
This commit is contained in:
3
models/SiaN/.gitignore
vendored
3
models/SiaN/.gitignore
vendored
@@ -1 +1,2 @@
|
||||
runs
|
||||
runs
|
||||
*.gen.py
|
||||
23
models/SiaN/_schema.md
Normal file
23
models/SiaN/_schema.md
Normal file
@@ -0,0 +1,23 @@
|
||||
# _schema.md
|
||||
|
||||
## Format
|
||||
|
||||
```
|
||||
# === IMPORTS ===
|
||||
<all imports>
|
||||
|
||||
# code: ./src/file.py
|
||||
# markdown
|
||||
"""Description"""
|
||||
|
||||
# inline:
|
||||
<custom code>
|
||||
```
|
||||
|
||||
## Directives
|
||||
|
||||
- `# code:` - include file from src/
|
||||
- `# markdown` - description block
|
||||
- `# inline:` - custom code cell
|
||||
|
||||
`build.py` generates notebook from `_schema.py`.
|
||||
47
models/SiaN/_schema.py
Normal file
47
models/SiaN/_schema.py
Normal file
@@ -0,0 +1,47 @@
|
||||
# _schema.py
|
||||
|
||||
# === IMPORTS ===
|
||||
import os
|
||||
import random
|
||||
import logging
|
||||
from typing import Tuple
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
import matplotlib.pyplot as plt
|
||||
from PIL import Image
|
||||
from torch.utils.data import DataLoader, Dataset, Subset
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
from torchvision import transforms, models
|
||||
from tqdm import tqdm
|
||||
|
||||
# code: ./src/utils.py
|
||||
# markdown
|
||||
"""# SiaN Model"""
|
||||
|
||||
# code: ./src/dataloader.py
|
||||
# markdown
|
||||
"""Dataset for Google/Yandex image pairs with homography augmentation."""
|
||||
|
||||
# code: ./src/model.py
|
||||
# markdown
|
||||
"""HomographyCNN6 predicts 6 params: rx, ry, rz, tx, ty, scale."""
|
||||
|
||||
# code: ./src/train.py
|
||||
# markdown
|
||||
"""HomographyTrainer manages training loop with validation."""
|
||||
|
||||
# code: ./src/analyze.py
|
||||
# markdown
|
||||
"""Visualization and analysis of predictions."""
|
||||
|
||||
# code: ./src/main.py
|
||||
# markdown
|
||||
"""Run main() to execute the full pipeline."""
|
||||
|
||||
# inline:
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
235
models/SiaN/build.py
Normal file
235
models/SiaN/build.py
Normal file
@@ -0,0 +1,235 @@
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
|
||||
|
||||
def parse_schema(schema_path):
|
||||
with open(schema_path, 'r', encoding='utf-8') as f:
|
||||
content = f.read()
|
||||
|
||||
imports = []
|
||||
items = []
|
||||
|
||||
in_imports_section = False
|
||||
lines = content.split('\n')
|
||||
|
||||
i = 0
|
||||
while i < len(lines):
|
||||
line = lines[i]
|
||||
stripped = line.strip()
|
||||
|
||||
if stripped == '# === IMPORTS ===':
|
||||
in_imports_section = True
|
||||
i += 1
|
||||
continue
|
||||
elif stripped.startswith('# ==='):
|
||||
in_imports_section = False
|
||||
i += 1
|
||||
continue
|
||||
|
||||
if in_imports_section and stripped and not stripped.startswith('#'):
|
||||
imports.append(stripped)
|
||||
i += 1
|
||||
continue
|
||||
|
||||
if in_imports_section and stripped.startswith('#'):
|
||||
in_imports_section = False
|
||||
|
||||
if stripped.startswith('# code:'):
|
||||
match = re.search(r'# code:\s*(.+)', stripped)
|
||||
if match:
|
||||
items.append(('code', match.group(1).strip()))
|
||||
i += 1
|
||||
continue
|
||||
|
||||
if stripped.startswith('# inline:'):
|
||||
i += 1
|
||||
code_lines = []
|
||||
while i < len(lines):
|
||||
stripped = lines[i].strip()
|
||||
if stripped.startswith('#'):
|
||||
if stripped.startswith('# code:') or stripped.startswith('# inline:') or stripped == '# markdown' or stripped.startswith('# ==='):
|
||||
break
|
||||
i += 1
|
||||
continue
|
||||
if stripped == '':
|
||||
i += 1
|
||||
continue
|
||||
code_lines.append(lines[i])
|
||||
i += 1
|
||||
if code_lines:
|
||||
items.append(('inline', '\n'.join(code_lines).rstrip()))
|
||||
continue
|
||||
|
||||
if stripped == '# markdown':
|
||||
i += 1
|
||||
if i < len(lines):
|
||||
next_line = lines[i].strip()
|
||||
if next_line.startswith('"""'):
|
||||
if next_line.endswith('"""') and len(next_line) > 3:
|
||||
md_content = next_line[3:-3].strip()
|
||||
i += 1
|
||||
else:
|
||||
end_idx = None
|
||||
for j in range(i + 1, len(lines)):
|
||||
if '"""' in lines[j]:
|
||||
end_idx = j
|
||||
break
|
||||
if end_idx:
|
||||
md_content = '\n'.join(lines[i:end_idx])
|
||||
md_content = md_content.strip('"""').strip()
|
||||
i = end_idx + 1
|
||||
else:
|
||||
md_content = ""
|
||||
i += 1
|
||||
items.append(('markdown', md_content))
|
||||
continue
|
||||
|
||||
i += 1
|
||||
|
||||
return imports, items
|
||||
|
||||
|
||||
def strip_all_imports(content):
|
||||
lines = content.split('\n')
|
||||
result_lines = []
|
||||
skip_block = False
|
||||
if_block_indent = 0
|
||||
|
||||
for line in lines:
|
||||
stripped = line.strip()
|
||||
|
||||
if stripped.startswith('if __name__') or stripped.startswith('if __name__ =='):
|
||||
skip_block = True
|
||||
if_block_indent = len(line) - len(line.lstrip())
|
||||
continue
|
||||
|
||||
if skip_block:
|
||||
current_indent = len(line) - len(line.lstrip())
|
||||
if line.strip() == '':
|
||||
continue
|
||||
if current_indent < if_block_indent:
|
||||
skip_block = False
|
||||
elif current_indent == if_block_indent and stripped.startswith('if '):
|
||||
skip_block = True
|
||||
if_block_indent = current_indent
|
||||
continue
|
||||
else:
|
||||
continue
|
||||
|
||||
if stripped.startswith('import ') or stripped.startswith('from '):
|
||||
continue
|
||||
|
||||
result_lines.append(line)
|
||||
|
||||
while result_lines and result_lines[-1].strip() == '':
|
||||
result_lines.pop()
|
||||
|
||||
return '\n'.join(result_lines)
|
||||
|
||||
|
||||
def read_src_file(ref, base_dir):
|
||||
ref_path = ref.replace('./src/', '').replace('src/', '')
|
||||
full_path = os.path.join(base_dir, 'src', ref_path.replace('./src/', '').lstrip('/'))
|
||||
|
||||
if not full_path.endswith('.py'):
|
||||
full_path += '.py'
|
||||
|
||||
with open(full_path, 'r', encoding='utf-8') as f:
|
||||
return f.read()
|
||||
|
||||
|
||||
def build_notebook(schema_path, output_path=None):
|
||||
schema_dir = os.path.dirname(os.path.abspath(schema_path))
|
||||
if output_path is None:
|
||||
output_path = os.path.join(schema_dir, 'notebook.gen.ipynb')
|
||||
|
||||
imports, items = parse_schema(schema_path)
|
||||
|
||||
cells = []
|
||||
|
||||
if imports:
|
||||
imports_cell = {
|
||||
"cell_type": "code",
|
||||
"execution_count": None,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [imp + "\n" for imp in imports]
|
||||
}
|
||||
cells.append(imports_cell)
|
||||
|
||||
for item_type, item_content in items:
|
||||
if item_type == 'markdown':
|
||||
md_cell = {
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": item_content + "\n"
|
||||
}
|
||||
cells.append(md_cell)
|
||||
elif item_type == 'inline':
|
||||
lines = item_content.split('\n')
|
||||
while lines and lines[-1].strip() == '':
|
||||
lines.pop()
|
||||
if lines:
|
||||
lines.append('')
|
||||
cell = {
|
||||
"cell_type": "code",
|
||||
"execution_count": None,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [line + "\n" for line in lines]
|
||||
}
|
||||
cells.append(cell)
|
||||
elif item_type == 'code':
|
||||
try:
|
||||
content = read_src_file(item_content, schema_dir)
|
||||
content = strip_all_imports(content)
|
||||
|
||||
lines = content.split('\n')
|
||||
while lines and lines[-1].strip() == '':
|
||||
lines.pop()
|
||||
if lines:
|
||||
lines.append('')
|
||||
|
||||
cell = {
|
||||
"cell_type": "code",
|
||||
"execution_count": None,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [line + "\n" for line in lines]
|
||||
}
|
||||
cells.append(cell)
|
||||
except FileNotFoundError:
|
||||
print(f"Warning: Could not find: {item_content}")
|
||||
|
||||
notebook = {
|
||||
"cells": cells,
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": ".venv",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"name": "python",
|
||||
"version": "3.11.0"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
||||
|
||||
with open(output_path, 'w', encoding='utf-8') as f:
|
||||
json.dump(notebook, f, indent=1, ensure_ascii=False)
|
||||
|
||||
print(f"Notebook generated: {output_path}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
script_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
schema_path = os.path.join(script_dir, '_schema.py')
|
||||
|
||||
if os.path.exists(schema_path):
|
||||
build_notebook(schema_path)
|
||||
else:
|
||||
print(f"Error: _schema.py not found at {schema_path}")
|
||||
501
models/SiaN/notebook.gen.ipynb
Normal file
501
models/SiaN/notebook.gen.ipynb
Normal file
@@ -0,0 +1,501 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import os\n",
|
||||
"import random\n",
|
||||
"import logging\n",
|
||||
"from typing import Tuple\n",
|
||||
"import cv2\n",
|
||||
"import numpy as np\n",
|
||||
"import torch\n",
|
||||
"import torch.nn as nn\n",
|
||||
"import torch.optim as optim\n",
|
||||
"import matplotlib.pyplot as plt\n",
|
||||
"from PIL import Image\n",
|
||||
"from torch.utils.data import DataLoader, Dataset, Subset\n",
|
||||
"from torch.utils.tensorboard import SummaryWriter\n",
|
||||
"from torchvision import transforms, models\n",
|
||||
"from tqdm import tqdm\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"\n",
|
||||
"\n",
|
||||
"config = {\n",
|
||||
" \"data_dir\": r\"C:\\Users\\admin\\Projects\\autopilot\\datasets\\ya_go_maps\\images\",\n",
|
||||
" \"image_size\": (256, 256),\n",
|
||||
" \"batch_size\": 32,\n",
|
||||
" \"train_split\": 0.8,\n",
|
||||
" \"num_workers\": 0,\n",
|
||||
" \"epochs\": 100,\n",
|
||||
" \"learning_rate\": 2e-4,\n",
|
||||
" \"dropout_rate\": 0.3,\n",
|
||||
" \"backbone\": \"resnet18\",\n",
|
||||
" \"output_dir\": r\"C:\\Users\\admin\\Projects\\autopilot\\models\\SiaN\\runs\",\n",
|
||||
"}\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def get_camera_matrix(w, h):\n",
|
||||
" return np.array([[w / 2, 0, w / 2], [0, h / 2, h / 2], [0, 0, 1]], dtype=np.float32)\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def generate_random_homography_params(angle_range=10, translation_range=0.1, scale_range=(0.9, 1.1)):\n",
|
||||
" scale = np.random.uniform(*scale_range)\n",
|
||||
" tx = np.random.uniform(-translation_range, translation_range)\n",
|
||||
" ty = np.random.uniform(-translation_range, translation_range)\n",
|
||||
" rx = np.radians(np.random.uniform(-angle_range, angle_range))\n",
|
||||
" ry = np.radians(np.random.uniform(-angle_range, angle_range))\n",
|
||||
" rz = np.radians(np.random.uniform(-angle_range, angle_range))\n",
|
||||
" return np.array([rx, ry, rz, tx, ty, scale])\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def homography_params_to_matrix(params, K):\n",
|
||||
" rx, ry, rz, tx, ty, scale = params\n",
|
||||
" cy, sy = np.cos(rz), np.sin(rz)\n",
|
||||
" cp, sp = np.cos(ry), np.sin(ry)\n",
|
||||
" cr, sr = np.cos(rx), np.sin(rx)\n",
|
||||
" Rz = np.array([[cy, -sy, 0], [sy, cy, 0], [0, 0, 1]], dtype=np.float32)\n",
|
||||
" Ry = np.array([[cp, 0, sp], [0, 1, 0], [-sp, 0, cp]], dtype=np.float32)\n",
|
||||
" Rx = np.array([[1, 0, 0], [0, cr, -sr], [0, sr, cr]], dtype=np.float32)\n",
|
||||
" T = np.array([[1, 0, tx], [0, 1, ty], [0, 0, scale]], dtype=np.float32)\n",
|
||||
" return K @ Rx @ Ry @ Rz @ T @ np.linalg.inv(K)\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def matrix_to_homography_params(H, K):\n",
|
||||
" K_inv = np.linalg.inv(K)\n",
|
||||
" E = K_inv @ H @ K\n",
|
||||
" scale = np.sqrt(np.linalg.det(E[:2, :2]))\n",
|
||||
" R = E[:2, :2] / scale\n",
|
||||
" tx, ty = E[0, 2], E[1, 2]\n",
|
||||
" rz = np.arctan2(R[1, 0], R[0, 0])\n",
|
||||
" r20, r21 = E[2, 0], E[2, 1]\n",
|
||||
" ry = np.arctan2(r20, r21)\n",
|
||||
" rx = np.arctan2(-E[1, 2], E[1, 1])\n",
|
||||
" return np.array([rx, ry, rz, tx, ty, scale], dtype=np.float32)\n",
|
||||
"\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": "# SiaN Model\n"
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"class YaGoDataset(Dataset):\n",
|
||||
" def __init__(self, root_dir: str, transform=None, augment: bool = True, \n",
|
||||
" image_size: Tuple[int, int] = (256, 256)):\n",
|
||||
" self.root_dir = root_dir\n",
|
||||
" self.transform = transform\n",
|
||||
" self.augment = augment\n",
|
||||
" self.image_size = image_size\n",
|
||||
" self.K = get_camera_matrix(image_size[1], image_size[0])\n",
|
||||
" self.image_pairs = self._discover_image_pairs()\n",
|
||||
"\n",
|
||||
" def _discover_image_pairs(self):\n",
|
||||
" pairs = []\n",
|
||||
" for f in os.listdir(self.root_dir):\n",
|
||||
" if f.endswith(\"_google.png\"):\n",
|
||||
" idx = f.split(\"_\")[0]\n",
|
||||
" yandex_path = os.path.join(self.root_dir, f\"{idx}_yandex.png\")\n",
|
||||
" if os.path.exists(yandex_path):\n",
|
||||
" pairs.append({\"idx\": int(idx), \"google\": os.path.join(self.root_dir, f), \"yandex\": yandex_path})\n",
|
||||
" return sorted(pairs, key=lambda x: x[\"idx\"])\n",
|
||||
"\n",
|
||||
" def __len__(self):\n",
|
||||
" return len(self.image_pairs)\n",
|
||||
"\n",
|
||||
" def __getitem__(self, idx):\n",
|
||||
" pair = self.image_pairs[idx]\n",
|
||||
" google_img = Image.open(pair[\"google\"]).convert(\"RGB\").resize((self.image_size[1], self.image_size[0]), Image.BILINEAR)\n",
|
||||
" yandex_img = Image.open(pair[\"yandex\"]).convert(\"RGB\").resize((self.image_size[1], self.image_size[0]), Image.BILINEAR)\n",
|
||||
"\n",
|
||||
" if self.augment:\n",
|
||||
" params1 = generate_random_homography_params()\n",
|
||||
" params2 = generate_random_homography_params()\n",
|
||||
" H1 = homography_params_to_matrix(params1, self.K)\n",
|
||||
" H2 = homography_params_to_matrix(params2, self.K)\n",
|
||||
" H_combined = np.linalg.inv(H1) @ H2\n",
|
||||
" yandex_img = Image.fromarray(cv2.warpPerspective(np.array(yandex_img), H1, self.image_size))\n",
|
||||
" google_img = Image.fromarray(cv2.warpPerspective(np.array(google_img), H2, self.image_size))\n",
|
||||
" target_params = matrix_to_homography_params(H_combined, self.K)\n",
|
||||
" target_matrix = H_combined\n",
|
||||
" else:\n",
|
||||
" target_params = np.zeros(6, dtype=np.float32)\n",
|
||||
" target_matrix = np.eye(3, dtype=np.float32)\n",
|
||||
"\n",
|
||||
" if self.transform:\n",
|
||||
" google_img = self.transform(google_img)\n",
|
||||
" yandex_img = self.transform(yandex_img)\n",
|
||||
"\n",
|
||||
" return {\n",
|
||||
" \"google_img\": google_img,\n",
|
||||
" \"yandex_img\": yandex_img,\n",
|
||||
" \"homography_matrix\": torch.from_numpy(target_matrix).float(),\n",
|
||||
" \"homography_params\": torch.from_numpy(target_params).float(),\n",
|
||||
" }\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def create_data_loaders(root_dir, batch_size=32, train_split=0.8, num_workers=0, \n",
|
||||
" image_size=(256, 256), augment_train=True):\n",
|
||||
" transform = transforms.Compose([transforms.ToTensor()])\n",
|
||||
" \n",
|
||||
" full_ds = YaGoDataset(root_dir, transform=transform, augment=False, image_size=image_size)\n",
|
||||
" aug_ds = YaGoDataset(root_dir, transform=transform, augment=True, image_size=image_size)\n",
|
||||
"\n",
|
||||
" indices = list(range(len(full_ds)))\n",
|
||||
" random.shuffle(indices)\n",
|
||||
" split = int(train_split * len(indices))\n",
|
||||
" \n",
|
||||
" train_ds = Subset(aug_ds if augment_train else full_ds, indices[:split])\n",
|
||||
" val_ds = Subset(full_ds, indices[split:])\n",
|
||||
"\n",
|
||||
" return (DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=True),\n",
|
||||
" DataLoader(val_ds, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=True))\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def get_dataset_info():\n",
|
||||
" ds = YaGoDataset(config[\"data_dir\"], augment=True, image_size=config[\"image_size\"])\n",
|
||||
" return {\n",
|
||||
" \"size\": len(ds),\n",
|
||||
" \"sample_keys\": list(ds[0].keys()),\n",
|
||||
" \"sample_params\": ds[0][\"homography_params\"].numpy()\n",
|
||||
" }\n",
|
||||
"\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": "Dataset for Google/Yandex image pairs with homography augmentation.\n"
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"\n",
|
||||
"\n",
|
||||
"class HomographyCNN6(nn.Module):\n",
|
||||
" def __init__(self, input_channels=3, backbone_name=\"resnet18\", pretrained=True, dropout_rate=0.3):\n",
|
||||
" super().__init__()\n",
|
||||
" backbone = getattr(models, backbone_name)(weights=models.ResNet18_Weights.IMAGENET1K_V1 if pretrained else None)\n",
|
||||
" self.feature_dim = backbone.fc.in_features\n",
|
||||
" backbone.fc = nn.Identity()\n",
|
||||
" self.backbone = backbone\n",
|
||||
"\n",
|
||||
" self.head = nn.Sequential(\n",
|
||||
" nn.Linear(self.feature_dim * 4, 512),\n",
|
||||
" nn.ReLU(inplace=True),\n",
|
||||
" nn.Dropout(dropout_rate),\n",
|
||||
" nn.Linear(512, 256),\n",
|
||||
" nn.ReLU(inplace=True),\n",
|
||||
" nn.Dropout(dropout_rate),\n",
|
||||
" nn.Linear(256, 6),\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
" def forward(self, img1, img2):\n",
|
||||
" f1 = self.backbone(img1)\n",
|
||||
" f2 = self.backbone(img2)\n",
|
||||
" combined = torch.cat([f1, f2, torch.abs(f1 - f2), f1 * f2], dim=1)\n",
|
||||
" return self.head(combined)\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"class HomographyLoss6(nn.Module):\n",
|
||||
" def __init__(self):\n",
|
||||
" super().__init__()\n",
|
||||
" self.criterion = nn.MSELoss()\n",
|
||||
"\n",
|
||||
" def forward(self, pred, target):\n",
|
||||
" return self.criterion(pred, target)\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def count_parameters(model):\n",
|
||||
" return sum(p.numel() for p in model.parameters())\n",
|
||||
"\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": "HomographyCNN6 predicts 6 params: rx, ry, rz, tx, ty, scale.\n"
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"class HomographyTrainer:\n",
|
||||
" def __init__(self, model, train_loader, val_loader, device):\n",
|
||||
" self.model = model.to(device)\n",
|
||||
" self.train_loader = train_loader\n",
|
||||
" self.val_loader = val_loader\n",
|
||||
" self.device = device\n",
|
||||
" self.criterion = HomographyLoss6()\n",
|
||||
" self.optimizer = optim.Adam(model.parameters(), lr=config[\"learning_rate\"])\n",
|
||||
" self.writer = None\n",
|
||||
" self.best_val_loss = float(\"inf\")\n",
|
||||
"\n",
|
||||
" def train_epoch(self, epoch):\n",
|
||||
" self.model.train()\n",
|
||||
" total_loss, total_samples = 0, 0\n",
|
||||
" pbar = tqdm(self.train_loader, desc=f\"Epoch {epoch}\")\n",
|
||||
" for batch_idx, batch in enumerate(pbar):\n",
|
||||
" google_img = batch[\"google_img\"].to(self.device)\n",
|
||||
" yandex_img = batch[\"yandex_img\"].to(self.device)\n",
|
||||
" target = batch[\"homography_params\"].to(self.device)\n",
|
||||
"\n",
|
||||
" self.optimizer.zero_grad()\n",
|
||||
" output = self.model(google_img, yandex_img)\n",
|
||||
" loss = self.criterion(output, target)\n",
|
||||
" loss.backward()\n",
|
||||
" self.optimizer.step()\n",
|
||||
"\n",
|
||||
" total_loss += loss.item() * google_img.size(0)\n",
|
||||
" total_samples += google_img.size(0)\n",
|
||||
" pbar.set_postfix({\"loss\": loss.item()})\n",
|
||||
"\n",
|
||||
" return {\"loss\": total_loss / total_samples}\n",
|
||||
"\n",
|
||||
" def validate(self):\n",
|
||||
" self.model.eval()\n",
|
||||
" total_loss, total_samples = 0, 0\n",
|
||||
" with torch.no_grad():\n",
|
||||
" for batch in tqdm(self.val_loader, desc=\"Validation\"):\n",
|
||||
" google_img = batch[\"google_img\"].to(self.device)\n",
|
||||
" yandex_img = batch[\"yandex_img\"].to(self.device)\n",
|
||||
" target = batch[\"homography_params\"].to(self.device)\n",
|
||||
" output = self.model(google_img, yandex_img)\n",
|
||||
" loss = self.criterion(output, target)\n",
|
||||
" total_loss += loss.item() * google_img.size(0)\n",
|
||||
" total_samples += google_img.size(0)\n",
|
||||
" return {\"loss\": total_loss / total_samples}\n",
|
||||
"\n",
|
||||
" def train(self, num_epochs):\n",
|
||||
" log_dir = config[\"output_dir\"]\n",
|
||||
" os.makedirs(log_dir, exist_ok=True)\n",
|
||||
" self.writer = SummaryWriter(log_dir)\n",
|
||||
"\n",
|
||||
" for epoch in range(1, num_epochs + 1):\n",
|
||||
" train_metrics = self.train_epoch(epoch)\n",
|
||||
" val_metrics = self.validate()\n",
|
||||
" print(f\"Train Loss: {train_metrics['loss']:.4f}, Val Loss: {val_metrics['loss']:.4f}\")\n",
|
||||
"\n",
|
||||
" if val_metrics[\"loss\"] < self.best_val_loss:\n",
|
||||
" self.best_val_loss = val_metrics[\"loss\"]\n",
|
||||
" self.save_checkpoint(epoch, is_best=True)\n",
|
||||
" print(f\"Best model saved (val loss: {val_metrics['loss']:.4f})\")\n",
|
||||
"\n",
|
||||
" self.writer.close()\n",
|
||||
"\n",
|
||||
" def save_checkpoint(self, epoch, is_best=False):\n",
|
||||
" ckpt_dir = os.path.join(config[\"output_dir\"], \"checkpoints\")\n",
|
||||
" os.makedirs(ckpt_dir, exist_ok=True)\n",
|
||||
" ckpt = {\"epoch\": epoch, \"model_state_dict\": self.model.state_dict(), \"val_loss\": self.best_val_loss}\n",
|
||||
" torch.save(ckpt, os.path.join(ckpt_dir, f\"checkpoint_epoch_{epoch}.pt\"))\n",
|
||||
" if is_best:\n",
|
||||
" torch.save(ckpt, os.path.join(ckpt_dir, \"best_model.pt\"))\n",
|
||||
"\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": "HomographyTrainer manages training loop with validation.\n"
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"\n",
|
||||
"\n",
|
||||
"def analyze_training(trainer):\n",
|
||||
" print(\"=== Training Analysis ===\\n\")\n",
|
||||
"\n",
|
||||
" if trainer.writer:\n",
|
||||
" print(\"TensorBoard logs available at:\", trainer.writer.log_dir)\n",
|
||||
"\n",
|
||||
" print(f\"\\nBest val loss: {trainer.best_val_loss:.4f}\")\n",
|
||||
"\n",
|
||||
" trainer.model.eval()\n",
|
||||
" with torch.no_grad():\n",
|
||||
" batch = next(iter(trainer.val_loader))\n",
|
||||
" google_img = batch[\"google_img\"].to(trainer.device)\n",
|
||||
" yandex_img = batch[\"yandex_img\"].to(trainer.device)\n",
|
||||
" target_params = batch[\"homography_params\"].to(trainer.device)\n",
|
||||
"\n",
|
||||
" pred_params = trainer.model(google_img, yandex_img)\n",
|
||||
"\n",
|
||||
" print(f\"\\nSample predictions (first 3 of batch):\")\n",
|
||||
" print(f\"{'Param':<8} {'Target':>12} {'Predicted':>12} {'Error':>12}\")\n",
|
||||
" print(\"-\" * 46)\n",
|
||||
" names = [\"rx\", \"ry\", \"rz\", \"tx\", \"ty\", \"scale\"]\n",
|
||||
" for i in range(6):\n",
|
||||
" t = target_params[0, i].item()\n",
|
||||
" p = pred_params[0, i].item()\n",
|
||||
" print(f\"{names[i]:<8} {t:>12.4f} {p:>12.4f} {abs(t-p):>12.4f}\")\n",
|
||||
"\n",
|
||||
" print(f\"\\nBatch mean abs error: {torch.mean(torch.abs(pred_params - target_params)).item():.4f}\")\n",
|
||||
"\n",
|
||||
" print(\"\\n=== Visualization ===\")\n",
|
||||
" fig, axes = plt.subplots(1, 3, figsize=(15, 5))\n",
|
||||
" img1 = google_img[0].cpu()\n",
|
||||
" img2 = yandex_img[0].cpu()\n",
|
||||
" axes[0].imshow(img1.permute(1, 2, 0))\n",
|
||||
" axes[0].set_title(\"Google\")\n",
|
||||
" axes[0].axis(\"off\")\n",
|
||||
" axes[1].imshow(img2.permute(1, 2, 0))\n",
|
||||
" axes[1].set_title(\"Yandex\")\n",
|
||||
" axes[1].axis(\"off\")\n",
|
||||
" axes[2].bar(names, pred_params[0].cpu().numpy())\n",
|
||||
" axes[2].set_title(\"Predicted params\")\n",
|
||||
" axes[2].axhline(y=0, color=\"k\", lw=0.5)\n",
|
||||
" plt.tight_layout()\n",
|
||||
" plt.savefig(\"prediction_sample.png\")\n",
|
||||
" print(\"Saved prediction_sample.png\")\n",
|
||||
" plt.show()\n",
|
||||
"\n",
|
||||
" return {\"best_val_loss\": trainer.best_val_loss}\n",
|
||||
"\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": "Visualization and analysis of predictions.\n"
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"logging.basicConfig(level=logging.INFO, format=\"%(asctime)s - %(message)s\")\n",
|
||||
"logger = logging.getLogger(__name__)\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def create_dataset():\n",
|
||||
" logger.info(\"Creating data loaders...\")\n",
|
||||
" train_loader, val_loader = create_data_loaders(\n",
|
||||
" root_dir=config[\"data_dir\"],\n",
|
||||
" batch_size=config[\"batch_size\"],\n",
|
||||
" train_split=config[\"train_split\"],\n",
|
||||
" num_workers=config[\"num_workers\"],\n",
|
||||
" image_size=config[\"image_size\"],\n",
|
||||
" )\n",
|
||||
" logger.info(f\"Data loaders created: train={len(train_loader.dataset)}, val={len(val_loader.dataset)}\")\n",
|
||||
" return train_loader, val_loader\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def create_model():\n",
|
||||
" logger.info(\"Creating model...\")\n",
|
||||
" model = HomographyCNN6(\n",
|
||||
" input_channels=3,\n",
|
||||
" backbone_name=config[\"backbone\"],\n",
|
||||
" pretrained=True,\n",
|
||||
" dropout_rate=config[\"dropout_rate\"]\n",
|
||||
" )\n",
|
||||
" logger.info(f\"Model created with {count_parameters(model):,} parameters\")\n",
|
||||
" return model\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def train_model(model, train_loader, val_loader):\n",
|
||||
" device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
|
||||
" logger.info(f\"Using device: {device}\")\n",
|
||||
" \n",
|
||||
" trainer = HomographyTrainer(model, train_loader, val_loader, device)\n",
|
||||
" logger.info(\"Starting training...\")\n",
|
||||
" trainer.train(config[\"epochs\"])\n",
|
||||
" logger.info(\"Training completed\")\n",
|
||||
" return trainer\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def analyze_model(trainer):\n",
|
||||
" logger.info(\"Analyzing model...\")\n",
|
||||
" results = analyze_training(trainer)\n",
|
||||
" logger.info(f\"Analysis complete: best_val_loss={results['best_val_loss']:.4f}\")\n",
|
||||
" return results\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def main():\n",
|
||||
" logger.info(\"=\" * 50)\n",
|
||||
" logger.info(\"SiaN Training Pipeline\")\n",
|
||||
" logger.info(\"=\" * 50)\n",
|
||||
" \n",
|
||||
" dataset_info = get_dataset_info()\n",
|
||||
" logger.info(f\"Dataset: {dataset_info['size']} samples, keys={dataset_info['sample_keys']}\")\n",
|
||||
" \n",
|
||||
" train_loader, val_loader = create_dataset()\n",
|
||||
" model = create_model()\n",
|
||||
" trainer = train_model(model, train_loader, val_loader)\n",
|
||||
" results = analyze_model(trainer)\n",
|
||||
" \n",
|
||||
" logger.info(\"=\" * 50)\n",
|
||||
" logger.info(\"Pipeline completed successfully\")\n",
|
||||
" logger.info(\"=\" * 50)\n",
|
||||
" \n",
|
||||
" return trainer, results\n",
|
||||
"\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": "Run main() to execute the full pipeline.\n"
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"if __name__ == \"__main__\":\n",
|
||||
" main()\n",
|
||||
"\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": ".venv",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"name": "python",
|
||||
"version": "3.11.0"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
||||
File diff suppressed because one or more lines are too long
@@ -1,7 +1,6 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
from torchvision.utils import make_grid
|
||||
|
||||
|
||||
def analyze_training(trainer):
|
||||
@@ -51,8 +50,3 @@ def analyze_training(trainer):
|
||||
plt.show()
|
||||
|
||||
return {"best_val_loss": trainer.best_val_loss}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from train import trainer
|
||||
analyze_training(trainer)
|
||||
@@ -9,7 +9,7 @@ from PIL import Image
|
||||
from torch.utils.data import DataLoader, Dataset, Subset
|
||||
from torchvision import transforms
|
||||
|
||||
from utils import config, get_camera_matrix, generate_random_homography_params, homography_params_to_matrix, matrix_to_homography_params
|
||||
from .utils import config, get_camera_matrix, generate_random_homography_params, homography_params_to_matrix, matrix_to_homography_params
|
||||
|
||||
|
||||
class YaGoDataset(Dataset):
|
||||
@@ -84,9 +84,10 @@ def create_data_loaders(root_dir, batch_size=32, train_split=0.8, num_workers=0,
|
||||
DataLoader(val_ds, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=True))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
def get_dataset_info():
|
||||
ds = YaGoDataset(config["data_dir"], augment=True, image_size=config["image_size"])
|
||||
print(f"Dataset size: {len(ds)}")
|
||||
s = ds[0]
|
||||
print(f"Keys: {list(s.keys())}")
|
||||
print(f"Params: {s['homography_params'].numpy()}")
|
||||
return {
|
||||
"size": len(ds),
|
||||
"sample_keys": list(ds[0].keys()),
|
||||
"sample_params": ds[0]["homography_params"].numpy()
|
||||
}
|
||||
80
models/SiaN/src/main.py
Normal file
80
models/SiaN/src/main.py
Normal file
@@ -0,0 +1,80 @@
|
||||
import logging
|
||||
|
||||
import torch
|
||||
|
||||
from .dataloader import create_data_loaders, get_dataset_info
|
||||
from .model import HomographyCNN6, count_parameters
|
||||
from .train import HomographyTrainer
|
||||
from .analyze import analyze_training
|
||||
from .utils import config
|
||||
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(message)s")
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def create_dataset():
|
||||
logger.info("Creating data loaders...")
|
||||
train_loader, val_loader = create_data_loaders(
|
||||
root_dir=config["data_dir"],
|
||||
batch_size=config["batch_size"],
|
||||
train_split=config["train_split"],
|
||||
num_workers=config["num_workers"],
|
||||
image_size=config["image_size"],
|
||||
)
|
||||
logger.info(f"Data loaders created: train={len(train_loader.dataset)}, val={len(val_loader.dataset)}")
|
||||
return train_loader, val_loader
|
||||
|
||||
|
||||
def create_model():
|
||||
logger.info("Creating model...")
|
||||
model = HomographyCNN6(
|
||||
input_channels=3,
|
||||
backbone_name=config["backbone"],
|
||||
pretrained=True,
|
||||
dropout_rate=config["dropout_rate"]
|
||||
)
|
||||
logger.info(f"Model created with {count_parameters(model):,} parameters")
|
||||
return model
|
||||
|
||||
|
||||
def train_model(model, train_loader, val_loader):
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
logger.info(f"Using device: {device}")
|
||||
|
||||
trainer = HomographyTrainer(model, train_loader, val_loader, device)
|
||||
logger.info("Starting training...")
|
||||
trainer.train(config["epochs"])
|
||||
logger.info("Training completed")
|
||||
return trainer
|
||||
|
||||
|
||||
def analyze_model(trainer):
|
||||
logger.info("Analyzing model...")
|
||||
results = analyze_training(trainer)
|
||||
logger.info(f"Analysis complete: best_val_loss={results['best_val_loss']:.4f}")
|
||||
return results
|
||||
|
||||
|
||||
def main():
|
||||
logger.info("=" * 50)
|
||||
logger.info("SiaN Training Pipeline")
|
||||
logger.info("=" * 50)
|
||||
|
||||
dataset_info = get_dataset_info()
|
||||
logger.info(f"Dataset: {dataset_info['size']} samples, keys={dataset_info['sample_keys']}")
|
||||
|
||||
train_loader, val_loader = create_dataset()
|
||||
model = create_model()
|
||||
trainer = train_model(model, train_loader, val_loader)
|
||||
results = analyze_model(trainer)
|
||||
|
||||
logger.info("=" * 50)
|
||||
logger.info("Pipeline completed successfully")
|
||||
logger.info("=" * 50)
|
||||
|
||||
return trainer, results
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -37,9 +37,5 @@ class HomographyLoss6(nn.Module):
|
||||
return self.criterion(pred, target)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
model = HomographyCNN6()
|
||||
img1 = torch.randn(2, 3, 256, 256)
|
||||
img2 = torch.randn(2, 3, 256, 256)
|
||||
out = model(img1, img2)
|
||||
print(f"Output shape: {out.shape}, mean: {out.mean():.3f}")
|
||||
def count_parameters(model):
|
||||
return sum(p.numel() for p in model.parameters())
|
||||
@@ -3,12 +3,13 @@ import os
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
from dataloader import create_data_loaders
|
||||
from model import HomographyCNN6, HomographyLoss6
|
||||
from utils import config
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
from tqdm import tqdm
|
||||
|
||||
from .dataloader import create_data_loaders
|
||||
from .model import HomographyCNN6, HomographyLoss6, count_parameters
|
||||
from .utils import config
|
||||
|
||||
|
||||
class HomographyTrainer:
|
||||
def __init__(self, model, train_loader, val_loader, device):
|
||||
@@ -80,26 +81,3 @@ class HomographyTrainer:
|
||||
torch.save(ckpt, os.path.join(ckpt_dir, f"checkpoint_epoch_{epoch}.pt"))
|
||||
if is_best:
|
||||
torch.save(ckpt, os.path.join(ckpt_dir, "best_model.pt"))
|
||||
|
||||
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
print(f"Using device: {device}")
|
||||
|
||||
train_loader, val_loader = create_data_loaders(
|
||||
root_dir=config["data_dir"],
|
||||
batch_size=config["batch_size"],
|
||||
train_split=config["train_split"],
|
||||
num_workers=config["num_workers"],
|
||||
image_size=config["image_size"],
|
||||
)
|
||||
|
||||
model = HomographyCNN6(
|
||||
input_channels=3,
|
||||
backbone_name=config["backbone"],
|
||||
pretrained=True,
|
||||
dropout_rate=config["dropout_rate"]
|
||||
)
|
||||
print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")
|
||||
|
||||
trainer = HomographyTrainer(model, train_loader, val_loader, device)
|
||||
trainer.train(config["epochs"])
|
||||
Reference in New Issue
Block a user