110 lines
2.5 KiB
Python
110 lines
2.5 KiB
Python
# _schema.py
|
|
|
|
# === IMPORTS ===
|
|
import os
|
|
import random
|
|
import logging
|
|
from typing import Tuple
|
|
|
|
import cv2
|
|
import numpy as np
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.optim as optim
|
|
import matplotlib.pyplot as plt
|
|
from PIL import Image
|
|
from torch.utils.data import DataLoader, Dataset, Subset
|
|
from torch.utils.tensorboard import SummaryWriter
|
|
from torchvision import transforms, models
|
|
from tqdm import tqdm
|
|
|
|
# markdown
|
|
"""# Configuration
|
|
|
|
Global settings for:
|
|
- Data paths and image parameters
|
|
- Training hyperparameters
|
|
- Model architecture options
|
|
|
|
Contains the `config` dictionary used across all modules."""
|
|
# code: ./src/utils.py
|
|
|
|
# markdown
|
|
"""## Dataset
|
|
|
|
Google/Yandex image pair loader with homography augmentation.
|
|
|
|
**Features:**
|
|
- Loads paired images from dual camera sources
|
|
- Applies random homography transformations
|
|
- Supports configurable train/val split
|
|
|
|
**Returns:**
|
|
- Batch dict with `google_img`, `yandex_img`, `homography_params`"""
|
|
# code: ./src/dataloader.py
|
|
|
|
# code: ./src/test_dataloader.py
|
|
|
|
|
|
# markdown
|
|
"""## Model
|
|
|
|
`HomographyCNN6` — CNN architecture for homography estimation.
|
|
|
|
**Output:** 6 parameters
|
|
- `rx, ry, rz` — rotation angles (radians)
|
|
- `tx, ty` — translation offsets
|
|
- `scale` — isotropic scale factor
|
|
|
|
**Architecture:**
|
|
- Dual-branch CNN (Google + Yandex images)
|
|
- Shared backbone (configurable: resnet18/34/50)
|
|
- Fusion head with dropout regularization"""
|
|
# code: ./src/model.py
|
|
|
|
# markdown
|
|
"""## Training
|
|
|
|
`HomographyTrainer` — training loop with validation and checkpointing.
|
|
|
|
**Features:**
|
|
- Epoch-based training with tqdm progress bar
|
|
- Adam optimizer with configurable LR
|
|
- Validation after each epoch
|
|
- Best model auto-save
|
|
- Periodic checkpoints (every N epochs via `save_every_n_epochs`)
|
|
|
|
**Checkpoint saving:**
|
|
- `best_model.pt` — lowest validation loss
|
|
- `checkpoint_epoch_N.pt` — periodic saves"""
|
|
# code: ./src/train.py
|
|
|
|
# markdown
|
|
"""## Analysis
|
|
|
|
Visualization and evaluation tools:
|
|
|
|
- Training metrics plots (loss curves)
|
|
- Prediction visualization on sample images
|
|
- Error analysis and statistics"""
|
|
# code: ./src/analyze.py
|
|
|
|
# markdown
|
|
"""## Main Pipeline
|
|
|
|
Executes the full training workflow:
|
|
1. Load dataset info
|
|
2. Create data loaders
|
|
3. Initialize model
|
|
4. Train with validation
|
|
5. Analyze and export results
|
|
|
|
**Outputs:**
|
|
- Model checkpoints in `runs/checkpoints/`
|
|
- TensorBoard logs in `runs/`
|
|
- Analysis plots"""
|
|
# code: ./src/main.py
|
|
|
|
# # shell:
|
|
# !zip artefacts.zip runs/checkpoints/best_model.pt runs/images/ runs/events.*
|