write SiaN

This commit is contained in:
2026-04-04 19:41:16 +03:00
parent 702c53caac
commit 4b398f6c9a
11 changed files with 737 additions and 3150 deletions

View File

@@ -1,434 +0,0 @@
import os
import random
from typing import Any, Dict, List, Optional, Tuple
import cv2
import numpy as np
import torch
from PIL import Image
from torch.utils.data import DataLoader, Dataset
class HomographyDataset(Dataset):
"""
Dataset for homography estimation between Yandex and Google map image pairs.
This dataset loads pairs of images (Yandex and Google maps) and provides
homography matrices for data augmentation and training.
"""
def __init__(
self,
root_dir: str,
transform=None,
augment: bool = True,
max_samples: Optional[int] = None,
image_size: Tuple[int, int] = (700, 700),
cache_homographies: bool = True,
):
"""
Initialize the HomographyDataset.
Args:
root_dir: Directory containing image pairs (format: {idx:04d}_google.png, {idx:04d}_yandex.png)
transform: Optional torchvision transforms to apply
augment: Whether to apply homography-based data augmentation
max_samples: Maximum number of samples to load (None for all)
image_size: Target size for images (height, width)
cache_homographies: Whether to cache generated homography matrices to disk
"""
self.root_dir = root_dir
self.transform = transform
self.augment = augment
self.image_size = image_size
self.cache_homographies = cache_homographies
# Find all image pairs
self.image_pairs = self._discover_image_pairs()
if max_samples is not None:
self.image_pairs = self.image_pairs[:max_samples]
print(f"Found {len(self.image_pairs)} image pairs in {root_dir}")
# Create directory for cached homographies if needed
if cache_homographies:
self.homography_cache_dir = os.path.join(root_dir, "homography_cache")
os.makedirs(self.homography_cache_dir, exist_ok=True)
def _discover_image_pairs(self) -> List[Dict[str, Any]]:
"""Discover all Google-Yandex image pairs in the dataset directory."""
image_pairs = []
# Get all Google images
google_files = [
f for f in os.listdir(self.root_dir) if f.endswith("_google.png")
]
for google_file in sorted(google_files):
# Extract index from filename
idx_str = google_file.split("_")[0]
try:
idx = int(idx_str)
except ValueError:
continue
# Check if corresponding Yandex image exists
yandex_file = f"{idx:04d}_yandex.png"
yandex_path = os.path.join(self.root_dir, yandex_file)
if os.path.exists(yandex_path):
image_pairs.append(
{
"idx": idx,
"google_path": os.path.join(self.root_dir, google_file),
"yandex_path": yandex_path,
}
)
return image_pairs
def __len__(self) -> int:
"""Return the number of image pairs in the dataset."""
return len(self.image_pairs)
def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
"""
Get a sample from the dataset.
Returns a dictionary with:
- 'google_img': Google map image tensor
- 'yandex_img': Yandex map image tensor
- 'homography': Ground truth homography matrix (3x3)
- 'idx': Sample index
"""
pair_info = self.image_pairs[idx]
# Load images
google_img = Image.open(pair_info["google_path"]).convert("RGB")
yandex_img = Image.open(pair_info["yandex_path"]).convert("RGB")
# Resize images to target size
google_img = google_img.resize(
(self.image_size[1], self.image_size[0]), Image.BILINEAR
)
yandex_img = yandex_img.resize(
(self.image_size[1], self.image_size[0]), Image.BILINEAR
)
# Get or generate homography matrix
homography_matrix = self._get_homography_matrix(pair_info["idx"])
# Apply data augmentation if enabled
if self.augment:
google_img, yandex_img, homography_matrix = self._apply_augmentation(
google_img, yandex_img, homography_matrix
)
# Convert images to tensors
if self.transform:
google_img = self.transform(google_img)
yandex_img = self.transform(yandex_img)
else:
# Default conversion to tensor
google_img = (
torch.from_numpy(np.array(google_img)).float().permute(2, 0, 1) / 255.0
)
yandex_img = (
torch.from_numpy(np.array(yandex_img)).float().permute(2, 0, 1) / 255.0
)
# Convert homography to tensor
homography_tensor = torch.from_numpy(homography_matrix).float()
return {
"google_img": google_img,
"yandex_img": yandex_img,
"homography": homography_tensor,
"idx": torch.tensor(pair_info["idx"], dtype=torch.long),
}
def _get_homography_matrix(self, idx: int) -> np.ndarray:
"""
Get homography matrix for a given index.
If cached homography exists, load it. Otherwise generate a new one.
"""
if self.cache_homographies:
cache_path = os.path.join(
self.homography_cache_dir, f"{idx:04d}_homography.npy"
)
if os.path.exists(cache_path):
return np.load(cache_path)
# Generate new homography matrix
homography_matrix = self.generate_random_homography()
# Cache if enabled
if self.cache_homographies:
np.save(cache_path, homography_matrix)
return homography_matrix
def generate_random_homography(self) -> np.ndarray:
"""
Generate a random homography matrix for data augmentation.
Returns:
np.ndarray: 3x3 homography matrix.
"""
# Generate random affine transformation parameters
angle = np.random.uniform(-30, 30) # rotation in degrees
scale = np.random.uniform(0.8, 1.2) # scaling factor
tx = np.random.uniform(-50, 50) # translation in x
ty = np.random.uniform(-50, 50) # translation in y
# Convert angle to radians
theta = np.radians(angle)
# Create affine transformation matrix
affine_matrix = np.array(
[
[scale * np.cos(theta), -scale * np.sin(theta), tx],
[scale * np.sin(theta), scale * np.cos(theta), ty],
[0, 0, 1],
]
)
# Add small perspective distortion
perspective = np.random.uniform(-0.001, 0.001, (2, 3))
perspective = np.vstack([perspective, [0, 0, 0]])
homography_matrix = affine_matrix + perspective
return homography_matrix
def _apply_augmentation(
self,
google_img: Image.Image,
yandex_img: Image.Image,
base_homography: np.ndarray,
) -> Tuple[Image.Image, Image.Image, np.ndarray]:
"""
Apply homography-based data augmentation to image pair.
Args:
google_img: Google map image
yandex_img: Yandex map image
base_homography: Base homography matrix
Returns:
Tuple of (augmented_google_img, augmented_yandex_img, augmented_homography)
"""
# Generate augmentation homography
aug_homography = self.generate_random_homography()
# Combine with base homography
combined_homography = aug_homography @ base_homography
# Apply augmentation to both images
google_aug = self._apply_homography_to_image(google_img, aug_homography)
yandex_aug = self._apply_homography_to_image(yandex_img, aug_homography)
return google_aug, yandex_aug, combined_homography
def _apply_homography_to_image(
self, img: Image.Image, homography: np.ndarray
) -> Image.Image:
"""
Apply homography transformation to a single image.
Args:
img: PIL Image to transform
homography: 3x3 homography matrix
Returns:
Transformed PIL Image
"""
# Convert to numpy array
img_np = np.array(img)
# Get image dimensions
h, w = img_np.shape[:2]
# Apply homography transformation
transformed = cv2.warpPerspective(
img_np,
homography,
(w, h),
flags=cv2.INTER_LINEAR,
borderMode=cv2.BORDER_REFLECT,
)
# Convert back to PIL Image
return Image.fromarray(transformed)
def get_sample_without_augmentation(self, idx: int) -> Dict[str, Any]:
"""
Get a sample without data augmentation.
Useful for visualization and evaluation.
"""
pair_info = self.image_pairs[idx]
# Load images
google_img = Image.open(pair_info["google_path"]).convert("RGB")
yandex_img = Image.open(pair_info["yandex_path"]).convert("RGB")
# Resize
google_img = google_img.resize(
(self.image_size[1], self.image_size[0]), Image.BILINEAR
)
yandex_img = yandex_img.resize(
(self.image_size[1], self.image_size[0]), Image.BILINEAR
)
# Get homography matrix
homography_matrix = self._get_homography_matrix(pair_info["idx"])
return {
"google_img": google_img,
"yandex_img": yandex_img,
"homography": homography_matrix,
"idx": pair_info["idx"],
"google_path": pair_info["google_path"],
"yandex_path": pair_info["yandex_path"],
}
def create_data_loaders(
root_dir: str,
batch_size: int = 32,
train_split: float = 0.8,
num_workers: int = 4,
image_size: Tuple[int, int] = (256, 256),
augment_train: bool = True,
augment_val: bool = False,
) -> Tuple[DataLoader, DataLoader]:
"""
Create train and validation data loaders for homography estimation.
Args:
root_dir: Directory containing image pairs
batch_size: Batch size for data loaders
train_split: Fraction of data to use for training
num_workers: Number of worker processes for data loading
image_size: Target image size (height, width)
augment_train: Whether to augment training data
augment_val: Whether to augment validation data
Returns:
Tuple of (train_loader, val_loader)
"""
from torchvision import transforms
# Define transforms
transform = transforms.Compose(
[
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
]
)
# Create full dataset
full_dataset = HomographyDataset(
root_dir=root_dir,
transform=transform,
augment=False, # We'll handle augmentation separately
image_size=image_size,
cache_homographies=True,
)
# Split dataset
dataset_size = len(full_dataset)
train_size = int(train_split * dataset_size)
val_size = dataset_size - train_size
# Create indices for splitting
indices = list(range(dataset_size))
random.shuffle(indices)
train_indices = indices[:train_size]
val_indices = indices[train_size:]
# Create subset samplers
from torch.utils.data import Subset
train_dataset = Subset(full_dataset, train_indices)
val_dataset = Subset(full_dataset, val_indices)
# Apply augmentation by overriding __getitem__ for train dataset
if augment_train:
class AugmentedSubset(Subset):
def __getitem__(self, idx):
sample = self.dataset[self.indices[idx]]
# Apply augmentation
google_img = sample["google_img"]
yandex_img = sample["yandex_img"]
homography = sample["homography"]
# Generate augmentation homography
aug_homography = torch.from_numpy(
full_dataset.generate_random_homography()
).float()
# Combine homographies
combined_homography = aug_homography @ homography
# Apply augmentation (simplified - in practice would warp images)
# For now, we just return the combined homography
return {
"google_img": google_img,
"yandex_img": yandex_img,
"homography": combined_homography,
"idx": sample["idx"],
}
train_dataset = AugmentedSubset(full_dataset, train_indices)
# Create data loaders
train_loader = DataLoader(
train_dataset,
batch_size=batch_size,
shuffle=True,
num_workers=num_workers,
pin_memory=True,
)
val_loader = DataLoader(
val_dataset,
batch_size=batch_size,
shuffle=False,
num_workers=num_workers,
pin_memory=True,
)
return train_loader, val_loader
if __name__ == "__main__":
# Example usage
dataset = HomographyDataset(
root_dir=r"C:\Users\admin\Projects\autopilot\datasets\ya_go_maps\images",
augment=True,
image_size=(256, 256),
)
print(f"Dataset size: {len(dataset)}")
# Get a sample
sample = dataset[0]
print(f"Sample keys: {list(sample.keys())}")
print(f"Google image shape: {sample['google_img'].shape}")
print(f"Yandex image shape: {sample['yandex_img'].shape}")
print(f"Homography shape: {sample['homography'].shape}")
# Create data loaders
train_loader, val_loader = create_data_loaders(
root_dir=r"C:\Users\admin\Projects\autopilot\datasets\ya_go_maps\images",
batch_size=16,
train_split=0.8,
)
print(f"Train batches: {len(train_loader)}")
print(f"Val batches: {len(val_loader)}")

