48 lines
1000 B
Python
48 lines
1000 B
Python
# _schema.py
|
|
|
|
# === IMPORTS ===
|
|
import os
|
|
import random
|
|
import logging
|
|
from typing import Tuple
|
|
|
|
import cv2
|
|
import numpy as np
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.optim as optim
|
|
import matplotlib.pyplot as plt
|
|
from PIL import Image
|
|
from torch.utils.data import DataLoader, Dataset, Subset
|
|
from torch.utils.tensorboard import SummaryWriter
|
|
from torchvision import transforms, models
|
|
from tqdm import tqdm
|
|
|
|
# code: ./src/utils.py
|
|
# markdown
|
|
"""# SiaN Model"""
|
|
|
|
# code: ./src/dataloader.py
|
|
# markdown
|
|
"""Dataset for Google/Yandex image pairs with homography augmentation."""
|
|
|
|
# code: ./src/model.py
|
|
# markdown
|
|
"""HomographyCNN6 predicts 6 params: rx, ry, rz, tx, ty, scale."""
|
|
|
|
# code: ./src/train.py
|
|
# markdown
|
|
"""HomographyTrainer manages training loop with validation."""
|
|
|
|
# code: ./src/analyze.py
|
|
# markdown
|
|
"""Visualization and analysis of predictions."""
|
|
|
|
# code: ./src/main.py
|
|
# markdown
|
|
"""Run main() to execute the full pipeline."""
|
|
|
|
# inline:
|
|
if __name__ == "__main__":
|
|
main()
|