Files
autopilot/models/SiaN/_schema.py
2026-04-04 22:57:41 +03:00

107 lines
2.4 KiB
Python

# _schema.py
# === IMPORTS ===
import os
import random
import logging
from typing import Tuple
import cv2
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
from PIL import Image
from torch.utils.data import DataLoader, Dataset, Subset
from torch.utils.tensorboard import SummaryWriter
from torchvision import transforms, models
from tqdm import tqdm
# code: ./src/utils.py
# markdown
"""# Configuration
Global settings for:
- Data paths and image parameters
- Training hyperparameters
- Model architecture options
Contains the `config` dictionary used across all modules."""
# code: ./src/dataloader.py
# markdown
"""## Dataset
Google/Yandex image pair loader with homography augmentation.
**Features:**
- Loads paired images from dual camera sources
- Applies random homography transformations
- Supports configurable train/val split
**Returns:**
- Batch dict with `google_img`, `yandex_img`, `homography_params`"""
# code: ./src/model.py
# markdown
"""## Model
`HomographyCNN6` — CNN architecture for homography estimation.
**Output:** 6 parameters
- `rx, ry, rz` — rotation angles (radians)
- `tx, ty` — translation offsets
- `scale` — isotropic scale factor
**Architecture:**
- Dual-branch CNN (Google + Yandex images)
- Shared backbone (configurable: resnet18/34/50)
- Fusion head with dropout regularization"""
# code: ./src/train.py
# markdown
"""## Training
`HomographyTrainer` — training loop with validation and checkpointing.
**Features:**
- Epoch-based training with tqdm progress bar
- Adam optimizer with configurable LR
- Validation after each epoch
- Best model auto-save
- Periodic checkpoints (every N epochs via `save_every_n_epochs`)
**Checkpoint saving:**
- `best_model.pt` — lowest validation loss
- `checkpoint_epoch_N.pt` — periodic saves"""
# code: ./src/analyze.py
# markdown
"""## Analysis
Visualization and evaluation tools:
- Training metrics plots (loss curves)
- Prediction visualization on sample images
- Error analysis and statistics"""
# code: ./src/main.py
# markdown
"""## Main Pipeline
Executes the full training workflow:
1. Load dataset info
2. Create data loaders
3. Initialize model
4. Train with validation
5. Analyze and export results
**Outputs:**
- Model checkpoints in `runs/checkpoints/`
- TensorBoard logs in `runs/`
- Analysis plots"""
# # shell:
# !zip artefacts.zip runs/gan_training/checkpoints/best_model.pt