152
models/SiaN/model.py Normal file
View File

@@ -0,0 +1,152 @@
from typing import Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import models
class HomographyCNN(nn.Module):
"""
Model for estimating homography matrix (3x3) between two images.
"""
def __init__(
self,
input_channels: int = 3,
backbone_name: str = "resnet18",
pretrained: bool = True,
dropout_rate: float = 0.3,
use_batch_norm: bool = True,
):
super().__init__()
self.input_channels = input_channels
self.backbone_name = backbone_name
self.pretrained = pretrained
self.dropout_rate = dropout_rate
self.use_batch_norm = use_batch_norm
backbone = self._create_backbone(backbone_name, pretrained)
self.feature_dim = backbone.fc.in_features
backbone.fc = nn.Identity()
self.backbone = backbone
compare_input_dim = self.feature_dim * 4
layers = [
nn.Linear(compare_input_dim, 512),
nn.BatchNorm1d(512) if use_batch_norm else nn.Identity(),
nn.ReLU(inplace=True),
nn.Dropout(dropout_rate),
nn.Linear(512, 256),
nn.BatchNorm1d(256) if use_batch_norm else nn.Identity(),
nn.ReLU(inplace=True),
nn.Dropout(dropout_rate),
nn.Linear(256, 9),
]
self.head = nn.Sequential(*layers)
def _create_backbone(self, name: str, pretrained: bool) -> nn.Module:
name = name.lower()
if name == "resnet18":
model = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1 if pretrained else None)
elif name == "resnet34":
model = models.resnet34(weights=models.ResNet34_Weights.IMAGENET1K_V1 if pretrained else None)
else:
raise ValueError(f"Unsupported backbone: {name}")
if self.input_channels != 3:
old_conv = model.conv1
model.conv1 = nn.Conv2d(
self.input_channels,
old_conv.out_channels,
kernel_size=old_conv.kernel_size,
stride=old_conv.stride,
padding=old_conv.padding,
bias=old_conv.bias is not None,
)
return model
def _extract_features(self, x: torch.Tensor) -> torch.Tensor:
return self.backbone(x)
def forward(self, img1: torch.Tensor, img2: torch.Tensor) -> torch.Tensor:
f1 = self._extract_features(img1)
f2 = self._extract_features(img2)
diff = torch.abs(f1 - f2)
prod = f1 * f2
combined = torch.cat([f1, f2, diff, prod], dim=1)
h = self.head(combined)
h = h.view(-1, 3, 3)
return h
def predict_homography(self, img1: torch.Tensor, img2: torch.Tensor) -> torch.Tensor:
was_training = self.training
self.eval()
with torch.no_grad():
h = self.forward(img1, img2)
if was_training:
self.train()
return h
class HomographyLoss(nn.Module):
def __init__(self):
super().__init__()
self.criterion = nn.MSELoss()
def forward(self, pred_homography: torch.Tensor, target_homography: torch.Tensor) -> torch.Tensor:
return self.criterion(pred_homography, target_homography)
def create_homography_model(
model_type: str = "backbone",
input_size: Tuple[int, int] = (256, 256),
**kwargs,
) -> nn.Module:
if model_type == "backbone":
return HomographyCNN(**kwargs)
else:
raise ValueError(f"Unknown model type: {model_type}")
if __name__ == "__main__":
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
model = HomographyCNN(
input_channels=3,
backbone_name="resnet18",
pretrained=True,
dropout_rate=0.3,
use_batch_norm=True,
).to(device)
print(f"Model created with {sum(p.numel() for p in model.parameters()):,} parameters")
batch_size = 4
height, width = 256, 256
img1 = torch.randn(batch_size, 3, height, width).to(device)
img2 = torch.randn(batch_size, 3, height, width).to(device)
print("\nTesting forward pass...")
output = model(img1, img2)
print(f"Output shape: {output.shape}")
print("\nTesting prediction...")
pred = model.predict_homography(img1, img2)
print(f"Prediction shape: {pred.shape}")
print("\nTesting loss function...")
target = torch.eye(3).unsqueeze(0).expand(batch_size, -1, -1).to(device)
loss_fn = HomographyLoss().to(device)
loss = loss_fn(output, target)
print(f"Loss value: {loss.item():.6f}")
print("\nAll tests completed successfully!")

