write SiaN
This commit is contained in:
@@ -1,434 +0,0 @@
|
||||
import os
|
||||
import random
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
from torch.utils.data import DataLoader, Dataset
|
||||
|
||||
|
||||
class HomographyDataset(Dataset):
|
||||
"""
|
||||
Dataset for homography estimation between Yandex and Google map image pairs.
|
||||
|
||||
This dataset loads pairs of images (Yandex and Google maps) and provides
|
||||
homography matrices for data augmentation and training.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
root_dir: str,
|
||||
transform=None,
|
||||
augment: bool = True,
|
||||
max_samples: Optional[int] = None,
|
||||
image_size: Tuple[int, int] = (700, 700),
|
||||
cache_homographies: bool = True,
|
||||
):
|
||||
"""
|
||||
Initialize the HomographyDataset.
|
||||
|
||||
Args:
|
||||
root_dir: Directory containing image pairs (format: {idx:04d}_google.png, {idx:04d}_yandex.png)
|
||||
transform: Optional torchvision transforms to apply
|
||||
augment: Whether to apply homography-based data augmentation
|
||||
max_samples: Maximum number of samples to load (None for all)
|
||||
image_size: Target size for images (height, width)
|
||||
cache_homographies: Whether to cache generated homography matrices to disk
|
||||
"""
|
||||
self.root_dir = root_dir
|
||||
self.transform = transform
|
||||
self.augment = augment
|
||||
self.image_size = image_size
|
||||
self.cache_homographies = cache_homographies
|
||||
|
||||
# Find all image pairs
|
||||
self.image_pairs = self._discover_image_pairs()
|
||||
|
||||
if max_samples is not None:
|
||||
self.image_pairs = self.image_pairs[:max_samples]
|
||||
|
||||
print(f"Found {len(self.image_pairs)} image pairs in {root_dir}")
|
||||
|
||||
# Create directory for cached homographies if needed
|
||||
if cache_homographies:
|
||||
self.homography_cache_dir = os.path.join(root_dir, "homography_cache")
|
||||
os.makedirs(self.homography_cache_dir, exist_ok=True)
|
||||
|
||||
def _discover_image_pairs(self) -> List[Dict[str, Any]]:
|
||||
"""Discover all Google-Yandex image pairs in the dataset directory."""
|
||||
image_pairs = []
|
||||
|
||||
# Get all Google images
|
||||
google_files = [
|
||||
f for f in os.listdir(self.root_dir) if f.endswith("_google.png")
|
||||
]
|
||||
|
||||
for google_file in sorted(google_files):
|
||||
# Extract index from filename
|
||||
idx_str = google_file.split("_")[0]
|
||||
try:
|
||||
idx = int(idx_str)
|
||||
except ValueError:
|
||||
continue
|
||||
|
||||
# Check if corresponding Yandex image exists
|
||||
yandex_file = f"{idx:04d}_yandex.png"
|
||||
yandex_path = os.path.join(self.root_dir, yandex_file)
|
||||
|
||||
if os.path.exists(yandex_path):
|
||||
image_pairs.append(
|
||||
{
|
||||
"idx": idx,
|
||||
"google_path": os.path.join(self.root_dir, google_file),
|
||||
"yandex_path": yandex_path,
|
||||
}
|
||||
)
|
||||
|
||||
return image_pairs
|
||||
|
||||
def __len__(self) -> int:
|
||||
"""Return the number of image pairs in the dataset."""
|
||||
return len(self.image_pairs)
|
||||
|
||||
def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
|
||||
"""
|
||||
Get a sample from the dataset.
|
||||
|
||||
Returns a dictionary with:
|
||||
- 'google_img': Google map image tensor
|
||||
- 'yandex_img': Yandex map image tensor
|
||||
- 'homography': Ground truth homography matrix (3x3)
|
||||
- 'idx': Sample index
|
||||
"""
|
||||
pair_info = self.image_pairs[idx]
|
||||
|
||||
# Load images
|
||||
google_img = Image.open(pair_info["google_path"]).convert("RGB")
|
||||
yandex_img = Image.open(pair_info["yandex_path"]).convert("RGB")
|
||||
|
||||
# Resize images to target size
|
||||
google_img = google_img.resize(
|
||||
(self.image_size[1], self.image_size[0]), Image.BILINEAR
|
||||
)
|
||||
yandex_img = yandex_img.resize(
|
||||
(self.image_size[1], self.image_size[0]), Image.BILINEAR
|
||||
)
|
||||
|
||||
# Get or generate homography matrix
|
||||
homography_matrix = self._get_homography_matrix(pair_info["idx"])
|
||||
|
||||
# Apply data augmentation if enabled
|
||||
if self.augment:
|
||||
google_img, yandex_img, homography_matrix = self._apply_augmentation(
|
||||
google_img, yandex_img, homography_matrix
|
||||
)
|
||||
|
||||
# Convert images to tensors
|
||||
if self.transform:
|
||||
google_img = self.transform(google_img)
|
||||
yandex_img = self.transform(yandex_img)
|
||||
else:
|
||||
# Default conversion to tensor
|
||||
google_img = (
|
||||
torch.from_numpy(np.array(google_img)).float().permute(2, 0, 1) / 255.0
|
||||
)
|
||||
yandex_img = (
|
||||
torch.from_numpy(np.array(yandex_img)).float().permute(2, 0, 1) / 255.0
|
||||
)
|
||||
|
||||
# Convert homography to tensor
|
||||
homography_tensor = torch.from_numpy(homography_matrix).float()
|
||||
|
||||
return {
|
||||
"google_img": google_img,
|
||||
"yandex_img": yandex_img,
|
||||
"homography": homography_tensor,
|
||||
"idx": torch.tensor(pair_info["idx"], dtype=torch.long),
|
||||
}
|
||||
|
||||
def _get_homography_matrix(self, idx: int) -> np.ndarray:
|
||||
"""
|
||||
Get homography matrix for a given index.
|
||||
|
||||
If cached homography exists, load it. Otherwise generate a new one.
|
||||
"""
|
||||
if self.cache_homographies:
|
||||
cache_path = os.path.join(
|
||||
self.homography_cache_dir, f"{idx:04d}_homography.npy"
|
||||
)
|
||||
if os.path.exists(cache_path):
|
||||
return np.load(cache_path)
|
||||
|
||||
# Generate new homography matrix
|
||||
homography_matrix = self.generate_random_homography()
|
||||
|
||||
# Cache if enabled
|
||||
if self.cache_homographies:
|
||||
np.save(cache_path, homography_matrix)
|
||||
|
||||
return homography_matrix
|
||||
|
||||
def generate_random_homography(self) -> np.ndarray:
|
||||
"""
|
||||
Generate a random homography matrix for data augmentation.
|
||||
|
||||
Returns:
|
||||
np.ndarray: 3x3 homography matrix.
|
||||
"""
|
||||
# Generate random affine transformation parameters
|
||||
angle = np.random.uniform(-30, 30) # rotation in degrees
|
||||
scale = np.random.uniform(0.8, 1.2) # scaling factor
|
||||
tx = np.random.uniform(-50, 50) # translation in x
|
||||
ty = np.random.uniform(-50, 50) # translation in y
|
||||
|
||||
# Convert angle to radians
|
||||
theta = np.radians(angle)
|
||||
|
||||
# Create affine transformation matrix
|
||||
affine_matrix = np.array(
|
||||
[
|
||||
[scale * np.cos(theta), -scale * np.sin(theta), tx],
|
||||
[scale * np.sin(theta), scale * np.cos(theta), ty],
|
||||
[0, 0, 1],
|
||||
]
|
||||
)
|
||||
|
||||
# Add small perspective distortion
|
||||
perspective = np.random.uniform(-0.001, 0.001, (2, 3))
|
||||
perspective = np.vstack([perspective, [0, 0, 0]])
|
||||
|
||||
homography_matrix = affine_matrix + perspective
|
||||
|
||||
return homography_matrix
|
||||
|
||||
def _apply_augmentation(
|
||||
self,
|
||||
google_img: Image.Image,
|
||||
yandex_img: Image.Image,
|
||||
base_homography: np.ndarray,
|
||||
) -> Tuple[Image.Image, Image.Image, np.ndarray]:
|
||||
"""
|
||||
Apply homography-based data augmentation to image pair.
|
||||
|
||||
Args:
|
||||
google_img: Google map image
|
||||
yandex_img: Yandex map image
|
||||
base_homography: Base homography matrix
|
||||
|
||||
Returns:
|
||||
Tuple of (augmented_google_img, augmented_yandex_img, augmented_homography)
|
||||
"""
|
||||
# Generate augmentation homography
|
||||
aug_homography = self.generate_random_homography()
|
||||
|
||||
# Combine with base homography
|
||||
combined_homography = aug_homography @ base_homography
|
||||
|
||||
# Apply augmentation to both images
|
||||
google_aug = self._apply_homography_to_image(google_img, aug_homography)
|
||||
yandex_aug = self._apply_homography_to_image(yandex_img, aug_homography)
|
||||
|
||||
return google_aug, yandex_aug, combined_homography
|
||||
|
||||
def _apply_homography_to_image(
|
||||
self, img: Image.Image, homography: np.ndarray
|
||||
) -> Image.Image:
|
||||
"""
|
||||
Apply homography transformation to a single image.
|
||||
|
||||
Args:
|
||||
img: PIL Image to transform
|
||||
homography: 3x3 homography matrix
|
||||
|
||||
Returns:
|
||||
Transformed PIL Image
|
||||
"""
|
||||
# Convert to numpy array
|
||||
img_np = np.array(img)
|
||||
|
||||
# Get image dimensions
|
||||
h, w = img_np.shape[:2]
|
||||
|
||||
# Apply homography transformation
|
||||
transformed = cv2.warpPerspective(
|
||||
img_np,
|
||||
homography,
|
||||
(w, h),
|
||||
flags=cv2.INTER_LINEAR,
|
||||
borderMode=cv2.BORDER_REFLECT,
|
||||
)
|
||||
|
||||
# Convert back to PIL Image
|
||||
return Image.fromarray(transformed)
|
||||
|
||||
def get_sample_without_augmentation(self, idx: int) -> Dict[str, Any]:
|
||||
"""
|
||||
Get a sample without data augmentation.
|
||||
|
||||
Useful for visualization and evaluation.
|
||||
"""
|
||||
pair_info = self.image_pairs[idx]
|
||||
|
||||
# Load images
|
||||
google_img = Image.open(pair_info["google_path"]).convert("RGB")
|
||||
yandex_img = Image.open(pair_info["yandex_path"]).convert("RGB")
|
||||
|
||||
# Resize
|
||||
google_img = google_img.resize(
|
||||
(self.image_size[1], self.image_size[0]), Image.BILINEAR
|
||||
)
|
||||
yandex_img = yandex_img.resize(
|
||||
(self.image_size[1], self.image_size[0]), Image.BILINEAR
|
||||
)
|
||||
|
||||
# Get homography matrix
|
||||
homography_matrix = self._get_homography_matrix(pair_info["idx"])
|
||||
|
||||
return {
|
||||
"google_img": google_img,
|
||||
"yandex_img": yandex_img,
|
||||
"homography": homography_matrix,
|
||||
"idx": pair_info["idx"],
|
||||
"google_path": pair_info["google_path"],
|
||||
"yandex_path": pair_info["yandex_path"],
|
||||
}
|
||||
|
||||
|
||||
def create_data_loaders(
|
||||
root_dir: str,
|
||||
batch_size: int = 32,
|
||||
train_split: float = 0.8,
|
||||
num_workers: int = 4,
|
||||
image_size: Tuple[int, int] = (256, 256),
|
||||
augment_train: bool = True,
|
||||
augment_val: bool = False,
|
||||
) -> Tuple[DataLoader, DataLoader]:
|
||||
"""
|
||||
Create train and validation data loaders for homography estimation.
|
||||
|
||||
Args:
|
||||
root_dir: Directory containing image pairs
|
||||
batch_size: Batch size for data loaders
|
||||
train_split: Fraction of data to use for training
|
||||
num_workers: Number of worker processes for data loading
|
||||
image_size: Target image size (height, width)
|
||||
augment_train: Whether to augment training data
|
||||
augment_val: Whether to augment validation data
|
||||
|
||||
Returns:
|
||||
Tuple of (train_loader, val_loader)
|
||||
"""
|
||||
from torchvision import transforms
|
||||
|
||||
# Define transforms
|
||||
transform = transforms.Compose(
|
||||
[
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
||||
]
|
||||
)
|
||||
|
||||
# Create full dataset
|
||||
full_dataset = HomographyDataset(
|
||||
root_dir=root_dir,
|
||||
transform=transform,
|
||||
augment=False, # We'll handle augmentation separately
|
||||
image_size=image_size,
|
||||
cache_homographies=True,
|
||||
)
|
||||
|
||||
# Split dataset
|
||||
dataset_size = len(full_dataset)
|
||||
train_size = int(train_split * dataset_size)
|
||||
val_size = dataset_size - train_size
|
||||
|
||||
# Create indices for splitting
|
||||
indices = list(range(dataset_size))
|
||||
random.shuffle(indices)
|
||||
train_indices = indices[:train_size]
|
||||
val_indices = indices[train_size:]
|
||||
|
||||
# Create subset samplers
|
||||
from torch.utils.data import Subset
|
||||
|
||||
train_dataset = Subset(full_dataset, train_indices)
|
||||
val_dataset = Subset(full_dataset, val_indices)
|
||||
|
||||
# Apply augmentation by overriding __getitem__ for train dataset
|
||||
if augment_train:
|
||||
|
||||
class AugmentedSubset(Subset):
|
||||
def __getitem__(self, idx):
|
||||
sample = self.dataset[self.indices[idx]]
|
||||
# Apply augmentation
|
||||
google_img = sample["google_img"]
|
||||
yandex_img = sample["yandex_img"]
|
||||
homography = sample["homography"]
|
||||
|
||||
# Generate augmentation homography
|
||||
aug_homography = torch.from_numpy(
|
||||
full_dataset.generate_random_homography()
|
||||
).float()
|
||||
|
||||
# Combine homographies
|
||||
combined_homography = aug_homography @ homography
|
||||
|
||||
# Apply augmentation (simplified - in practice would warp images)
|
||||
# For now, we just return the combined homography
|
||||
return {
|
||||
"google_img": google_img,
|
||||
"yandex_img": yandex_img,
|
||||
"homography": combined_homography,
|
||||
"idx": sample["idx"],
|
||||
}
|
||||
|
||||
train_dataset = AugmentedSubset(full_dataset, train_indices)
|
||||
|
||||
# Create data loaders
|
||||
train_loader = DataLoader(
|
||||
train_dataset,
|
||||
batch_size=batch_size,
|
||||
shuffle=True,
|
||||
num_workers=num_workers,
|
||||
pin_memory=True,
|
||||
)
|
||||
|
||||
val_loader = DataLoader(
|
||||
val_dataset,
|
||||
batch_size=batch_size,
|
||||
shuffle=False,
|
||||
num_workers=num_workers,
|
||||
pin_memory=True,
|
||||
)
|
||||
|
||||
return train_loader, val_loader
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Example usage
|
||||
dataset = HomographyDataset(
|
||||
root_dir=r"C:\Users\admin\Projects\autopilot\datasets\ya_go_maps\images",
|
||||
augment=True,
|
||||
image_size=(256, 256),
|
||||
)
|
||||
|
||||
print(f"Dataset size: {len(dataset)}")
|
||||
|
||||
# Get a sample
|
||||
sample = dataset[0]
|
||||
print(f"Sample keys: {list(sample.keys())}")
|
||||
print(f"Google image shape: {sample['google_img'].shape}")
|
||||
print(f"Yandex image shape: {sample['yandex_img'].shape}")
|
||||
print(f"Homography shape: {sample['homography'].shape}")
|
||||
|
||||
# Create data loaders
|
||||
train_loader, val_loader = create_data_loaders(
|
||||
root_dir=r"C:\Users\admin\Projects\autopilot\datasets\ya_go_maps\images",
|
||||
batch_size=16,
|
||||
train_split=0.8,
|
||||
)
|
||||
|
||||
print(f"Train batches: {len(train_loader)}")
|
||||
print(f"Val batches: {len(val_loader)}")
|
||||
152
models/SiaN/model.py
Normal file
152
models/SiaN/model.py
Normal file
@@ -0,0 +1,152 @@
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torchvision import models
|
||||
|
||||
|
||||
class HomographyCNN(nn.Module):
|
||||
"""
|
||||
Model for estimating homography matrix (3x3) between two images.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_channels: int = 3,
|
||||
backbone_name: str = "resnet18",
|
||||
pretrained: bool = True,
|
||||
dropout_rate: float = 0.3,
|
||||
use_batch_norm: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.input_channels = input_channels
|
||||
self.backbone_name = backbone_name
|
||||
self.pretrained = pretrained
|
||||
self.dropout_rate = dropout_rate
|
||||
self.use_batch_norm = use_batch_norm
|
||||
|
||||
backbone = self._create_backbone(backbone_name, pretrained)
|
||||
|
||||
self.feature_dim = backbone.fc.in_features
|
||||
backbone.fc = nn.Identity()
|
||||
self.backbone = backbone
|
||||
|
||||
compare_input_dim = self.feature_dim * 4
|
||||
|
||||
layers = [
|
||||
nn.Linear(compare_input_dim, 512),
|
||||
nn.BatchNorm1d(512) if use_batch_norm else nn.Identity(),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Dropout(dropout_rate),
|
||||
|
||||
nn.Linear(512, 256),
|
||||
nn.BatchNorm1d(256) if use_batch_norm else nn.Identity(),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Dropout(dropout_rate),
|
||||
|
||||
nn.Linear(256, 9),
|
||||
]
|
||||
self.head = nn.Sequential(*layers)
|
||||
|
||||
def _create_backbone(self, name: str, pretrained: bool) -> nn.Module:
|
||||
name = name.lower()
|
||||
if name == "resnet18":
|
||||
model = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1 if pretrained else None)
|
||||
elif name == "resnet34":
|
||||
model = models.resnet34(weights=models.ResNet34_Weights.IMAGENET1K_V1 if pretrained else None)
|
||||
else:
|
||||
raise ValueError(f"Unsupported backbone: {name}")
|
||||
if self.input_channels != 3:
|
||||
old_conv = model.conv1
|
||||
model.conv1 = nn.Conv2d(
|
||||
self.input_channels,
|
||||
old_conv.out_channels,
|
||||
kernel_size=old_conv.kernel_size,
|
||||
stride=old_conv.stride,
|
||||
padding=old_conv.padding,
|
||||
bias=old_conv.bias is not None,
|
||||
)
|
||||
return model
|
||||
|
||||
def _extract_features(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return self.backbone(x)
|
||||
|
||||
def forward(self, img1: torch.Tensor, img2: torch.Tensor) -> torch.Tensor:
|
||||
f1 = self._extract_features(img1)
|
||||
f2 = self._extract_features(img2)
|
||||
|
||||
diff = torch.abs(f1 - f2)
|
||||
prod = f1 * f2
|
||||
combined = torch.cat([f1, f2, diff, prod], dim=1)
|
||||
|
||||
h = self.head(combined)
|
||||
h = h.view(-1, 3, 3)
|
||||
return h
|
||||
|
||||
def predict_homography(self, img1: torch.Tensor, img2: torch.Tensor) -> torch.Tensor:
|
||||
was_training = self.training
|
||||
self.eval()
|
||||
with torch.no_grad():
|
||||
h = self.forward(img1, img2)
|
||||
if was_training:
|
||||
self.train()
|
||||
return h
|
||||
|
||||
|
||||
class HomographyLoss(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.criterion = nn.MSELoss()
|
||||
|
||||
def forward(self, pred_homography: torch.Tensor, target_homography: torch.Tensor) -> torch.Tensor:
|
||||
return self.criterion(pred_homography, target_homography)
|
||||
|
||||
|
||||
def create_homography_model(
|
||||
model_type: str = "backbone",
|
||||
input_size: Tuple[int, int] = (256, 256),
|
||||
**kwargs,
|
||||
) -> nn.Module:
|
||||
if model_type == "backbone":
|
||||
return HomographyCNN(**kwargs)
|
||||
else:
|
||||
raise ValueError(f"Unknown model type: {model_type}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
print(f"Using device: {device}")
|
||||
|
||||
model = HomographyCNN(
|
||||
input_channels=3,
|
||||
backbone_name="resnet18",
|
||||
pretrained=True,
|
||||
dropout_rate=0.3,
|
||||
use_batch_norm=True,
|
||||
).to(device)
|
||||
|
||||
print(f"Model created with {sum(p.numel() for p in model.parameters()):,} parameters")
|
||||
|
||||
batch_size = 4
|
||||
height, width = 256, 256
|
||||
|
||||
img1 = torch.randn(batch_size, 3, height, width).to(device)
|
||||
img2 = torch.randn(batch_size, 3, height, width).to(device)
|
||||
|
||||
print("\nTesting forward pass...")
|
||||
output = model(img1, img2)
|
||||
print(f"Output shape: {output.shape}")
|
||||
|
||||
print("\nTesting prediction...")
|
||||
pred = model.predict_homography(img1, img2)
|
||||
print(f"Prediction shape: {pred.shape}")
|
||||
|
||||
print("\nTesting loss function...")
|
||||
target = torch.eye(3).unsqueeze(0).expand(batch_size, -1, -1).to(device)
|
||||
loss_fn = HomographyLoss().to(device)
|
||||
loss = loss_fn(output, target)
|
||||
print(f"Loss value: {loss.item():.6f}")
|
||||
|
||||
print("\nAll tests completed successfully!")
|
||||
@@ -417,7 +417,379 @@
|
||||
"id": "2dad9a5f",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
"source": [
|
||||
"from typing import Tuple\n",
|
||||
"\n",
|
||||
"import torch\n",
|
||||
"import torch.nn as nn\n",
|
||||
"import torch.nn.functional as F\n",
|
||||
"from torchvision import models\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"class HomographyCNN(nn.Module):\n",
|
||||
" \"\"\"\n",
|
||||
" Model for estimating homography matrix (3x3) between two images.\n",
|
||||
" \"\"\"\n",
|
||||
"\n",
|
||||
" def __init__(\n",
|
||||
" self,\n",
|
||||
" input_channels: int = 3,\n",
|
||||
" backbone_name: str = \"resnet18\",\n",
|
||||
" pretrained: bool = True,\n",
|
||||
" dropout_rate: float = 0.3,\n",
|
||||
" use_batch_norm: bool = True,\n",
|
||||
" ):\n",
|
||||
" super().__init__()\n",
|
||||
"\n",
|
||||
" self.input_channels = input_channels\n",
|
||||
" self.backbone_name = backbone_name\n",
|
||||
" self.pretrained = pretrained\n",
|
||||
" self.dropout_rate = dropout_rate\n",
|
||||
" self.use_batch_norm = use_batch_norm\n",
|
||||
"\n",
|
||||
" backbone = self._create_backbone(backbone_name, pretrained)\n",
|
||||
"\n",
|
||||
" self.feature_dim = backbone.fc.in_features\n",
|
||||
" backbone.fc = nn.Identity()\n",
|
||||
" self.backbone = backbone\n",
|
||||
"\n",
|
||||
" compare_input_dim = self.feature_dim * 4\n",
|
||||
"\n",
|
||||
" layers = [\n",
|
||||
" nn.Linear(compare_input_dim, 512),\n",
|
||||
" nn.BatchNorm1d(512) if use_batch_norm else nn.Identity(),\n",
|
||||
" nn.ReLU(inplace=True),\n",
|
||||
" nn.Dropout(dropout_rate),\n",
|
||||
"\n",
|
||||
" nn.Linear(512, 256),\n",
|
||||
" nn.BatchNorm1d(256) if use_batch_norm else nn.Identity(),\n",
|
||||
" nn.ReLU(inplace=True),\n",
|
||||
" nn.Dropout(dropout_rate),\n",
|
||||
"\n",
|
||||
" nn.Linear(256, 9),\n",
|
||||
" ]\n",
|
||||
" self.head = nn.Sequential(*layers)\n",
|
||||
"\n",
|
||||
" def _create_backbone(self, name: str, pretrained: bool) -> nn.Module:\n",
|
||||
" name = name.lower()\n",
|
||||
" if name == \"resnet18\":\n",
|
||||
" model = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1 if pretrained else None)\n",
|
||||
" elif name == \"resnet34\":\n",
|
||||
" model = models.resnet34(weights=models.ResNet34_Weights.IMAGENET1K_V1 if pretrained else None)\n",
|
||||
" else:\n",
|
||||
" raise ValueError(f\"Unsupported backbone: {name}\")\n",
|
||||
" if self.input_channels != 3:\n",
|
||||
" old_conv = model.conv1\n",
|
||||
" model.conv1 = nn.Conv2d(\n",
|
||||
" self.input_channels,\n",
|
||||
" old_conv.out_channels,\n",
|
||||
" kernel_size=old_conv.kernel_size,\n",
|
||||
" stride=old_conv.stride,\n",
|
||||
" padding=old_conv.padding,\n",
|
||||
" bias=old_conv.bias is not None,\n",
|
||||
" )\n",
|
||||
" return model\n",
|
||||
"\n",
|
||||
" def _extract_features(self, x: torch.Tensor) -> torch.Tensor:\n",
|
||||
" return self.backbone(x)\n",
|
||||
"\n",
|
||||
" def forward(self, img1: torch.Tensor, img2: torch.Tensor) -> torch.Tensor:\n",
|
||||
" f1 = self._extract_features(img1)\n",
|
||||
" f2 = self._extract_features(img2)\n",
|
||||
"\n",
|
||||
" diff = torch.abs(f1 - f2)\n",
|
||||
" prod = f1 * f2\n",
|
||||
" combined = torch.cat([f1, f2, diff, prod], dim=1)\n",
|
||||
"\n",
|
||||
" h = self.head(combined)\n",
|
||||
" h = h.view(-1, 3, 3)\n",
|
||||
" return h\n",
|
||||
"\n",
|
||||
" def predict_homography(self, img1: torch.Tensor, img2: torch.Tensor) -> torch.Tensor:\n",
|
||||
" was_training = self.training\n",
|
||||
" self.eval()\n",
|
||||
" with torch.no_grad():\n",
|
||||
" h = self.forward(img1, img2)\n",
|
||||
" if was_training:\n",
|
||||
" self.train()\n",
|
||||
" return h\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"class HomographyLoss(nn.Module):\n",
|
||||
" def __init__(self):\n",
|
||||
" super().__init__()\n",
|
||||
" self.criterion = nn.MSELoss()\n",
|
||||
"\n",
|
||||
" def forward(self, pred_homography: torch.Tensor, target_homography: torch.Tensor) -> torch.Tensor:\n",
|
||||
" return self.criterion(pred_homography, target_homography)\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def create_homography_model(\n",
|
||||
" model_type: str = \"backbone\",\n",
|
||||
" input_size: Tuple[int, int] = (256, 256),\n",
|
||||
" **kwargs,\n",
|
||||
") -> nn.Module:\n",
|
||||
" if model_type == \"backbone\":\n",
|
||||
" return HomographyCNN(**kwargs)\n",
|
||||
" else:\n",
|
||||
" raise ValueError(f\"Unknown model type: {model_type}\")\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"if __name__ == \"__main__\":\n",
|
||||
" device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
|
||||
" print(f\"Using device: {device}\")\n",
|
||||
"\n",
|
||||
" model = HomographyCNN(\n",
|
||||
" input_channels=3,\n",
|
||||
" backbone_name=\"resnet18\",\n",
|
||||
" pretrained=True,\n",
|
||||
" dropout_rate=0.3,\n",
|
||||
" use_batch_norm=True,\n",
|
||||
" ).to(device)\n",
|
||||
"\n",
|
||||
" print(f\"Model created with {sum(p.numel() for p in model.parameters()):,} parameters\")\n",
|
||||
"\n",
|
||||
" batch_size = 4\n",
|
||||
" height, width = 256, 256\n",
|
||||
"\n",
|
||||
" img1 = torch.randn(batch_size, 3, height, width).to(device)\n",
|
||||
" img2 = torch.randn(batch_size, 3, height, width).to(device)\n",
|
||||
"\n",
|
||||
" print(\"\\nTesting forward pass...\")\n",
|
||||
" output = model(img1, img2)\n",
|
||||
" print(f\"Output shape: {output.shape}\")\n",
|
||||
"\n",
|
||||
" print(\"\\nTesting prediction...\")\n",
|
||||
" pred = model.predict_homography(img1, img2)\n",
|
||||
" print(f\"Prediction shape: {pred.shape}\")\n",
|
||||
"\n",
|
||||
" print(\"\\nTesting loss function...\")\n",
|
||||
" target = torch.eye(3).unsqueeze(0).expand(batch_size, -1, -1).to(device)\n",
|
||||
" loss_fn = HomographyLoss().to(device)\n",
|
||||
" loss = loss_fn(output, target)\n",
|
||||
" print(f\"Loss value: {loss.item():.6f}\")\n",
|
||||
"\n",
|
||||
" print(\"\\nAll tests completed successfully!\")\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "e573b201",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import os\n",
|
||||
"import time\n",
|
||||
"from datetime import datetime\n",
|
||||
"\n",
|
||||
"import torch\n",
|
||||
"import torch.nn as nn\n",
|
||||
"import torch.optim as optim\n",
|
||||
"from torch.utils.data import DataLoader\n",
|
||||
"from torch.utils.tensorboard import SummaryWriter\n",
|
||||
"from tqdm import tqdm\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"class HomographyTrainer:\n",
|
||||
" def __init__(\n",
|
||||
" self,\n",
|
||||
" model: nn.Module,\n",
|
||||
" train_loader: DataLoader,\n",
|
||||
" val_loader: DataLoader,\n",
|
||||
" device: torch.device,\n",
|
||||
" config: dict,\n",
|
||||
" ):\n",
|
||||
" self.model = model.to(device)\n",
|
||||
" self.train_loader = train_loader\n",
|
||||
" self.val_loader = val_loader\n",
|
||||
" self.device = device\n",
|
||||
" self.config = config\n",
|
||||
"\n",
|
||||
" self.criterion = HomographyLoss()\n",
|
||||
" self.optimizer = optim.Adam(\n",
|
||||
" model.parameters(),\n",
|
||||
" lr=config.get(\"learning_rate\", 2e-4),\n",
|
||||
" betas=(config.get(\"beta1\", 0.5), config.get(\"beta2\", 0.999)),\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
" self.writer = None\n",
|
||||
" self.best_val_loss = float(\"inf\")\n",
|
||||
" self.epochs_without_improvement = 0\n",
|
||||
"\n",
|
||||
" def train_epoch(self, epoch: int) -> dict:\n",
|
||||
" self.model.train()\n",
|
||||
" total_loss = 0\n",
|
||||
" total_samples = 0\n",
|
||||
"\n",
|
||||
" pbar = tqdm(self.train_loader, desc=f\"Epoch {epoch}\")\n",
|
||||
" for batch_idx, batch in enumerate(pbar):\n",
|
||||
" google_img = batch[\"google_img\"].to(self.device)\n",
|
||||
" yandex_img = batch[\"yandex_img\"].to(self.device)\n",
|
||||
" target = batch[\"homography\"].to(self.device)\n",
|
||||
"\n",
|
||||
" self.optimizer.zero_grad()\n",
|
||||
"\n",
|
||||
" output = self.model(google_img, yandex_img)\n",
|
||||
" loss = self.criterion(output, target)\n",
|
||||
"\n",
|
||||
" loss.backward()\n",
|
||||
" self.optimizer.step()\n",
|
||||
"\n",
|
||||
" total_loss += loss.item() * google_img.size(0)\n",
|
||||
" total_samples += google_img.size(0)\n",
|
||||
"\n",
|
||||
" if batch_idx % self.config.get(\"log_interval\", 10) == 0:\n",
|
||||
" pbar.set_postfix({\"loss\": loss.item()})\n",
|
||||
"\n",
|
||||
" if self.writer:\n",
|
||||
" self.writer.add_scalar(\n",
|
||||
" \"train/loss\",\n",
|
||||
" loss.item(),\n",
|
||||
" epoch * len(self.train_loader) + batch_idx,\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
" avg_loss = total_loss / total_samples\n",
|
||||
" return {\"loss\": avg_loss}\n",
|
||||
"\n",
|
||||
" def validate(self) -> dict:\n",
|
||||
" self.model.eval()\n",
|
||||
" total_loss = 0\n",
|
||||
" total_samples = 0\n",
|
||||
"\n",
|
||||
" with torch.no_grad():\n",
|
||||
" for batch in tqdm(self.val_loader, desc=\"Validation\"):\n",
|
||||
" google_img = batch[\"google_img\"].to(self.device)\n",
|
||||
" yandex_img = batch[\"yandex_img\"].to(self.device)\n",
|
||||
" target = batch[\"homography\"].to(self.device)\n",
|
||||
"\n",
|
||||
" output = self.model(google_img, yandex_img)\n",
|
||||
" loss = self.criterion(output, target)\n",
|
||||
"\n",
|
||||
" total_loss += loss.item() * google_img.size(0)\n",
|
||||
" total_samples += google_img.size(0)\n",
|
||||
"\n",
|
||||
" avg_loss = total_loss / total_samples\n",
|
||||
" return {\"loss\": avg_loss}\n",
|
||||
"\n",
|
||||
" def train(self, num_epochs: int):\n",
|
||||
" log_dir = self.config.get(\"output_dir\", \"runs/homography\")\n",
|
||||
" os.makedirs(log_dir, exist_ok=True)\n",
|
||||
" self.writer = SummaryWriter(log_dir)\n",
|
||||
"\n",
|
||||
" print(f\"Starting training for {num_epochs} epochs\")\n",
|
||||
" print(f\"Logging to: {log_dir}\")\n",
|
||||
"\n",
|
||||
" for epoch in range(1, num_epochs + 1):\n",
|
||||
" print(f\"\\nEpoch {epoch}/{num_epochs}\")\n",
|
||||
"\n",
|
||||
" train_metrics = self.train_epoch(epoch)\n",
|
||||
" val_metrics = self.validate()\n",
|
||||
"\n",
|
||||
" print(f\"Train Loss: {train_metrics['loss']:.4f}\")\n",
|
||||
" print(f\"Val Loss: {val_metrics['loss']:.4f}\")\n",
|
||||
"\n",
|
||||
" if self.writer:\n",
|
||||
" self.writer.add_scalar(\"epoch/train_loss\", train_metrics[\"loss\"], epoch)\n",
|
||||
" self.writer.add_scalar(\"epoch/val_loss\", val_metrics[\"loss\"], epoch)\n",
|
||||
"\n",
|
||||
" if val_metrics[\"loss\"] < self.best_val_loss:\n",
|
||||
" self.best_val_loss = val_metrics[\"loss\"]\n",
|
||||
" self.epochs_without_improvement = 0\n",
|
||||
" self.save_checkpoint(epoch, val_metrics[\"loss\"], is_best=True)\n",
|
||||
" print(f\"New best model saved with val loss: {val_metrics['loss']:.4f}\")\n",
|
||||
" else:\n",
|
||||
" self.epochs_without_improvement += 1\n",
|
||||
" self.save_checkpoint(epoch, val_metrics[\"loss\"], is_best=False)\n",
|
||||
"\n",
|
||||
" patience = self.config.get(\"early_stopping_patience\", 20)\n",
|
||||
" if self.epochs_without_improvement >= patience:\n",
|
||||
" print(f\"Early stopping triggered after {patience} epochs without improvement\")\n",
|
||||
" break\n",
|
||||
"\n",
|
||||
" self.writer.close()\n",
|
||||
"\n",
|
||||
" def save_checkpoint(self, epoch: int, val_loss: float, is_best: bool = False):\n",
|
||||
" checkpoint_dir = os.path.join(\n",
|
||||
" self.config.get(\"output_dir\", \"runs/homography\"), \"checkpoints\"\n",
|
||||
" )\n",
|
||||
" os.makedirs(checkpoint_dir, exist_ok=True)\n",
|
||||
"\n",
|
||||
" checkpoint = {\n",
|
||||
" \"epoch\": epoch,\n",
|
||||
" \"model_state_dict\": self.model.state_dict(),\n",
|
||||
" \"optimizer_state_dict\": self.optimizer.state_dict(),\n",
|
||||
" \"val_loss\": val_loss,\n",
|
||||
" \"config\": self.config,\n",
|
||||
" }\n",
|
||||
"\n",
|
||||
" checkpoint_path = os.path.join(checkpoint_dir, f\"checkpoint_epoch_{epoch}.pt\")\n",
|
||||
" torch.save(checkpoint, checkpoint_path)\n",
|
||||
"\n",
|
||||
" if is_best:\n",
|
||||
" best_path = os.path.join(checkpoint_dir, \"best_model.pt\")\n",
|
||||
" torch.save(checkpoint, best_path)\n",
|
||||
"\n",
|
||||
" def load_checkpoint(self, checkpoint_path: str):\n",
|
||||
" checkpoint = torch.load(checkpoint_path, map_location=self.device)\n",
|
||||
" self.model.load_state_dict(checkpoint[\"model_state_dict\"])\n",
|
||||
" self.optimizer.load_state_dict(checkpoint[\"optimizer_state_dict\"])\n",
|
||||
" return checkpoint[\"epoch\"], checkpoint[\"val_loss\"]\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def main():\n",
|
||||
" config_dict = config.copy()\n",
|
||||
"\n",
|
||||
" if isinstance(config_dict.get(\"image_size\"), list):\n",
|
||||
" config_dict[\"image_size\"] = tuple(config_dict[\"image_size\"])\n",
|
||||
"\n",
|
||||
" device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
|
||||
" print(f\"Using device: {device}\")\n",
|
||||
"\n",
|
||||
" print(\"Creating data loaders...\")\n",
|
||||
" train_loader, val_loader = create_data_loaders(\n",
|
||||
" root_dir=config_dict[\"data_dir\"],\n",
|
||||
" batch_size=config_dict[\"batch_size\"],\n",
|
||||
" train_split=config_dict[\"train_split\"],\n",
|
||||
" num_workers=config_dict[\"num_workers\"],\n",
|
||||
" image_size=config_dict[\"image_size\"],\n",
|
||||
" augment_train=True,\n",
|
||||
" augment_val=False,\n",
|
||||
" device=device,\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
" print(f\"Train batches: {len(train_loader)}\")\n",
|
||||
" print(f\"Val batches: {len(val_loader)}\")\n",
|
||||
"\n",
|
||||
" print(\"Creating model...\")\n",
|
||||
" model = create_homography_model(\n",
|
||||
" model_type=\"backbone\",\n",
|
||||
" input_channels=3,\n",
|
||||
" backbone_name=\"resnet18\",\n",
|
||||
" pretrained=True,\n",
|
||||
" dropout_rate=0.3,\n",
|
||||
" use_batch_norm=True,\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
" print(f\"Model parameters: {sum(p.numel() for p in model.parameters()):,}\")\n",
|
||||
"\n",
|
||||
" trainer = HomographyTrainer(\n",
|
||||
" model=model,\n",
|
||||
" train_loader=train_loader,\n",
|
||||
" val_loader=val_loader,\n",
|
||||
" device=device,\n",
|
||||
" config=config_dict,\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
" print(\"Starting training...\")\n",
|
||||
" trainer.train(config_dict[\"epochs\"])\n",
|
||||
"\n",
|
||||
" print(\"Training completed!\")\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"if __name__ == \"__main__\":\n",
|
||||
" main()\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
|
||||
212
models/SiaN/train.py
Normal file
212
models/SiaN/train.py
Normal file
@@ -0,0 +1,212 @@
|
||||
import os
|
||||
import time
|
||||
from datetime import datetime
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
from dataloader import config, create_data_loaders
|
||||
from model import HomographyCNN, HomographyLoss, create_homography_model
|
||||
from torch.utils.data import DataLoader
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
class HomographyTrainer:
|
||||
def __init__(
|
||||
self,
|
||||
model: nn.Module,
|
||||
train_loader: DataLoader,
|
||||
val_loader: DataLoader,
|
||||
device: torch.device,
|
||||
config: dict,
|
||||
):
|
||||
self.model = model.to(device)
|
||||
self.train_loader = train_loader
|
||||
self.val_loader = val_loader
|
||||
self.device = device
|
||||
self.config = config
|
||||
|
||||
self.criterion = HomographyLoss()
|
||||
self.optimizer = optim.Adam(
|
||||
model.parameters(),
|
||||
lr=config.get("learning_rate", 2e-4),
|
||||
betas=(config.get("beta1", 0.5), config.get("beta2", 0.999)),
|
||||
)
|
||||
|
||||
self.writer = None
|
||||
self.best_val_loss = float("inf")
|
||||
self.epochs_without_improvement = 0
|
||||
|
||||
def train_epoch(self, epoch: int) -> dict:
|
||||
self.model.train()
|
||||
total_loss = 0
|
||||
total_samples = 0
|
||||
|
||||
pbar = tqdm(self.train_loader, desc=f"Epoch {epoch}")
|
||||
for batch_idx, batch in enumerate(pbar):
|
||||
google_img = batch["google_img"].to(self.device)
|
||||
yandex_img = batch["yandex_img"].to(self.device)
|
||||
target = batch["homography"].to(self.device)
|
||||
|
||||
self.optimizer.zero_grad()
|
||||
|
||||
output = self.model(google_img, yandex_img)
|
||||
loss = self.criterion(output, target)
|
||||
|
||||
loss.backward()
|
||||
self.optimizer.step()
|
||||
|
||||
total_loss += loss.item() * google_img.size(0)
|
||||
total_samples += google_img.size(0)
|
||||
|
||||
if batch_idx % self.config.get("log_interval", 10) == 0:
|
||||
pbar.set_postfix({"loss": loss.item()})
|
||||
|
||||
if self.writer:
|
||||
self.writer.add_scalar(
|
||||
"train/loss",
|
||||
loss.item(),
|
||||
epoch * len(self.train_loader) + batch_idx,
|
||||
)
|
||||
|
||||
avg_loss = total_loss / total_samples
|
||||
return {"loss": avg_loss}
|
||||
|
||||
def validate(self) -> dict:
|
||||
self.model.eval()
|
||||
total_loss = 0
|
||||
total_samples = 0
|
||||
|
||||
with torch.no_grad():
|
||||
for batch in tqdm(self.val_loader, desc="Validation"):
|
||||
google_img = batch["google_img"].to(self.device)
|
||||
yandex_img = batch["yandex_img"].to(self.device)
|
||||
target = batch["homography"].to(self.device)
|
||||
|
||||
output = self.model(google_img, yandex_img)
|
||||
loss = self.criterion(output, target)
|
||||
|
||||
total_loss += loss.item() * google_img.size(0)
|
||||
total_samples += google_img.size(0)
|
||||
|
||||
avg_loss = total_loss / total_samples
|
||||
return {"loss": avg_loss}
|
||||
|
||||
def train(self, num_epochs: int):
|
||||
log_dir = self.config.get("output_dir", "runs/homography")
|
||||
os.makedirs(log_dir, exist_ok=True)
|
||||
self.writer = SummaryWriter(log_dir)
|
||||
|
||||
print(f"Starting training for {num_epochs} epochs")
|
||||
print(f"Logging to: {log_dir}")
|
||||
|
||||
for epoch in range(1, num_epochs + 1):
|
||||
print(f"\nEpoch {epoch}/{num_epochs}")
|
||||
|
||||
train_metrics = self.train_epoch(epoch)
|
||||
val_metrics = self.validate()
|
||||
|
||||
print(f"Train Loss: {train_metrics['loss']:.4f}")
|
||||
print(f"Val Loss: {val_metrics['loss']:.4f}")
|
||||
|
||||
if self.writer:
|
||||
self.writer.add_scalar("epoch/train_loss", train_metrics["loss"], epoch)
|
||||
self.writer.add_scalar("epoch/val_loss", val_metrics["loss"], epoch)
|
||||
|
||||
if val_metrics["loss"] < self.best_val_loss:
|
||||
self.best_val_loss = val_metrics["loss"]
|
||||
self.epochs_without_improvement = 0
|
||||
self.save_checkpoint(epoch, val_metrics["loss"], is_best=True)
|
||||
print(f"New best model saved with val loss: {val_metrics['loss']:.4f}")
|
||||
else:
|
||||
self.epochs_without_improvement += 1
|
||||
self.save_checkpoint(epoch, val_metrics["loss"], is_best=False)
|
||||
|
||||
patience = self.config.get("early_stopping_patience", 20)
|
||||
if self.epochs_without_improvement >= patience:
|
||||
print(f"Early stopping triggered after {patience} epochs without improvement")
|
||||
break
|
||||
|
||||
self.writer.close()
|
||||
|
||||
def save_checkpoint(self, epoch: int, val_loss: float, is_best: bool = False):
|
||||
checkpoint_dir = os.path.join(
|
||||
self.config.get("output_dir", "runs/homography"), "checkpoints"
|
||||
)
|
||||
os.makedirs(checkpoint_dir, exist_ok=True)
|
||||
|
||||
checkpoint = {
|
||||
"epoch": epoch,
|
||||
"model_state_dict": self.model.state_dict(),
|
||||
"optimizer_state_dict": self.optimizer.state_dict(),
|
||||
"val_loss": val_loss,
|
||||
"config": self.config,
|
||||
}
|
||||
|
||||
checkpoint_path = os.path.join(checkpoint_dir, f"checkpoint_epoch_{epoch}.pt")
|
||||
torch.save(checkpoint, checkpoint_path)
|
||||
|
||||
if is_best:
|
||||
best_path = os.path.join(checkpoint_dir, "best_model.pt")
|
||||
torch.save(checkpoint, best_path)
|
||||
|
||||
def load_checkpoint(self, checkpoint_path: str):
|
||||
checkpoint = torch.load(checkpoint_path, map_location=self.device)
|
||||
self.model.load_state_dict(checkpoint["model_state_dict"])
|
||||
self.optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
|
||||
return checkpoint["epoch"], checkpoint["val_loss"]
|
||||
|
||||
|
||||
def main():
|
||||
config_dict = config.copy()
|
||||
|
||||
if isinstance(config_dict.get("image_size"), list):
|
||||
config_dict["image_size"] = tuple(config_dict["image_size"])
|
||||
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
print(f"Using device: {device}")
|
||||
|
||||
print("Creating data loaders...")
|
||||
train_loader, val_loader = create_data_loaders(
|
||||
root_dir=config_dict["data_dir"],
|
||||
batch_size=config_dict["batch_size"],
|
||||
train_split=config_dict["train_split"],
|
||||
num_workers=config_dict["num_workers"],
|
||||
image_size=config_dict["image_size"],
|
||||
augment_train=True,
|
||||
augment_val=False,
|
||||
device=device,
|
||||
)
|
||||
|
||||
print(f"Train batches: {len(train_loader)}")
|
||||
print(f"Val batches: {len(val_loader)}")
|
||||
|
||||
print("Creating model...")
|
||||
model = create_homography_model(
|
||||
model_type="backbone",
|
||||
input_channels=3,
|
||||
backbone_name="resnet18",
|
||||
pretrained=True,
|
||||
dropout_rate=0.3,
|
||||
use_batch_norm=True,
|
||||
)
|
||||
|
||||
print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")
|
||||
|
||||
trainer = HomographyTrainer(
|
||||
model=model,
|
||||
train_loader=train_loader,
|
||||
val_loader=val_loader,
|
||||
device=device,
|
||||
config=config_dict,
|
||||
)
|
||||
|
||||
print("Starting training...")
|
||||
trainer.train(config_dict["epochs"])
|
||||
|
||||
print("Training completed!")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user