Files
autopilot/models/SiaN/example_homography.py
2026-02-16 19:07:31 +03:00

346 lines
10 KiB
Python

"""
Example script demonstrating the complete homography estimation workflow.
This script shows how to:
1. Load the ya_go_maps dataset
2. Create and train a homography estimation model
3. Perform inference on new image pairs
4. Visualize results
"""
import os
import sys
from pathlib import Path
# Add parent directory to path for imports
sys.path.append(str(Path(__file__).parent.parent))
import matplotlib.pyplot as plt
import numpy as np
import torch
from PIL import Image
from models.homography import HomographyDataset, create_data_loaders
from models.homography_cnn import HomographyCNN, HomographyLoss, create_homography_model
from models.infer_homography import HomographyInference
from models.train_homography import HomographyTrainer
def example_dataset_loading():
"""Example 1: Loading and exploring the dataset."""
print("=" * 60)
print("Example 1: Loading and exploring the dataset")
print("=" * 60)
# Path to the dataset
dataset_path = r"C:\Users\admin\Projects\autopilot\datasets\ya_go_maps\images"
# Create dataset
dataset = HomographyDataset(
root_dir=dataset_path,
augment=True,
image_size=(256, 256),
cache_homographies=True,
)
print(f"Dataset size: {len(dataset)} image pairs")
# Get a sample
sample = dataset[0]
print(f"\nSample keys: {list(sample.keys())}")
print(f"Google image shape: {sample['google_img'].shape}")
print(f"Yandex image shape: {sample['yandex_img'].shape}")
print(f"Homography shape: {sample['homography'].shape}")
# Show sample homography matrix
print(f"\nSample homography matrix:")
print(sample["homography"].numpy())
# Get sample without augmentation for visualization
raw_sample = dataset.get_sample_without_augmentation(0)
# Visualize sample
fig, axes = plt.subplots(1, 2, figsize=(10, 5))
axes[0].imshow(raw_sample["google_img"])
axes[0].set_title("Google Map")
axes[0].axis("off")
axes[1].imshow(raw_sample["yandex_img"])
axes[1].set_title("Yandex Map")
axes[1].axis("off")
plt.suptitle("Sample Image Pair (without augmentation)")
plt.tight_layout()
plt.show()
return dataset
def example_model_creation():
"""Example 2: Creating and testing the model."""
print("\n" + "=" * 60)
print("Example 2: Creating and testing the model")
print("=" * 60)
# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
# Create model
model = HomographyCNN(
input_channels=3,
hidden_channels=64,
num_blocks=4,
dropout_rate=0.3,
use_batch_norm=True,
).to(device)
print(
f"Model created with {sum(p.numel() for p in model.parameters()):,} parameters"
)
# Create dummy input
batch_size = 2
height, width = 256, 256
google_img = torch.randn(batch_size, 3, height, width).to(device)
yandex_img = torch.randn(batch_size, 3, height, width).to(device)
# Test forward pass
print("\nTesting forward pass...")
output = model(google_img, yandex_img, return_matrix=True)
print(f"Output shape: {output.shape}") # Should be (2, 3, 3)
print(f"Sample output matrix:")
print(output[0].cpu().detach().numpy())
# Test loss function
print("\nTesting loss function...")
target_homography = torch.eye(3).unsqueeze(0).expand(batch_size, -1, -1).to(device)
loss_fn = HomographyLoss(
matrix_weight=1.0,
geometric_weight=0.5,
reg_weight=0.1,
grid_size=8,
).to(device)
loss = loss_fn(output, target_homography, google_img, yandex_img)
print(f"Loss value: {loss.item():.6f}")
# Test metrics
print("\nTesting metrics...")
metrics = loss_fn.compute_metrics(output, target_homography)
for key, value in metrics.items():
print(f"{key}: {value:.6f}")
return model, loss_fn
def example_data_loaders():
"""Example 3: Creating data loaders for training."""
print("\n" + "=" * 60)
print("Example 3: Creating data loaders for training")
print("=" * 60)
# Path to the dataset
dataset_path = r"C:\Users\admin\Projects\autopilot\datasets\ya_go_maps\images"
# Create data loaders
train_loader, val_loader = create_data_loaders(
root_dir=dataset_path,
batch_size=16,
train_split=0.8,
num_workers=0, # Use 0 for debugging, increase for training
image_size=(256, 256),
augment_train=True,
augment_val=False,
)
print(f"Train batches: {len(train_loader)}")
print(f"Val batches: {len(val_loader)}")
# Get a batch from train loader
batch = next(iter(train_loader))
print(f"\nBatch keys: {list(batch.keys())}")
print(f"Batch size: {batch['google_img'].shape[0]}")
print(f"Image shape: {batch['google_img'].shape[1:]}")
return train_loader, val_loader
def example_training_config():
"""Example 4: Setting up training configuration."""
print("\n" + "=" * 60)
print("Example 4: Training configuration")
print("=" * 60)
# Training configuration
config = {
# Model config
"model_type": "cnn",
"hidden_channels": 64,
"num_blocks": 4,
"dropout_rate": 0.3,
"use_batch_norm": True,
"image_size": [256, 256],
# Training config
"epochs": 50,
"batch_size": 16,
"learning_rate": 1e-3,
"weight_decay": 1e-4,
"optimizer": "adam",
"scheduler": "plateau",
"grad_clip": 1.0,
# Loss config
"matrix_weight": 1.0,
"geometric_weight": 0.5,
"reg_weight": 0.1,
"grid_size": 8,
# Data config
"data_dir": r"C:\Users\admin\Projects\autopilot\datasets\ya_go_maps\images",
"train_split": 0.8,
"num_workers": 0,
# Output config
"output_dir": "runs/example_training",
"seed": 42,
}
print("Training configuration:")
for key, value in config.items():
print(f" {key}: {value}")
return config
def example_inference():
"""Example 5: Performing inference with a trained model."""
print("\n" + "=" * 60)
print("Example 5: Inference example")
print("=" * 60)
# Note: This example assumes you have a trained model
# For demonstration, we'll show the code structure
print("Inference workflow:")
print("1. Load trained model")
print("2. Preprocess input images")
print("3. Predict homography matrix")
print("4. Visualize alignment")
# Example code structure (commented out since we don't have a trained model yet)
"""
# Create inference object
inference = HomographyInference(
model_path="runs/homography/checkpoint_best.pth",
device="cuda" if torch.cuda.is_available() else "cpu",
)
# Load images
google_img = Image.open("path/to/google.png")
yandex_img = Image.open("path/to/yandex.png")
# Predict homography
homography = inference.predict(google_img, yandex_img)
print(f"Predicted homography matrix:")
print(homography.cpu().numpy())
# Visualize alignment
inference.visualize_alignment(
google_img,
yandex_img,
homography.cpu().numpy(),
save_path="alignment_visualization.png",
show=True,
)
"""
print(
"\nNote: To run actual inference, first train a model using train_homography.py"
)
def example_quick_start():
"""Example 6: Quick start guide."""
print("\n" + "=" * 60)
print("Example 6: Quick Start Guide")
print("=" * 60)
print("\nTo get started with homography estimation:")
print("\n1. First, explore the dataset:")
print(
' python -c "from models.homography import HomographyDataset; '
"dataset = HomographyDataset(r'C:\\Users\\admin\\Projects\\autopilot\\datasets\\ya_go_maps\\images'); "
"print(f'Found {len(dataset)} image pairs')\""
)
print("\n2. Train a model:")
print(" python models/train_homography.py \\")
print(
' --data_dir "C:\\Users\\admin\\Projects\\autopilot\\datasets\\ya_go_maps\\images" \\'
)
print(" --epochs 50 \\")
print(" --batch_size 16 \\")
print(' --output_dir "runs/my_experiment"')
print("\n3. Perform inference on a single image pair:")
print(" python models/infer_homography.py \\")
print(' --model_path "runs/my_experiment/checkpoint_best.pth" \\')
print(" --mode single \\")
print(' --google_path "path/to/google.png" \\')
print(' --yandex_path "path/to/yandex.png" \\')
print(' --output_vis "alignment_result.png"')
print("\n4. Evaluate on the entire dataset:")
print(" python models/infer_homography.py \\")
print(' --model_path "runs/my_experiment/checkpoint_best.pth" \\')
print(" --mode dataset \\")
print(
' --dataset_dir "C:\\Users\\admin\\Projects\\autopilot\\datasets\\ya_go_maps\\images" \\'
)
print(' --save_results "evaluation_results.json"')
def main():
"""Run all examples."""
print("Homography Estimation Workflow Examples")
print("=" * 60)
try:
# Example 1: Dataset loading
dataset = example_dataset_loading()
# Example 2: Model creation
model, loss_fn = example_model_creation()
# Example 3: Data loaders
train_loader, val_loader = example_data_loaders()
# Example 4: Training configuration
config = example_training_config()
# Example 5: Inference
example_inference()
# Example 6: Quick start
example_quick_start()
print("\n" + "=" * 60)
print("All examples completed successfully!")
print("=" * 60)
print("\nNext steps:")
print("1. Run the training script to train a model")
print("2. Use the inference script to test the trained model")
print("3. Integrate the homography estimation into your autopilot system")
except Exception as e:
print(f"\nError running examples: {e}")
import traceback
traceback.print_exc()
if __name__ == "__main__":
main()