add auto codegen

This commit is contained in:
2026-04-04 21:32:50 +03:00
parent 703ea8dbaf
commit b2cc714d79
12 changed files with 901 additions and 1320 deletions

View File

@@ -1 +1,2 @@
runs
runs
*.gen.py

23
models/SiaN/_schema.md Normal file
View File

@@ -0,0 +1,23 @@
# _schema.md
## Format
```
# === IMPORTS ===
<all imports>
# code: ./src/file.py
# markdown
"""Description"""
# inline:
<custom code>
```
## Directives
- `# code:` - include file from src/
- `# markdown` - description block
- `# inline:` - custom code cell
`build.py` generates notebook from `_schema.py`.

47
models/SiaN/_schema.py Normal file
View File

@@ -0,0 +1,47 @@
# _schema.py
# === IMPORTS ===
import os
import random
import logging
from typing import Tuple
import cv2
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
from PIL import Image
from torch.utils.data import DataLoader, Dataset, Subset
from torch.utils.tensorboard import SummaryWriter
from torchvision import transforms, models
from tqdm import tqdm
# code: ./src/utils.py
# markdown
"""# SiaN Model"""
# code: ./src/dataloader.py
# markdown
"""Dataset for Google/Yandex image pairs with homography augmentation."""
# code: ./src/model.py
# markdown
"""HomographyCNN6 predicts 6 params: rx, ry, rz, tx, ty, scale."""
# code: ./src/train.py
# markdown
"""HomographyTrainer manages training loop with validation."""
# code: ./src/analyze.py
# markdown
"""Visualization and analysis of predictions."""
# code: ./src/main.py
# markdown
"""Run main() to execute the full pipeline."""
# inline:
if __name__ == "__main__":
main()

235
models/SiaN/build.py Normal file
View File