View File

@@ -417,7 +417,379 @@
"id": "2dad9a5f",
"metadata": {},
"outputs": [],
"source": []
"source": [
"from typing import Tuple\n",
"\n",
"import torch\n",
"import torch.nn as nn\n",
"import torch.nn.functional as F\n",
"from torchvision import models\n",
"\n",
"\n",
"class HomographyCNN(nn.Module):\n",
" \"\"\"\n",
" Model for estimating homography matrix (3x3) between two images.\n",
" \"\"\"\n",
"\n",
" def __init__(\n",
" self,\n",
" input_channels: int = 3,\n",
" backbone_name: str = \"resnet18\",\n",
" pretrained: bool = True,\n",
" dropout_rate: float = 0.3,\n",
" use_batch_norm: bool = True,\n",
" ):\n",
" super().__init__()\n",
"\n",
" self.input_channels = input_channels\n",
" self.backbone_name = backbone_name\n",
" self.pretrained = pretrained\n",
" self.dropout_rate = dropout_rate\n",
" self.use_batch_norm = use_batch_norm\n",
"\n",
" backbone = self._create_backbone(backbone_name, pretrained)\n",
"\n",
" self.feature_dim = backbone.fc.in_features\n",
" backbone.fc = nn.Identity()\n",
" self.backbone = backbone\n",
"\n",
" compare_input_dim = self.feature_dim * 4\n",
"\n",
" layers = [\n",
" nn.Linear(compare_input_dim, 512),\n",
" nn.BatchNorm1d(512) if use_batch_norm else nn.Identity(),\n",
" nn.ReLU(inplace=True),\n",
" nn.Dropout(dropout_rate),\n",
"\n",
" nn.Linear(512, 256),\n",
" nn.BatchNorm1d(256) if use_batch_norm else nn.Identity(),\n",
" nn.ReLU(inplace=True),\n",
" nn.Dropout(dropout_rate),\n",
"\n",
" nn.Linear(256, 9),\n",
" ]\n",
" self.head = nn.Sequential(*layers)\n",
"\n",
" def _create_backbone(self, name: str, pretrained: bool) -> nn.Module:\n",
" name = name.lower()\n",
" if name == \"resnet18\":\n",
" model = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1 if pretrained else None)\n",
" elif name == \"resnet34\":\n",
" model = models.resnet34(weights=models.ResNet34_Weights.IMAGENET1K_V1 if pretrained else None)\n",
" else:\n",
" raise ValueError(f\"Unsupported backbone: {name}\")\n",
" if self.input_channels != 3:\n",
" old_conv = model.conv1\n",
" model.conv1 = nn.Conv2d(\n",
" self.input_channels,\n",
" old_conv.out_channels,\n",
" kernel_size=old_conv.kernel_size,\n",
" stride=old_conv.stride,\n",
" padding=old_conv.padding,\n",
" bias=old_conv.bias is not None,\n",
" )\n",
" return model\n",
"\n",
" def _extract_features(self, x: torch.Tensor) -> torch.Tensor:\n",
" return self.backbone(x)\n",
"\n",
" def forward(self, img1: torch.Tensor, img2: torch.Tensor) -> torch.Tensor:\n",
" f1 = self._extract_features(img1)\n",
" f2 = self._extract_features(img2)\n",
"\n",
" diff = torch.abs(f1 - f2)\n",
" prod = f1 * f2\n",
" combined = torch.cat([f1, f2, diff, prod], dim=1)\n",
"\n",
" h = self.head(combined)\n",
" h = h.view(-1, 3, 3)\n",
" return h\n",
"\n",
" def predict_homography(self, img1: torch.Tensor, img2: torch.Tensor) -> torch.Tensor:\n",
" was_training = self.training\n",
" self.eval()\n",
" with torch.no_grad():\n",
" h = self.forward(img1, img2)\n",
" if was_training:\n",
" self.train()\n",
" return h\n",
"\n",
"\n",
"class HomographyLoss(nn.Module):\n",
" def __init__(self):\n",
" super().__init__()\n",
" self.criterion = nn.MSELoss()\n",
"\n",
" def forward(self, pred_homography: torch.Tensor, target_homography: torch.Tensor) -> torch.Tensor:\n",
" return self.criterion(pred_homography, target_homography)\n",
"\n",
"\n",
"def create_homography_model(\n",
" model_type: str = \"backbone\",\n",
" input_size: Tuple[int, int] = (256, 256),\n",
" **kwargs,\n",
") -> nn.Module:\n",
" if model_type == \"backbone\":\n",
" return HomographyCNN(**kwargs)\n",
" else:\n",
" raise ValueError(f\"Unknown model type: {model_type}\")\n",
"\n",
"\n",
"if __name__ == \"__main__\":\n",
" device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
" print(f\"Using device: {device}\")\n",
"\n",
" model = HomographyCNN(\n",
" input_channels=3,\n",
" backbone_name=\"resnet18\",\n",
" pretrained=True,\n",
" dropout_rate=0.3,\n",
" use_batch_norm=True,\n",
" ).to(device)\n",
"\n",
" print(f\"Model created with {sum(p.numel() for p in model.parameters()):,} parameters\")\n",
"\n",
" batch_size = 4\n",
" height, width = 256, 256\n",
"\n",
" img1 = torch.randn(batch_size, 3, height, width).to(device)\n",
" img2 = torch.randn(batch_size, 3, height, width).to(device)\n",
"\n",
" print(\"\\nTesting forward pass...\")\n",
" output = model(img1, img2)\n",
" print(f\"Output shape: {output.shape}\")\n",
"\n",
" print(\"\\nTesting prediction...\")\n",
" pred = model.predict_homography(img1, img2)\n",
" print(f\"Prediction shape: {pred.shape}\")\n",
"\n",
" print(\"\\nTesting loss function...\")\n",
" target = torch.eye(3).unsqueeze(0).expand(batch_size, -1, -1).to(device)\n",
" loss_fn = HomographyLoss().to(device)\n",
" loss = loss_fn(output, target)\n",
" print(f\"Loss value: {loss.item():.6f}\")\n",
"\n",
" print(\"\\nAll tests completed successfully!\")\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "e573b201",
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"import time\n",
"from datetime import datetime\n",
"\n",
"import torch\n",
"import torch.nn as nn\n",
"import torch.optim as optim\n",
"from torch.utils.data import DataLoader\n",
"from torch.utils.tensorboard import SummaryWriter\n",
"from tqdm import tqdm\n",
"\n",
"\n",
"class HomographyTrainer:\n",
" def __init__(\n",
" self,\n",
" model: nn.Module,\n",
" train_loader: DataLoader,\n",
" val_loader: DataLoader,\n",
" device: torch.device,\n",
" config: dict,\n",
" ):\n",
" self.model = model.to(device)\n",
" self.train_loader = train_loader\n",
" self.val_loader = val_loader\n",
" self.device = device\n",
" self.config = config\n",
"\n",
" self.criterion = HomographyLoss()\n",
" self.optimizer = optim.Adam(\n",
" model.parameters(),\n",
" lr=config.get(\"learning_rate\", 2e-4),\n",
" betas=(config.get(\"beta1\", 0.5), config.get(\"beta2\", 0.999)),\n",
" )\n",
"\n",
" self.writer = None\n",
" self.best_val_loss = float(\"inf\")\n",
" self.epochs_without_improvement = 0\n",
"\n",
" def train_epoch(self, epoch: int) -> dict:\n",
" self.model.train()\n",
" total_loss = 0\n",
" total_samples = 0\n",
"\n",
" pbar = tqdm(self.train_loader, desc=f\"Epoch {epoch}\")\n",
" for batch_idx, batch in enumerate(pbar):\n",
" google_img = batch[\"google_img\"].to(self.device)\n",
" yandex_img = batch[\"yandex_img\"].to(self.device)\n",
" target = batch[\"homography\"].to(self.device)\n",
"\n",
" self.optimizer.zero_grad()\n",
"\n",
" output = self.model(google_img, yandex_img)\n",
" loss = self.criterion(output, target)\n",
"\n",
" loss.backward()\n",
" self.optimizer.step()\n",
"\n",
" total_loss += loss.item() * google_img.size(0)\n",
" total_samples += google_img.size(0)\n",
"\n",
" if batch_idx % self.config.get(\"log_interval\", 10) == 0:\n",
" pbar.set_postfix({\"loss\": loss.item()})\n",
"\n",
" if self.writer:\n",
" self.writer.add_scalar(\n",
" \"train/loss\",\n",
" loss.item(),\n",
" epoch * len(self.train_loader) + batch_idx,\n",
" )\n",
"\n",
" avg_loss = total_loss / total_samples\n",
" return {\"loss\": avg_loss}\n",
"\n",
" def validate(self) -> dict:\n",
" self.model.eval()\n",
" total_loss = 0\n",
" total_samples = 0\n",
"\n",
" with torch.no_grad():\n",
" for batch in tqdm(self.val_loader, desc=\"Validation\"):\n",
" google_img = batch[\"google_img\"].to(self.device)\n",
" yandex_img = batch[\"yandex_img\"].to(self.device)\n",
" target = batch[\"homography\"].to(self.device)\n",
"\n",
" output = self.model(google_img, yandex_img)\n",
" loss = self.criterion(output, target)\n",
"\n",
" total_loss += loss.item() * google_img.size(0)\n",
" total_samples += google_img.size(0)\n",
"\n",
" avg_loss = total_loss / total_samples\n",
" return {\"loss\": avg_loss}\n",
"\n",
" def train(self, num_epochs: int):\n",
" log_dir = self.config.get(\"output_dir\", \"runs/homography\")\n",
" os.makedirs(log_dir, exist_ok=True)\n",
" self.writer = SummaryWriter(log_dir)\n",
"\n",
" print(f\"Starting training for {num_epochs} epochs\")\n",
" print(f\"Logging to: {log_dir}\")\n",
"\n",
" for epoch in range(1, num_epochs + 1):\n",
" print(f\"\\nEpoch {epoch}/{num_epochs}\")\n",
"\n",
" train_metrics = self.train_epoch(epoch)\n",
" val_metrics = self.validate()\n",
"\n",
" print(f\"Train Loss: {train_metrics['loss']:.4f}\")\n",
" print(f\"Val Loss: {val_metrics['loss']:.4f}\")\n",
"\n",
" if self.writer:\n",
" self.writer.add_scalar(\"epoch/train_loss\", train_metrics[\"loss\"], epoch)\n",
" self.writer.add_scalar(\"epoch/val_loss\", val_metrics[\"loss\"], epoch)\n",
"\n",
" if val_metrics[\"loss\"] < self.best_val_loss:\n",
" self.best_val_loss = val_metrics[\"loss\"]\n",
" self.epochs_without_improvement = 0\n",
" self.save_checkpoint(epoch, val_metrics[\"loss\"], is_best=True)\n",
" print(f\"New best model saved with val loss: {val_metrics['loss']:.4f}\")\n",
" else:\n",
" self.epochs_without_improvement += 1\n",
" self.save_checkpoint(epoch, val_metrics[\"loss\"], is_best=False)\n",
"\n",
" patience = self.config.get(\"early_stopping_patience\", 20)\n",
" if self.epochs_without_improvement >= patience:\n",
" print(f\"Early stopping triggered after {patience} epochs without improvement\")\n",
" break\n",
"\n",
" self.writer.close()\n",
"\n",
" def save_checkpoint(self, epoch: int, val_loss: float, is_best: bool = False):\n",
" checkpoint_dir = os.path.join(\n",
" self.config.get(\"output_dir\", \"runs/homography\"), \"checkpoints\"\n",
" )\n",
" os.makedirs(checkpoint_dir, exist_ok=True)\n",
"\n",
" checkpoint = {\n",
" \"epoch\": epoch,\n",
" \"model_state_dict\": self.model.state_dict(),\n",
" \"optimizer_state_dict\": self.optimizer.state_dict(),\n",
" \"val_loss\": val_loss,\n",
" \"config\": self.config,\n",
" }\n",
"\n",
" checkpoint_path = os.path.join(checkpoint_dir, f\"checkpoint_epoch_{epoch}.pt\")\n",
" torch.save(checkpoint, checkpoint_path)\n",
"\n",
" if is_best:\n",
" best_path = os.path.join(checkpoint_dir, \"best_model.pt\")\n",
" torch.save(checkpoint, best_path)\n",
"\n",
" def load_checkpoint(self, checkpoint_path: str):\n",
" checkpoint = torch.load(checkpoint_path, map_location=self.device)\n",
" self.model.load_state_dict(checkpoint[\"model_state_dict\"])\n",
" self.optimizer.load_state_dict(checkpoint[\"optimizer_state_dict\"])\n",
" return checkpoint[\"epoch\"], checkpoint[\"val_loss\"]\n",
"\n",
"\n",
"def main():\n",
" config_dict = config.copy()\n",
"\n",
" if isinstance(config_dict.get(\"image_size\"), list):\n",
" config_dict[\"image_size\"] = tuple(config_dict[\"image_size\"])\n",
"\n",
" device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
" print(f\"Using device: {device}\")\n",
"\n",
" print(\"Creating data loaders...\")\n",
" train_loader, val_loader = create_data_loaders(\n",
" root_dir=config_dict[\"data_dir\"],\n",
" batch_size=config_dict[\"batch_size\"],\n",
" train_split=config_dict[\"train_split\"],\n",
" num_workers=config_dict[\"num_workers\"],\n",
" image_size=config_dict[\"image_size\"],\n",
" augment_train=True,\n",
" augment_val=False,\n",
" device=device,\n",
" )\n",
"\n",
" print(f\"Train batches: {len(train_loader)}\")\n",
" print(f\"Val batches: {len(val_loader)}\")\n",
"\n",
" print(\"Creating model...\")\n",
" model = create_homography_model(\n",
" model_type=\"backbone\",\n",
" input_channels=3,\n",
" backbone_name=\"resnet18\",\n",
" pretrained=True,\n",
" dropout_rate=0.3,\n",
" use_batch_norm=True,\n",
" )\n",
"\n",
" print(f\"Model parameters: {sum(p.numel() for p in model.parameters()):,}\")\n",
"\n",
" trainer = HomographyTrainer(\n",
" model=model,\n",
" train_loader=train_loader,\n",
" val_loader=val_loader,\n",
" device=device,\n",
" config=config_dict,\n",
" )\n",
"\n",
" print(\"Starting training...\")\n",
" trainer.train(config_dict[\"epochs\"])\n",
"\n",
" print(\"Training completed!\")\n",
"\n",
"\n",
"if __name__ == \"__main__\":\n",
" main()\n"
]
}
],
"metadata": {

212
models/SiaN/train.py Normal file
View File

@@ -0,0 +1,212 @@
import os
import time
from datetime import datetime
import torch
import torch.nn as nn
import torch.optim as optim
from dataloader import config, create_data_loaders
from model import HomographyCNN, HomographyLoss, create_homography_model
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm
class HomographyTrainer:
def __init__(
self,
model: nn.Module,
train_loader: DataLoader,
val_loader: DataLoader,
device: torch.device,
config: dict,
):
self.model = model.to(device)
self.train_loader = train_loader
self.val_loader = val_loader
self.device = device
self.config = config
self.criterion = HomographyLoss()
self.optimizer = optim.Adam(
model.parameters(),
lr=config.get("learning_rate", 2e-4),
betas=(config.get("beta1", 0.5), config.get("beta2", 0.999)),
)
self.writer = None
self.best_val_loss = float("inf")
self.epochs_without_improvement = 0
def train_epoch(self, epoch: int) -> dict:
self.model.train()
total_loss = 0
total_samples = 0
pbar = tqdm(self.train_loader, desc=f"Epoch {epoch}")
for batch_idx, batch in enumerate(pbar):
google_img = batch["google_img"].to(self.device)
yandex_img = batch["yandex_img"].to(self.device)
target = batch["homography"].to(self.device)
self.optimizer.zero_grad()
output = self.model(google_img, yandex_img)
loss = self.criterion(output, target)
loss.backward()
self.optimizer.step()
total_loss += loss.item() * google_img.size(0)
total_samples += google_img.size(0)
if batch_idx % self.config.get("log_interval", 10) == 0:
pbar.set_postfix({"loss": loss.item()})
if self.writer:
self.writer.add_scalar(
"train/loss",
loss.item(),
epoch * len(self.train_loader) + batch_idx,
)
avg_loss = total_loss / total_samples
return {"loss": avg_loss}
def validate(self) -> dict:
self.model.eval()
total_loss = 0
total_samples = 0
with torch.no_grad():
for batch in tqdm(self.val_loader, desc="Validation"):
google_img = batch["google_img"].to(self.device)
yandex_img = batch["yandex_img"].to(self.device)
target = batch["homography"].to(self.device)
output = self.model(google_img, yandex_img)
loss = self.criterion(output, target)
total_loss += loss.item() * google_img.size(0)
total_samples += google_img.size(0)
avg_loss = total_loss / total_samples
return {"loss": avg_loss}
def train(self, num_epochs: int):
log_dir = self.config.get("output_dir", "runs/homography")
os.makedirs(log_dir, exist_ok=True)
self.writer = SummaryWriter(log_dir)
print(f"Starting training for {num_epochs} epochs")
print(f"Logging to: {log_dir}")
for epoch in range(1, num_epochs + 1):
print(f"\nEpoch {epoch}/{num_epochs}")
train_metrics = self.train_epoch(epoch)
val_metrics = self.validate()
print(f"Train Loss: {train_metrics['loss']:.4f}")
print(f"Val Loss: {val_metrics['loss']:.4f}")
if self.writer:
self.writer.add_scalar("epoch/train_loss", train_metrics["loss"], epoch)
self.writer.add_scalar("epoch/val_loss", val_metrics["loss"], epoch)
if val_metrics["loss"] < self.best_val_loss:
self.best_val_loss = val_metrics["loss"]
self.epochs_without_improvement = 0
self.save_checkpoint(epoch, val_metrics["loss"], is_best=True)
print(f"New best model saved with val loss: {val_metrics['loss']:.4f}")
else:
self.epochs_without_improvement += 1
self.save_checkpoint(epoch, val_metrics["loss"], is_best=False)
patience = self.config.get("early_stopping_patience", 20)
if self.epochs_without_improvement >= patience:
print(f"Early stopping triggered after {patience} epochs without improvement")
break
self.writer.close()
def save_checkpoint(self, epoch: int, val_loss: float, is_best: bool = False):
checkpoint_dir = os.path.join(
self.config.get("output_dir", "runs/homography"), "checkpoints"
)
os.makedirs(checkpoint_dir, exist_ok=True)
checkpoint = {
"epoch": epoch,
"model_state_dict": self.model.state_dict(),
"optimizer_state_dict": self.optimizer.state_dict(),
"val_loss": val_loss,
"config": self.config,
}
checkpoint_path = os.path.join(checkpoint_dir, f"checkpoint_epoch_{epoch}.pt")
torch.save(checkpoint, checkpoint_path)
if is_best:
best_path = os.path.join(checkpoint_dir, "best_model.pt")
torch.save(checkpoint, best_path)
def load_checkpoint(self, checkpoint_path: str):
checkpoint = torch.load(checkpoint_path, map_location=self.device)
self.model.load_state_dict(checkpoint["model_state_dict"])
self.optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
return checkpoint["epoch"], checkpoint["val_loss"]
def main():
config_dict = config.copy()
if isinstance(config_dict.get("image_size"), list):
config_dict["image_size"] = tuple(config_dict["image_size"])
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
print("Creating data loaders...")
train_loader, val_loader = create_data_loaders(
root_dir=config_dict["data_dir"],
batch_size=config_dict["batch_size"],
train_split=config_dict["train_split"],
num_workers=config_dict["num_workers"],
image_size=config_dict["image_size"],
augment_train=True,
augment_val=False,
device=device,
)
print(f"Train batches: {len(train_loader)}")
print(f"Val batches: {len(val_loader)}")
print("Creating model...")
model = create_homography_model(
model_type="backbone",
input_channels=3,
backbone_name="resnet18",
pretrained=True,
dropout_rate=0.3,
use_batch_norm=True,
)
print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")
trainer = HomographyTrainer(
model=model,
train_loader=train_loader,
val_loader=val_loader,
device=device,
config=config_dict,
)
print("Starting training...")
trainer.train(config_dict["epochs"])
print("Training completed!")
if __name__ == "__main__":
main()