@@ -0,0 +1,235 @@
import json
import os
import re
def parse_schema(schema_path):
with open(schema_path, 'r', encoding='utf-8') as f:
content = f.read()
imports = []
items = []
in_imports_section = False
lines = content.split('\n')
i = 0
while i < len(lines):
line = lines[i]
stripped = line.strip()
if stripped == '# === IMPORTS ===':
in_imports_section = True
i += 1
continue
elif stripped.startswith('# ==='):
in_imports_section = False
i += 1
continue
if in_imports_section and stripped and not stripped.startswith('#'):
imports.append(stripped)
i += 1
continue
if in_imports_section and stripped.startswith('#'):
in_imports_section = False
if stripped.startswith('# code:'):
match = re.search(r'# code:\s*(.+)', stripped)
if match:
items.append(('code', match.group(1).strip()))
i += 1
continue
if stripped.startswith('# inline:'):
i += 1
code_lines = []
while i < len(lines):
stripped = lines[i].strip()
if stripped.startswith('#'):
if stripped.startswith('# code:') or stripped.startswith('# inline:') or stripped == '# markdown' or stripped.startswith('# ==='):
break
i += 1
continue
if stripped == '':
i += 1
continue
code_lines.append(lines[i])
i += 1
if code_lines:
items.append(('inline', '\n'.join(code_lines).rstrip()))
continue
if stripped == '# markdown':
i += 1
if i < len(lines):
next_line = lines[i].strip()
if next_line.startswith('"""'):
if next_line.endswith('"""') and len(next_line) > 3:
md_content = next_line[3:-3].strip()
i += 1
else:
end_idx = None
for j in range(i + 1, len(lines)):
if '"""' in lines[j]:
end_idx = j
break
if end_idx:
md_content = '\n'.join(lines[i:end_idx])
md_content = md_content.strip('"""').strip()
i = end_idx + 1
else:
md_content = ""
i += 1
items.append(('markdown', md_content))
continue
i += 1
return imports, items
def strip_all_imports(content):
lines = content.split('\n')
result_lines = []
skip_block = False
if_block_indent = 0
for line in lines:
stripped = line.strip()
if stripped.startswith('if __name__') or stripped.startswith('if __name__ =='):
skip_block = True
if_block_indent = len(line) - len(line.lstrip())
continue
if skip_block:
current_indent = len(line) - len(line.lstrip())
if line.strip() == '':
continue
if current_indent < if_block_indent:
skip_block = False
elif current_indent == if_block_indent and stripped.startswith('if '):
skip_block = True
if_block_indent = current_indent
continue
else:
continue
if stripped.startswith('import ') or stripped.startswith('from '):
continue
result_lines.append(line)
while result_lines and result_lines[-1].strip() == '':
result_lines.pop()
return '\n'.join(result_lines)
def read_src_file(ref, base_dir):
ref_path = ref.replace('./src/', '').replace('src/', '')
full_path = os.path.join(base_dir, 'src', ref_path.replace('./src/', '').lstrip('/'))
if not full_path.endswith('.py'):
full_path += '.py'
with open(full_path, 'r', encoding='utf-8') as f:
return f.read()
def build_notebook(schema_path, output_path=None):
schema_dir = os.path.dirname(os.path.abspath(schema_path))
if output_path is None:
output_path = os.path.join(schema_dir, 'notebook.gen.ipynb')
imports, items = parse_schema(schema_path)
cells = []
if imports:
imports_cell = {
"cell_type": "code",
"execution_count": None,
"metadata": {},
"outputs": [],
"source": [imp + "\n" for imp in imports]
}
cells.append(imports_cell)
for item_type, item_content in items:
if item_type == 'markdown':
md_cell = {
"cell_type": "markdown",
"metadata": {},
"source": item_content + "\n"
}
cells.append(md_cell)
elif item_type == 'inline':
lines = item_content.split('\n')
while lines and lines[-1].strip() == '':
lines.pop()
if lines:
lines.append('')
cell = {
"cell_type": "code",
"execution_count": None,
"metadata": {},
"outputs": [],
"source": [line + "\n" for line in lines]
}
cells.append(cell)
elif item_type == 'code':
try:
content = read_src_file(item_content, schema_dir)
content = strip_all_imports(content)
lines = content.split('\n')
while lines and lines[-1].strip() == '':
lines.pop()
if lines:
lines.append('')
cell = {
"cell_type": "code",
"execution_count": None,
"metadata": {},
"outputs": [],
"source": [line + "\n" for line in lines]
}
cells.append(cell)
except FileNotFoundError:
print(f"Warning: Could not find: {item_content}")
notebook = {
"cells": cells,
"metadata": {
"kernelspec": {
"display_name": ".venv",
"language": "python",
"name": "python3"
},
"language_info": {
"name": "python",
"version": "3.11.0"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
with open(output_path, 'w', encoding='utf-8') as f:
json.dump(notebook, f, indent=1, ensure_ascii=False)
print(f"Notebook generated: {output_path}")
if __name__ == "__main__":
script_dir = os.path.dirname(os.path.abspath(__file__))
schema_path = os.path.join(script_dir, '_schema.py')
if os.path.exists(schema_path):
build_notebook(schema_path)
else:
print(f"Error: _schema.py not found at {schema_path}")

View File

@@ -0,0 +1,501 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"import random\n",
"import logging\n",
"from typing import Tuple\n",
"import cv2\n",
"import numpy as np\n",
"import torch\n",
"import torch.nn as nn\n",
"import torch.optim as optim\n",
"import matplotlib.pyplot as plt\n",
"from PIL import Image\n",
"from torch.utils.data import DataLoader, Dataset, Subset\n",
"from torch.utils.tensorboard import SummaryWriter\n",
"from torchvision import transforms, models\n",
"from tqdm import tqdm\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"\n",
"\n",
"config = {\n",
" \"data_dir\": r\"C:\\Users\\admin\\Projects\\autopilot\\datasets\\ya_go_maps\\images\",\n",
" \"image_size\": (256, 256),\n",
" \"batch_size\": 32,\n",
" \"train_split\": 0.8,\n",
" \"num_workers\": 0,\n",
" \"epochs\": 100,\n",
" \"learning_rate\": 2e-4,\n",
" \"dropout_rate\": 0.3,\n",
" \"backbone\": \"resnet18\",\n",
" \"output_dir\": r\"C:\\Users\\admin\\Projects\\autopilot\\models\\SiaN\\runs\",\n",
"}\n",
"\n",
"\n",
"def get_camera_matrix(w, h):\n",
" return np.array([[w / 2, 0, w / 2], [0, h / 2, h / 2], [0, 0, 1]], dtype=np.float32)\n",
"\n",
"\n",
"def generate_random_homography_params(angle_range=10, translation_range=0.1, scale_range=(0.9, 1.1)):\n",
" scale = np.random.uniform(*scale_range)\n",
" tx = np.random.uniform(-translation_range, translation_range)\n",
" ty = np.random.uniform(-translation_range, translation_range)\n",
" rx = np.radians(np.random.uniform(-angle_range, angle_range))\n",
" ry = np.radians(np.random.uniform(-angle_range, angle_range))\n",
" rz = np.radians(np.random.uniform(-angle_range, angle_range))\n",
" return np.array([rx, ry, rz, tx, ty, scale])\n",
"\n",
"\n",
"def homography_params_to_matrix(params, K):\n",
" rx, ry, rz, tx, ty, scale = params\n",
" cy, sy = np.cos(rz), np.sin(rz)\n",
" cp, sp = np.cos(ry), np.sin(ry)\n",
" cr, sr = np.cos(rx), np.sin(rx)\n",
" Rz = np.array([[cy, -sy, 0], [sy, cy, 0], [0, 0, 1]], dtype=np.float32)\n",
" Ry = np.array([[cp, 0, sp], [0, 1, 0], [-sp, 0, cp]], dtype=np.float32)\n",
" Rx = np.array([[1, 0, 0], [0, cr, -sr], [0, sr, cr]], dtype=np.float32)\n",
" T = np.array([[1, 0, tx], [0, 1, ty], [0, 0, scale]], dtype=np.float32)\n",
" return K @ Rx @ Ry @ Rz @ T @ np.linalg.inv(K)\n",
"\n",
"\n",
"def matrix_to_homography_params(H, K):\n",
" K_inv = np.linalg.inv(K)\n",
" E = K_inv @ H @ K\n",
" scale = np.sqrt(np.linalg.det(E[:2, :2]))\n",
" R = E[:2, :2] / scale\n",
" tx, ty = E[0, 2], E[1, 2]\n",
" rz = np.arctan2(R[1, 0], R[0, 0])\n",
" r20, r21 = E[2, 0], E[2, 1]\n",
" ry = np.arctan2(r20, r21)\n",
" rx = np.arctan2(-E[1, 2], E[1, 1])\n",
" return np.array([rx, ry, rz, tx, ty, scale], dtype=np.float32)\n",
"\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": "# SiaN Model\n"
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"\n",
"\n",
"\n",
"\n",
"class YaGoDataset(Dataset):\n",
" def __init__(self, root_dir: str, transform=None, augment: bool = True, \n",
" image_size: Tuple[int, int] = (256, 256)):\n",
" self.root_dir = root_dir\n",
" self.transform = transform\n",
" self.augment = augment\n",
" self.image_size = image_size\n",
" self.K = get_camera_matrix(image_size[1], image_size[0])\n",
" self.image_pairs = self._discover_image_pairs()\n",
"\n",
" def _discover_image_pairs(self):\n",
" pairs = []\n",
" for f in os.listdir(self.root_dir):\n",
" if f.endswith(\"_google.png\"):\n",
" idx = f.split(\"_\")[0]\n",
" yandex_path = os.path.join(self.root_dir, f\"{idx}_yandex.png\")\n",
" if os.path.exists(yandex_path):\n",
" pairs.append({\"idx\": int(idx), \"google\": os.path.join(self.root_dir, f), \"yandex\": yandex_path})\n",
" return sorted(pairs, key=lambda x: x[\"idx\"])\n",
"\n",
" def __len__(self):\n",
" return len(self.image_pairs)\n",
"\n",
" def __getitem__(self, idx):\n",
" pair = self.image_pairs[idx]\n",
" google_img = Image.open(pair[\"google\"]).convert(\"RGB\").resize((self.image_size[1], self.image_size[0]), Image.BILINEAR)\n",
" yandex_img = Image.open(pair[\"yandex\"]).convert(\"RGB\").resize((self.image_size[1], self.image_size[0]), Image.BILINEAR)\n",
"\n",
" if self.augment:\n",
" params1 = generate_random_homography_params()\n",
" params2 = generate_random_homography_params()\n",
" H1 = homography_params_to_matrix(params1, self.K)\n",
" H2 = homography_params_to_matrix(params2, self.K)\n",
" H_combined = np.linalg.inv(H1) @ H2\n",
" yandex_img = Image.fromarray(cv2.warpPerspective(np.array(yandex_img), H1, self.image_size))\n",
" google_img = Image.fromarray(cv2.warpPerspective(np.array(google_img), H2, self.image_size))\n",
" target_params = matrix_to_homography_params(H_combined, self.K)\n",
" target_matrix = H_combined\n",
" else:\n",
" target_params = np.zeros(6, dtype=np.float32)\n",
" target_matrix = np.eye(3, dtype=np.float32)\n",
"\n",
" if self.transform:\n",
" google_img = self.transform(google_img)\n",
" yandex_img = self.transform(yandex_img)\n",
"\n",
" return {\n",
" \"google_img\": google_img,\n",
" \"yandex_img\": yandex_img,\n",
" \"homography_matrix\": torch.from_numpy(target_matrix).float(),\n",
" \"homography_params\": torch.from_numpy(target_params).float(),\n",
" }\n",
"\n",
"\n",
"def create_data_loaders(root_dir, batch_size=32, train_split=0.8, num_workers=0, \n",
" image_size=(256, 256), augment_train=True):\n",
" transform = transforms.Compose([transforms.ToTensor()])\n",
" \n",
" full_ds = YaGoDataset(root_dir, transform=transform, augment=False, image_size=image_size)\n",
" aug_ds = YaGoDataset(root_dir, transform=transform, augment=True, image_size=image_size)\n",
"\n",
" indices = list(range(len(full_ds)))\n",
" random.shuffle(indices)\n",
" split = int(train_split * len(indices))\n",
" \n",
" train_ds = Subset(aug_ds if augment_train else full_ds, indices[:split])\n",
" val_ds = Subset(full_ds, indices[split:])\n",
"\n",
" return (DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=True),\n",
" DataLoader(val_ds, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=True))\n",
"\n",
"\n",
"def get_dataset_info():\n",
" ds = YaGoDataset(config[\"data_dir\"], augment=True, image_size=config[\"image_size\"])\n",
" return {\n",
" \"size\": len(ds),\n",
" \"sample_keys\": list(ds[0].keys()),\n",
" \"sample_params\": ds[0][\"homography_params\"].numpy()\n",
" }\n",
"\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": "Dataset for Google/Yandex image pairs with homography augmentation.\n"
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"\n",
"\n",
"class HomographyCNN6(nn.Module):\n",
" def __init__(self, input_channels=3, backbone_name=\"resnet18\", pretrained=True, dropout_rate=0.3):\n",
" super().__init__()\n",
" backbone = getattr(models, backbone_name)(weights=models.ResNet18_Weights.IMAGENET1K_V1 if pretrained else None)\n",
" self.feature_dim = backbone.fc.in_features\n",
" backbone.fc = nn.Identity()\n",
" self.backbone = backbone\n",
"\n",
" self.head = nn.Sequential(\n",
" nn.Linear(self.feature_dim * 4, 512),\n",
" nn.ReLU(inplace=True),\n",
" nn.Dropout(dropout_rate),\n",
" nn.Linear(512, 256),\n",
" nn.ReLU(inplace=True),\n",
" nn.Dropout(dropout_rate),\n",
" nn.Linear(256, 6),\n",
" )\n",
"\n",
" def forward(self, img1, img2):\n",
" f1 = self.backbone(img1)\n",
" f2 = self.backbone(img2)\n",
" combined = torch.cat([f1, f2, torch.abs(f1 - f2), f1 * f2], dim=1)\n",
" return self.head(combined)\n",
"\n",
"\n",
"class HomographyLoss6(nn.Module):\n",
" def __init__(self):\n",
" super().__init__()\n",
" self.criterion = nn.MSELoss()\n",
"\n",
" def forward(self, pred, target):\n",
" return self.criterion(pred, target)\n",
"\n",
"\n",
"def count_parameters(model):\n",
" return sum(p.numel() for p in model.parameters())\n",
"\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": "HomographyCNN6 predicts 6 params: rx, ry, rz, tx, ty, scale.\n"
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"\n",
"\n",
"\n",
"\n",
"class HomographyTrainer:\n",
" def __init__(self, model, train_loader, val_loader, device):\n",
" self.model = model.to(device)\n",
" self.train_loader = train_loader\n",
" self.val_loader = val_loader\n",
" self.device = device\n",
" self.criterion = HomographyLoss6()\n",
" self.optimizer = optim.Adam(model.parameters(), lr=config[\"learning_rate\"])\n",
" self.writer = None\n",
" self.best_val_loss = float(\"inf\")\n",
"\n",
" def train_epoch(self, epoch):\n",
" self.model.train()\n",
" total_loss, total_samples = 0, 0\n",
" pbar = tqdm(self.train_loader, desc=f\"Epoch {epoch}\")\n",
" for batch_idx, batch in enumerate(pbar):\n",
" google_img = batch[\"google_img\"].to(self.device)\n",
" yandex_img = batch[\"yandex_img\"].to(self.device)\n",
" target = batch[\"homography_params\"].to(self.device)\n",
"\n",
" self.optimizer.zero_grad()\n",
" output = self.model(google_img, yandex_img)\n",
" loss = self.criterion(output, target)\n",
" loss.backward()\n",
" self.optimizer.step()\n",
"\n",
" total_loss += loss.item() * google_img.size(0)\n",
" total_samples += google_img.size(0)\n",
" pbar.set_postfix({\"loss\": loss.item()})\n",
"\n",
" return {\"loss\": total_loss / total_samples}\n",
"\n",
" def validate(self):\n",
" self.model.eval()\n",
" total_loss, total_samples = 0, 0\n",
" with torch.no_grad():\n",
" for batch in tqdm(self.val_loader, desc=\"Validation\"):\n",
" google_img = batch[\"google_img\"].to(self.device)\n",
" yandex_img = batch[\"yandex_img\"].to(self.device)\n",
" target = batch[\"homography_params\"].to(self.device)\n",
" output = self.model(google_img, yandex_img)\n",
" loss = self.criterion(output, target)\n",
" total_loss += loss.item() * google_img.size(0)\n",
" total_samples += google_img.size(0)\n",
" return {\"loss\": total_loss / total_samples}\n",
"\n",
" def train(self, num_epochs):\n",
" log_dir = config[\"output_dir\"]\n",
" os.makedirs(log_dir, exist_ok=True)\n",
" self.writer = SummaryWriter(log_dir)\n",
"\n",
" for epoch in range(1, num_epochs + 1):\n",
" train_metrics = self.train_epoch(epoch)\n",
" val_metrics = self.validate()\n",
" print(f\"Train Loss: {train_metrics['loss']:.4f}, Val Loss: {val_metrics['loss']:.4f}\")\n",
"\n",
" if val_metrics[\"loss\"] < self.best_val_loss:\n",
" self.best_val_loss = val_metrics[\"loss\"]\n",
" self.save_checkpoint(epoch, is_best=True)\n",
" print(f\"Best model saved (val loss: {val_metrics['loss']:.4f})\")\n",
"\n",
" self.writer.close()\n",
"\n",
" def save_checkpoint(self, epoch, is_best=False):\n",
" ckpt_dir = os.path.join(config[\"output_dir\"], \"checkpoints\")\n",
" os.makedirs(ckpt_dir, exist_ok=True)\n",
" ckpt = {\"epoch\": epoch, \"model_state_dict\": self.model.state_dict(), \"val_loss\": self.best_val_loss}\n",
" torch.save(ckpt, os.path.join(ckpt_dir, f\"checkpoint_epoch_{epoch}.pt\"))\n",
" if is_best:\n",
" torch.save(ckpt, os.path.join(ckpt_dir, \"best_model.pt\"))\n",
"\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": "HomographyTrainer manages training loop with validation.\n"
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"\n",
"\n",
"def analyze_training(trainer):\n",
" print(\"=== Training Analysis ===\\n\")\n",
"\n",
" if trainer.writer:\n",
" print(\"TensorBoard logs available at:\", trainer.writer.log_dir)\n",
"\n",
" print(f\"\\nBest val loss: {trainer.best_val_loss:.4f}\")\n",
"\n",
" trainer.model.eval()\n",
" with torch.no_grad():\n",
" batch = next(iter(trainer.val_loader))\n",
" google_img = batch[\"google_img\"].to(trainer.device)\n",
" yandex_img = batch[\"yandex_img\"].to(trainer.device)\n",
" target_params = batch[\"homography_params\"].to(trainer.device)\n",
"\n",
" pred_params = trainer.model(google_img, yandex_img)\n",
"\n",
" print(f\"\\nSample predictions (first 3 of batch):\")\n",
" print(f\"{'Param':<8} {'Target':>12} {'Predicted':>12} {'Error':>12}\")\n",
" print(\"-\" * 46)\n",
" names = [\"rx\", \"ry\", \"rz\", \"tx\", \"ty\", \"scale\"]\n",
" for i in range(6):\n",
" t = target_params[0, i].item()\n",
" p = pred_params[0, i].item()\n",
" print(f\"{names[i]:<8} {t:>12.4f} {p:>12.4f} {abs(t-p):>12.4f}\")\n",
"\n",
" print(f\"\\nBatch mean abs error: {torch.mean(torch.abs(pred_params - target_params)).item():.4f}\")\n",
"\n",
" print(\"\\n=== Visualization ===\")\n",
" fig, axes = plt.subplots(1, 3, figsize=(15, 5))\n",
" img1 = google_img[0].cpu()\n",
" img2 = yandex_img[0].cpu()\n",
" axes[0].imshow(img1.permute(1, 2, 0))\n",
" axes[0].set_title(\"Google\")\n",
" axes[0].axis(\"off\")\n",
" axes[1].imshow(img2.permute(1, 2, 0))\n",
" axes[1].set_title(\"Yandex\")\n",
" axes[1].axis(\"off\")\n",
" axes[2].bar(names, pred_params[0].cpu().numpy())\n",
" axes[2].set_title(\"Predicted params\")\n",
" axes[2].axhline(y=0, color=\"k\", lw=0.5)\n",
" plt.tight_layout()\n",
" plt.savefig(\"prediction_sample.png\")\n",
" print(\"Saved prediction_sample.png\")\n",
" plt.show()\n",
"\n",
" return {\"best_val_loss\": trainer.best_val_loss}\n",
"\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": "Visualization and analysis of predictions.\n"
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"\n",
"\n",
"\n",
"\n",
"logging.basicConfig(level=logging.INFO, format=\"%(asctime)s - %(message)s\")\n",
"logger = logging.getLogger(__name__)\n",
"\n",
"\n",
"def create_dataset():\n",
" logger.info(\"Creating data loaders...\")\n",
" train_loader, val_loader = create_data_loaders(\n",
" root_dir=config[\"data_dir\"],\n",
" batch_size=config[\"batch_size\"],\n",
" train_split=config[\"train_split\"],\n",
" num_workers=config[\"num_workers\"],\n",
" image_size=config[\"image_size\"],\n",
" )\n",
" logger.info(f\"Data loaders created: train={len(train_loader.dataset)}, val={len(val_loader.dataset)}\")\n",
" return train_loader, val_loader\n",
"\n",
"\n",
"def create_model():\n",
" logger.info(\"Creating model...\")\n",
" model = HomographyCNN6(\n",
" input_channels=3,\n",
" backbone_name=config[\"backbone\"],\n",
" pretrained=True,\n",
" dropout_rate=config[\"dropout_rate\"]\n",
" )\n",
" logger.info(f\"Model created with {count_parameters(model):,} parameters\")\n",
" return model\n",
"\n",
"\n",
"def train_model(model, train_loader, val_loader):\n",
" device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
" logger.info(f\"Using device: {device}\")\n",
" \n",
" trainer = HomographyTrainer(model, train_loader, val_loader, device)\n",
" logger.info(\"Starting training...\")\n",
" trainer.train(config[\"epochs\"])\n",
" logger.info(\"Training completed\")\n",
" return trainer\n",
"\n",
"\n",
"def analyze_model(trainer):\n",
" logger.info(\"Analyzing model...\")\n",
" results = analyze_training(trainer)\n",
" logger.info(f\"Analysis complete: best_val_loss={results['best_val_loss']:.4f}\")\n",
" return results\n",
"\n",
"\n",
"def main():\n",
" logger.info(\"=\" * 50)\n",
" logger.info(\"SiaN Training Pipeline\")\n",
" logger.info(\"=\" * 50)\n",
" \n",
" dataset_info = get_dataset_info()\n",
" logger.info(f\"Dataset: {dataset_info['size']} samples, keys={dataset_info['sample_keys']}\")\n",
" \n",
" train_loader, val_loader = create_dataset()\n",
" model = create_model()\n",
" trainer = train_model(model, train_loader, val_loader)\n",
" results = analyze_model(trainer)\n",
" \n",
" logger.info(\"=\" * 50)\n",
" logger.info(\"Pipeline completed successfully\")\n",
" logger.info(\"=\" * 50)\n",
" \n",
" return trainer, results\n",
"\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": "Run main() to execute the full pipeline.\n"
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"if __name__ == \"__main__\":\n",
" main()\n",
"\n"
]
}
],
"metadata": {
"kernelspec": {
"display_name": ".venv",
"language": "python",
"name": "python3"
},
"language_info": {
"name": "python",
"version": "3.11.0"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

File diff suppressed because one or more lines are too long

View File

@@ -1,7 +1,6 @@
import torch
import numpy as np
import matplotlib.pyplot as plt
from torchvision.utils import make_grid
def analyze_training(trainer):
@@ -51,8 +50,3 @@ def analyze_training(trainer):
plt.show()
return {"best_val_loss": trainer.best_val_loss}
if __name__ == "__main__":
from train import trainer
analyze_training(trainer)

View File

@@ -9,7 +9,7 @@ from PIL import Image
from torch.utils.data import DataLoader, Dataset, Subset
from torchvision import transforms
from utils import config, get_camera_matrix, generate_random_homography_params, homography_params_to_matrix, matrix_to_homography_params
from .utils import config, get_camera_matrix, generate_random_homography_params, homography_params_to_matrix, matrix_to_homography_params
class YaGoDataset(Dataset):
@@ -84,9 +84,10 @@ def create_data_loaders(root_dir, batch_size=32, train_split=0.8, num_workers=0,
DataLoader(val_ds, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=True))
if __name__ == "__main__":
def get_dataset_info():
ds = YaGoDataset(config["data_dir"], augment=True, image_size=config["image_size"])
print(f"Dataset size: {len(ds)}")
s = ds[0]
print(f"Keys: {list(s.keys())}")
print(f"Params: {s['homography_params'].numpy()}")
return {
"size": len(ds),
"sample_keys": list(ds[0].keys()),
"sample_params": ds[0]["homography_params"].numpy()
}

80
models/SiaN/src/main.py Normal file
View File

@@ -0,0 +1,80 @@
import logging
import torch
from .dataloader import create_data_loaders, get_dataset_info
from .model import HomographyCNN6, count_parameters
from .train import HomographyTrainer
from .analyze import analyze_training
from .utils import config
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(message)s")
logger = logging.getLogger(__name__)
def create_dataset():
logger.info("Creating data loaders...")
train_loader, val_loader = create_data_loaders(
root_dir=config["data_dir"],
batch_size=config["batch_size"],
train_split=config["train_split"],
num_workers=config["num_workers"],
image_size=config["image_size"],
)
logger.info(f"Data loaders created: train={len(train_loader.dataset)}, val={len(val_loader.dataset)}")
return train_loader, val_loader
def create_model():
logger.info("Creating model...")
model = HomographyCNN6(
input_channels=3,
backbone_name=config["backbone"],
pretrained=True,
dropout_rate=config["dropout_rate"]
)
logger.info(f"Model created with {count_parameters(model):,} parameters")
return model
def train_model(model, train_loader, val_loader):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logger.info(f"Using device: {device}")
trainer = HomographyTrainer(model, train_loader, val_loader, device)
logger.info("Starting training...")
trainer.train(config["epochs"])
logger.info("Training completed")
return trainer
def analyze_model(trainer):
logger.info("Analyzing model...")
results = analyze_training(trainer)
logger.info(f"Analysis complete: best_val_loss={results['best_val_loss']:.4f}")
return results
def main():
logger.info("=" * 50)
logger.info("SiaN Training Pipeline")
logger.info("=" * 50)
dataset_info = get_dataset_info()
logger.info(f"Dataset: {dataset_info['size']} samples, keys={dataset_info['sample_keys']}")
train_loader, val_loader = create_dataset()
model = create_model()
trainer = train_model(model, train_loader, val_loader)
results = analyze_model(trainer)
logger.info("=" * 50)
logger.info("Pipeline completed successfully")
logger.info("=" * 50)
return trainer, results
if __name__ == "__main__":
main()

View File

@@ -37,9 +37,5 @@ class HomographyLoss6(nn.Module):
return self.criterion(pred, target)
if __name__ == "__main__":
model = HomographyCNN6()
img1 = torch.randn(2, 3, 256, 256)
img2 = torch.randn(2, 3, 256, 256)
out = model(img1, img2)
print(f"Output shape: {out.shape}, mean: {out.mean():.3f}")
def count_parameters(model):
return sum(p.numel() for p in model.parameters())

View File

@@ -3,12 +3,13 @@ import os
import torch
import torch.nn as nn
import torch.optim as optim
from dataloader import create_data_loaders
from model import HomographyCNN6, HomographyLoss6
from utils import config
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm
from .dataloader import create_data_loaders
from .model import HomographyCNN6, HomographyLoss6, count_parameters
from .utils import config
class HomographyTrainer:
def __init__(self, model, train_loader, val_loader, device):
@@ -80,26 +81,3 @@ class HomographyTrainer:
torch.save(ckpt, os.path.join(ckpt_dir, f"checkpoint_epoch_{epoch}.pt"))
if is_best:
torch.save(ckpt, os.path.join(ckpt_dir, "best_model.pt"))
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
train_loader, val_loader = create_data_loaders(
root_dir=config["data_dir"],
batch_size=config["batch_size"],
train_split=config["train_split"],
num_workers=config["num_workers"],
image_size=config["image_size"],
)
model = HomographyCNN6(
input_channels=3,
backbone_name=config["backbone"],
pretrained=True,
dropout_rate=config["dropout_rate"]
)
print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")
trainer = HomographyTrainer(model, train_loader, val_loader, device)
trainer.train(config["epochs"])