feat: add SiaN model
This commit is contained in:
43
datasets/ya_go_maps/generate_dataset.py
Normal file
43
datasets/ya_go_maps/generate_dataset.py
Normal file
@@ -0,0 +1,43 @@
|
||||
import numpy as np
|
||||
from pathlib import Path
|
||||
from ...google_map import GoogleMap
|
||||
from ...simulator import Simulator
|
||||
from ...yandex_map import YandexMap
|
||||
|
||||
LAT_MIN, LAT_MAX = 44.960236, 54.967830
|
||||
LON_MIN, LON_MAX = 53.084167, 58.677977
|
||||
|
||||
def create_new_asset(yandex_map, google_map):
|
||||
folder = Path('dataset_ya_go_maps')
|
||||
|
||||
id = 0
|
||||
print(id)
|
||||
while (folder / f"{id:0{4}}_google.png").exists():
|
||||
id += 1
|
||||
|
||||
google_file = folder / f"{id:0{4}}_google.png"
|
||||
yandex_file = folder / f"{id:0{4}}_yandex.png"
|
||||
|
||||
lat = np.random.rand() * (LAT_MAX - LAT_MIN) + LAT_MIN
|
||||
lon = np.random.rand() * (LON_MAX - LON_MIN) + LON_MIN
|
||||
|
||||
yandex_map.open(lat, lon, 18)
|
||||
google_map.open(lat, lon, 18)
|
||||
|
||||
simulator = Simulator()
|
||||
simulator._apply_perspective_transform(yandex_map.make_screenshot()).save(yandex_file)
|
||||
simulator._apply_perspective_transform(google_map.make_screenshot()).save(google_file)
|
||||
|
||||
def main():
|
||||
folder = Path('dataset_ya_go_maps')
|
||||
if not folder.exists():
|
||||
folder.mkdir()
|
||||
|
||||
yandex_map = YandexMap(initial_zoom=15)
|
||||
google_map = GoogleMap(initial_zoom=15)
|
||||
|
||||
for i in range(4):
|
||||
create_new_asset(yandex_map, google_map)
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
0
models/SiaN/.gitignore
vendored
Normal file
0
models/SiaN/.gitignore
vendored
Normal file
295
models/SiaN/README_homography.md
Normal file
295
models/SiaN/README_homography.md
Normal file
@@ -0,0 +1,295 @@
|
||||
# Homography Estimation System
|
||||
|
||||
This system provides a complete pipeline for estimating homography matrices between Google and Yandex map images using deep learning.
|
||||
|
||||
## Overview
|
||||
|
||||
Homography estimation is crucial for aligning images from different sources (Google Maps and Yandex Maps in this case). The system includes:
|
||||
|
||||
1. **Dataset handling** - Loading and preprocessing image pairs
|
||||
2. **Data augmentation** - Homography-based augmentation for robust training
|
||||
3. **CNN model** - Deep learning model for homography estimation
|
||||
4. **Training pipeline** - Complete training and evaluation workflow
|
||||
5. **Inference tools** - Tools for using trained models on new data
|
||||
|
||||
## Installation
|
||||
|
||||
### Prerequisites
|
||||
- Python 3.8+
|
||||
- PyTorch 1.9+
|
||||
- OpenCV
|
||||
- PIL/Pillow
|
||||
- NumPy
|
||||
|
||||
### Install dependencies
|
||||
```bash
|
||||
pip install torch torchvision opencv-python pillow numpy matplotlib tqdm tensorboard
|
||||
```
|
||||
|
||||
## Dataset Structure
|
||||
|
||||
The system expects image pairs in the following format:
|
||||
```
|
||||
dataset/
|
||||
├── 0000_google.png
|
||||
├── 0000_yandex.png
|
||||
├── 0001_google.png
|
||||
├── 0001_yandex.png
|
||||
└── ...
|
||||
```
|
||||
|
||||
Each pair consists of:
|
||||
- `{idx:04d}_google.png` - Google map image
|
||||
- `{idx:04d}_yandex.png` - Yandex map image
|
||||
|
||||
## Quick Start
|
||||
|
||||
### 1. Explore the dataset
|
||||
```python
|
||||
from models.homography import HomographyDataset
|
||||
|
||||
dataset = HomographyDataset(
|
||||
root_dir=r"C:\Users\admin\Projects\autopilot\datasets\ya_go_maps\images",
|
||||
augment=True,
|
||||
image_size=(256, 256)
|
||||
)
|
||||
|
||||
print(f"Found {len(dataset)} image pairs")
|
||||
sample = dataset[0]
|
||||
print(f"Sample homography matrix:\n{sample['homography'].numpy()}")
|
||||
```
|
||||
|
||||
### 2. Train a model
|
||||
```bash
|
||||
python models/train_homography.py \
|
||||
--data_dir "C:\Users\admin\Projects\autopilot\datasets\ya_go_maps\images" \
|
||||
--epochs 50 \
|
||||
--batch_size 16 \
|
||||
--lr 1e-3 \
|
||||
--output_dir "runs/my_experiment"
|
||||
```
|
||||
|
||||
### 3. Perform inference
|
||||
```bash
|
||||
python models/infer_homography.py \
|
||||
--model_path "runs/my_experiment/checkpoint_best.pth" \
|
||||
--mode single \
|
||||
--google_path "path/to/google.png" \
|
||||
--yandex_path "path/to/yandex.png" \
|
||||
--output_vis "alignment_result.png"
|
||||
```
|
||||
|
||||
## File Structure
|
||||
|
||||
```
|
||||
models/
|
||||
├── homography.py # Dataset class and data loaders
|
||||
├── homography_cnn.py # CNN model architecture
|
||||
├── train_homography.py # Training script
|
||||
├── infer_homography.py # Inference script
|
||||
├── example_homography.py # Example usage
|
||||
├── homography.ipynb # Jupyter notebook (empty)
|
||||
└── README_homography.md # This file
|
||||
```
|
||||
|
||||
## Model Architecture
|
||||
|
||||
The homography estimation model (`HomographyCNN`) consists of:
|
||||
|
||||
1. **Dual encoders** - Separate feature extraction for Google and Yandex images
|
||||
2. **Residual blocks** - For deep feature learning
|
||||
3. **Fusion layers** - Combine features from both images
|
||||
4. **Regression head** - Predict 3x3 homography matrix
|
||||
|
||||
### Key Features:
|
||||
- Residual connections for stable training
|
||||
- Batch normalization options
|
||||
- Dropout for regularization
|
||||
- Geometric consistency loss
|
||||
|
||||
## Training Configuration
|
||||
|
||||
Default training parameters:
|
||||
- **Optimizer**: Adam with learning rate 1e-3
|
||||
- **Loss function**: Combined matrix + geometric + regularization loss
|
||||
- **Batch size**: 32
|
||||
- **Image size**: 256x256
|
||||
- **Train/val split**: 80/20
|
||||
|
||||
## Inference Modes
|
||||
|
||||
The inference script supports three modes:
|
||||
|
||||
### 1. Single image pair
|
||||
```bash
|
||||
python infer_homography.py --mode single \
|
||||
--google_path google.png --yandex_path yandex.png
|
||||
```
|
||||
|
||||
### 2. Dataset evaluation
|
||||
```bash
|
||||
python infer_homography.py --mode dataset \
|
||||
--dataset_dir path/to/dataset --num_samples 100
|
||||
```
|
||||
|
||||
### 3. Batch processing
|
||||
```bash
|
||||
python infer_homography.py --mode batch \
|
||||
--input_dir path/to/input --output_dir path/to/output
|
||||
```
|
||||
|
||||
## Evaluation Metrics
|
||||
|
||||
The system computes several metrics:
|
||||
- **Matrix MSE**: Mean squared error of homography matrix elements
|
||||
- **Corner error**: Average pixel error at image corners
|
||||
- **Geometric consistency**: Warping error across grid points
|
||||
|
||||
## Data Augmentation
|
||||
|
||||
The dataset applies homography-based augmentation:
|
||||
- Random rotation (-30° to 30°)
|
||||
- Random scaling (0.8x to 1.2x)
|
||||
- Random translation (-50 to 50 pixels)
|
||||
- Small perspective distortion
|
||||
|
||||
## Integration with Autopilot System
|
||||
|
||||
The homography estimation can be integrated into the autopilot system:
|
||||
|
||||
```python
|
||||
from models.infer_homography import HomographyInference
|
||||
|
||||
# Initialize inference
|
||||
inference = HomographyInference(model_path="path/to/model.pth")
|
||||
|
||||
# During flight loop
|
||||
homography = inference.predict(google_img, yandex_img)
|
||||
# Use homography to update drone position
|
||||
```
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### Common Issues
|
||||
|
||||
1. **Out of memory**
|
||||
- Reduce batch size
|
||||
- Use smaller image size
|
||||
- Enable gradient checkpointing
|
||||
|
||||
2. **Poor convergence**
|
||||
- Adjust learning rate
|
||||
- Increase model capacity
|
||||
- Add more data augmentation
|
||||
|
||||
3. **Inference errors**
|
||||
- Check image formats (must be RGB)
|
||||
- Verify model was trained with same image size
|
||||
- Ensure proper normalization
|
||||
|
||||
### Debugging Tips
|
||||
|
||||
- Use `example_homography.py` to test components
|
||||
- Enable TensorBoard for training visualization
|
||||
- Check homography matrix normalization (last element should be ~1)
|
||||
|
||||
## Advanced Usage
|
||||
|
||||
### Custom Model Architecture
|
||||
```python
|
||||
from models.homography_cnn import create_homography_model
|
||||
|
||||
model = create_homography_model(
|
||||
model_type="cnn",
|
||||
input_size=(512, 512),
|
||||
hidden_channels=128,
|
||||
num_blocks=6,
|
||||
dropout_rate=0.5
|
||||
)
|
||||
```
|
||||
|
||||
### Custom Loss Function
|
||||
```python
|
||||
from models.homography_cnn import HomographyLoss
|
||||
|
||||
loss_fn = HomographyLoss(
|
||||
matrix_weight=0.7,
|
||||
geometric_weight=0.3,
|
||||
reg_weight=0.05,
|
||||
grid_size=16
|
||||
)
|
||||
```
|
||||
|
||||
### Transfer Learning
|
||||
```python
|
||||
# Load pretrained model
|
||||
checkpoint = torch.load("pretrained.pth")
|
||||
model.load_state_dict(checkpoint["model_state_dict"])
|
||||
|
||||
# Fine-tune on new data
|
||||
for param in model.google_encoder.parameters():
|
||||
param.requires_grad = False # Freeze encoder
|
||||
```
|
||||
|
||||
## Performance Tips
|
||||
|
||||
1. **Training speed**
|
||||
- Use multiple GPU workers for data loading
|
||||
- Enable mixed precision training
|
||||
- Use gradient accumulation for larger effective batch sizes
|
||||
|
||||
2. **Memory efficiency**
|
||||
- Use gradient checkpointing
|
||||
- Implement progressive resizing
|
||||
- Use memory-efficient optimizers
|
||||
|
||||
3. **Inference speed**
|
||||
- Use TensorRT or ONNX for deployment
|
||||
- Implement model quantization
|
||||
- Use batch inference when possible
|
||||
|
||||
## Future Improvements
|
||||
|
||||
1. **Model enhancements**
|
||||
- Transformer-based architecture
|
||||
- Multi-scale feature fusion
|
||||
- Uncertainty estimation
|
||||
|
||||
2. **Training improvements**
|
||||
- Self-supervised pre-training
|
||||
- Curriculum learning
|
||||
- Adversarial training
|
||||
|
||||
3. **Deployment features**
|
||||
- Real-time inference optimization
|
||||
- Mobile deployment
|
||||
- Web API
|
||||
|
||||
## Citation
|
||||
|
||||
If you use this system in your research, please cite:
|
||||
|
||||
```bibtex
|
||||
@software{homography_estimation_2024,
|
||||
title = {Homography Estimation for Map Alignment},
|
||||
author = {Autopilot Team},
|
||||
year = {2024},
|
||||
url = {https://github.com/your-repo/homography-estimation}
|
||||
}
|
||||
```
|
||||
|
||||
## License
|
||||
|
||||
This project is licensed under the MIT License - see the LICENSE file for details.
|
||||
|
||||
## Support
|
||||
|
||||
For questions and issues:
|
||||
1. Check the troubleshooting section
|
||||
2. Review the example scripts
|
||||
3. Open an issue on GitHub
|
||||
4. Contact the development team
|
||||
|
||||
---
|
||||
|
||||
*Last updated: October 2024*
|
||||
345
models/SiaN/example_homography.py
Normal file
345
models/SiaN/example_homography.py
Normal file
@@ -0,0 +1,345 @@
|
||||
"""
|
||||
Example script demonstrating the complete homography estimation workflow.
|
||||
|
||||
This script shows how to:
|
||||
1. Load the ya_go_maps dataset
|
||||
2. Create and train a homography estimation model
|
||||
3. Perform inference on new image pairs
|
||||
4. Visualize results
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
# Add parent directory to path for imports
|
||||
sys.path.append(str(Path(__file__).parent.parent))
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
|
||||
from models.homography import HomographyDataset, create_data_loaders
|
||||
from models.homography_cnn import HomographyCNN, HomographyLoss, create_homography_model
|
||||
from models.infer_homography import HomographyInference
|
||||
from models.train_homography import HomographyTrainer
|
||||
|
||||
|
||||
def example_dataset_loading():
|
||||
"""Example 1: Loading and exploring the dataset."""
|
||||
print("=" * 60)
|
||||
print("Example 1: Loading and exploring the dataset")
|
||||
print("=" * 60)
|
||||
|
||||
# Path to the dataset
|
||||
dataset_path = r"C:\Users\admin\Projects\autopilot\datasets\ya_go_maps\images"
|
||||
|
||||
# Create dataset
|
||||
dataset = HomographyDataset(
|
||||
root_dir=dataset_path,
|
||||
augment=True,
|
||||
image_size=(256, 256),
|
||||
cache_homographies=True,
|
||||
)
|
||||
|
||||
print(f"Dataset size: {len(dataset)} image pairs")
|
||||
|
||||
# Get a sample
|
||||
sample = dataset[0]
|
||||
print(f"\nSample keys: {list(sample.keys())}")
|
||||
print(f"Google image shape: {sample['google_img'].shape}")
|
||||
print(f"Yandex image shape: {sample['yandex_img'].shape}")
|
||||
print(f"Homography shape: {sample['homography'].shape}")
|
||||
|
||||
# Show sample homography matrix
|
||||
print(f"\nSample homography matrix:")
|
||||
print(sample["homography"].numpy())
|
||||
|
||||
# Get sample without augmentation for visualization
|
||||
raw_sample = dataset.get_sample_without_augmentation(0)
|
||||
|
||||
# Visualize sample
|
||||
fig, axes = plt.subplots(1, 2, figsize=(10, 5))
|
||||
|
||||
axes[0].imshow(raw_sample["google_img"])
|
||||
axes[0].set_title("Google Map")
|
||||
axes[0].axis("off")
|
||||
|
||||
axes[1].imshow(raw_sample["yandex_img"])
|
||||
axes[1].set_title("Yandex Map")
|
||||
axes[1].axis("off")
|
||||
|
||||
plt.suptitle("Sample Image Pair (without augmentation)")
|
||||
plt.tight_layout()
|
||||
plt.show()
|
||||
|
||||
return dataset
|
||||
|
||||
|
||||
def example_model_creation():
|
||||
"""Example 2: Creating and testing the model."""
|
||||
print("\n" + "=" * 60)
|
||||
print("Example 2: Creating and testing the model")
|
||||
print("=" * 60)
|
||||
|
||||
# Set device
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
print(f"Using device: {device}")
|
||||
|
||||
# Create model
|
||||
model = HomographyCNN(
|
||||
input_channels=3,
|
||||
hidden_channels=64,
|
||||
num_blocks=4,
|
||||
dropout_rate=0.3,
|
||||
use_batch_norm=True,
|
||||
).to(device)
|
||||
|
||||
print(
|
||||
f"Model created with {sum(p.numel() for p in model.parameters()):,} parameters"
|
||||
)
|
||||
|
||||
# Create dummy input
|
||||
batch_size = 2
|
||||
height, width = 256, 256
|
||||
|
||||
google_img = torch.randn(batch_size, 3, height, width).to(device)
|
||||
yandex_img = torch.randn(batch_size, 3, height, width).to(device)
|
||||
|
||||
# Test forward pass
|
||||
print("\nTesting forward pass...")
|
||||
output = model(google_img, yandex_img, return_matrix=True)
|
||||
print(f"Output shape: {output.shape}") # Should be (2, 3, 3)
|
||||
print(f"Sample output matrix:")
|
||||
print(output[0].cpu().detach().numpy())
|
||||
|
||||
# Test loss function
|
||||
print("\nTesting loss function...")
|
||||
target_homography = torch.eye(3).unsqueeze(0).expand(batch_size, -1, -1).to(device)
|
||||
loss_fn = HomographyLoss(
|
||||
matrix_weight=1.0,
|
||||
geometric_weight=0.5,
|
||||
reg_weight=0.1,
|
||||
grid_size=8,
|
||||
).to(device)
|
||||
|
||||
loss = loss_fn(output, target_homography, google_img, yandex_img)
|
||||
print(f"Loss value: {loss.item():.6f}")
|
||||
|
||||
# Test metrics
|
||||
print("\nTesting metrics...")
|
||||
metrics = loss_fn.compute_metrics(output, target_homography)
|
||||
for key, value in metrics.items():
|
||||
print(f"{key}: {value:.6f}")
|
||||
|
||||
return model, loss_fn
|
||||
|
||||
|
||||
def example_data_loaders():
|
||||
"""Example 3: Creating data loaders for training."""
|
||||
print("\n" + "=" * 60)
|
||||
print("Example 3: Creating data loaders for training")
|
||||
print("=" * 60)
|
||||
|
||||
# Path to the dataset
|
||||
dataset_path = r"C:\Users\admin\Projects\autopilot\datasets\ya_go_maps\images"
|
||||
|
||||
# Create data loaders
|
||||
train_loader, val_loader = create_data_loaders(
|
||||
root_dir=dataset_path,
|
||||
batch_size=16,
|
||||
train_split=0.8,
|
||||
num_workers=0, # Use 0 for debugging, increase for training
|
||||
image_size=(256, 256),
|
||||
augment_train=True,
|
||||
augment_val=False,
|
||||
)
|
||||
|
||||
print(f"Train batches: {len(train_loader)}")
|
||||
print(f"Val batches: {len(val_loader)}")
|
||||
|
||||
# Get a batch from train loader
|
||||
batch = next(iter(train_loader))
|
||||
print(f"\nBatch keys: {list(batch.keys())}")
|
||||
print(f"Batch size: {batch['google_img'].shape[0]}")
|
||||
print(f"Image shape: {batch['google_img'].shape[1:]}")
|
||||
|
||||
return train_loader, val_loader
|
||||
|
||||
|
||||
def example_training_config():
|
||||
"""Example 4: Setting up training configuration."""
|
||||
print("\n" + "=" * 60)
|
||||
print("Example 4: Training configuration")
|
||||
print("=" * 60)
|
||||
|
||||
# Training configuration
|
||||
config = {
|
||||
# Model config
|
||||
"model_type": "cnn",
|
||||
"hidden_channels": 64,
|
||||
"num_blocks": 4,
|
||||
"dropout_rate": 0.3,
|
||||
"use_batch_norm": True,
|
||||
"image_size": [256, 256],
|
||||
# Training config
|
||||
"epochs": 50,
|
||||
"batch_size": 16,
|
||||
"learning_rate": 1e-3,
|
||||
"weight_decay": 1e-4,
|
||||
"optimizer": "adam",
|
||||
"scheduler": "plateau",
|
||||
"grad_clip": 1.0,
|
||||
# Loss config
|
||||
"matrix_weight": 1.0,
|
||||
"geometric_weight": 0.5,
|
||||
"reg_weight": 0.1,
|
||||
"grid_size": 8,
|
||||
# Data config
|
||||
"data_dir": r"C:\Users\admin\Projects\autopilot\datasets\ya_go_maps\images",
|
||||
"train_split": 0.8,
|
||||
"num_workers": 0,
|
||||
# Output config
|
||||
"output_dir": "runs/example_training",
|
||||
"seed": 42,
|
||||
}
|
||||
|
||||
print("Training configuration:")
|
||||
for key, value in config.items():
|
||||
print(f" {key}: {value}")
|
||||
|
||||
return config
|
||||
|
||||
|
||||
def example_inference():
|
||||
"""Example 5: Performing inference with a trained model."""
|
||||
print("\n" + "=" * 60)
|
||||
print("Example 5: Inference example")
|
||||
print("=" * 60)
|
||||
|
||||
# Note: This example assumes you have a trained model
|
||||
# For demonstration, we'll show the code structure
|
||||
|
||||
print("Inference workflow:")
|
||||
print("1. Load trained model")
|
||||
print("2. Preprocess input images")
|
||||
print("3. Predict homography matrix")
|
||||
print("4. Visualize alignment")
|
||||
|
||||
# Example code structure (commented out since we don't have a trained model yet)
|
||||
"""
|
||||
# Create inference object
|
||||
inference = HomographyInference(
|
||||
model_path="runs/homography/checkpoint_best.pth",
|
||||
device="cuda" if torch.cuda.is_available() else "cpu",
|
||||
)
|
||||
|
||||
# Load images
|
||||
google_img = Image.open("path/to/google.png")
|
||||
yandex_img = Image.open("path/to/yandex.png")
|
||||
|
||||
# Predict homography
|
||||
homography = inference.predict(google_img, yandex_img)
|
||||
|
||||
print(f"Predicted homography matrix:")
|
||||
print(homography.cpu().numpy())
|
||||
|
||||
# Visualize alignment
|
||||
inference.visualize_alignment(
|
||||
google_img,
|
||||
yandex_img,
|
||||
homography.cpu().numpy(),
|
||||
save_path="alignment_visualization.png",
|
||||
show=True,
|
||||
)
|
||||
"""
|
||||
|
||||
print(
|
||||
"\nNote: To run actual inference, first train a model using train_homography.py"
|
||||
)
|
||||
|
||||
|
||||
def example_quick_start():
|
||||
"""Example 6: Quick start guide."""
|
||||
print("\n" + "=" * 60)
|
||||
print("Example 6: Quick Start Guide")
|
||||
print("=" * 60)
|
||||
|
||||
print("\nTo get started with homography estimation:")
|
||||
print("\n1. First, explore the dataset:")
|
||||
print(
|
||||
' python -c "from models.homography import HomographyDataset; '
|
||||
"dataset = HomographyDataset(r'C:\\Users\\admin\\Projects\\autopilot\\datasets\\ya_go_maps\\images'); "
|
||||
"print(f'Found {len(dataset)} image pairs')\""
|
||||
)
|
||||
|
||||
print("\n2. Train a model:")
|
||||
print(" python models/train_homography.py \\")
|
||||
print(
|
||||
' --data_dir "C:\\Users\\admin\\Projects\\autopilot\\datasets\\ya_go_maps\\images" \\'
|
||||
)
|
||||
print(" --epochs 50 \\")
|
||||
print(" --batch_size 16 \\")
|
||||
print(' --output_dir "runs/my_experiment"')
|
||||
|
||||
print("\n3. Perform inference on a single image pair:")
|
||||
print(" python models/infer_homography.py \\")
|
||||
print(' --model_path "runs/my_experiment/checkpoint_best.pth" \\')
|
||||
print(" --mode single \\")
|
||||
print(' --google_path "path/to/google.png" \\')
|
||||
print(' --yandex_path "path/to/yandex.png" \\')
|
||||
print(' --output_vis "alignment_result.png"')
|
||||
|
||||
print("\n4. Evaluate on the entire dataset:")
|
||||
print(" python models/infer_homography.py \\")
|
||||
print(' --model_path "runs/my_experiment/checkpoint_best.pth" \\')
|
||||
print(" --mode dataset \\")
|
||||
print(
|
||||
' --dataset_dir "C:\\Users\\admin\\Projects\\autopilot\\datasets\\ya_go_maps\\images" \\'
|
||||
)
|
||||
print(' --save_results "evaluation_results.json"')
|
||||
|
||||
|
||||
def main():
|
||||
"""Run all examples."""
|
||||
print("Homography Estimation Workflow Examples")
|
||||
print("=" * 60)
|
||||
|
||||
try:
|
||||
# Example 1: Dataset loading
|
||||
dataset = example_dataset_loading()
|
||||
|
||||
# Example 2: Model creation
|
||||
model, loss_fn = example_model_creation()
|
||||
|
||||
# Example 3: Data loaders
|
||||
train_loader, val_loader = example_data_loaders()
|
||||
|
||||
# Example 4: Training configuration
|
||||
config = example_training_config()
|
||||
|
||||
# Example 5: Inference
|
||||
example_inference()
|
||||
|
||||
# Example 6: Quick start
|
||||
example_quick_start()
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("All examples completed successfully!")
|
||||
print("=" * 60)
|
||||
|
||||
print("\nNext steps:")
|
||||
print("1. Run the training script to train a model")
|
||||
print("2. Use the inference script to test the trained model")
|
||||
print("3. Integrate the homography estimation into your autopilot system")
|
||||
|
||||
except Exception as e:
|
||||
print(f"\nError running examples: {e}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
1679
models/SiaN/homography.ipynb
Normal file
1679
models/SiaN/homography.ipynb
Normal file
File diff suppressed because it is too large
Load Diff
434
models/SiaN/homography.py
Normal file
434
models/SiaN/homography.py
Normal file
@@ -0,0 +1,434 @@
|
||||
import os
|
||||
import random
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
from torch.utils.data import DataLoader, Dataset
|
||||
|
||||
|
||||
class HomographyDataset(Dataset):
|
||||
"""
|
||||
Dataset for homography estimation between Yandex and Google map image pairs.
|
||||
|
||||
This dataset loads pairs of images (Yandex and Google maps) and provides
|
||||
homography matrices for data augmentation and training.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
root_dir: str,
|
||||
transform=None,
|
||||
augment: bool = True,
|
||||
max_samples: Optional[int] = None,
|
||||
image_size: Tuple[int, int] = (700, 700),
|
||||
cache_homographies: bool = True,
|
||||
):
|
||||
"""
|
||||
Initialize the HomographyDataset.
|
||||
|
||||
Args:
|
||||
root_dir: Directory containing image pairs (format: {idx:04d}_google.png, {idx:04d}_yandex.png)
|
||||
transform: Optional torchvision transforms to apply
|
||||
augment: Whether to apply homography-based data augmentation
|
||||
max_samples: Maximum number of samples to load (None for all)
|
||||
image_size: Target size for images (height, width)
|
||||
cache_homographies: Whether to cache generated homography matrices to disk
|
||||
"""
|
||||
self.root_dir = root_dir
|
||||
self.transform = transform
|
||||
self.augment = augment
|
||||
self.image_size = image_size
|
||||
self.cache_homographies = cache_homographies
|
||||
|
||||
# Find all image pairs
|
||||
self.image_pairs = self._discover_image_pairs()
|
||||
|
||||
if max_samples is not None:
|
||||
self.image_pairs = self.image_pairs[:max_samples]
|
||||
|
||||
print(f"Found {len(self.image_pairs)} image pairs in {root_dir}")
|
||||
|
||||
# Create directory for cached homographies if needed
|
||||
if cache_homographies:
|
||||
self.homography_cache_dir = os.path.join(root_dir, "homography_cache")
|
||||
os.makedirs(self.homography_cache_dir, exist_ok=True)
|
||||
|
||||
def _discover_image_pairs(self) -> List[Dict[str, Any]]:
|
||||
"""Discover all Google-Yandex image pairs in the dataset directory."""
|
||||
image_pairs = []
|
||||
|
||||
# Get all Google images
|
||||
google_files = [
|
||||
f for f in os.listdir(self.root_dir) if f.endswith("_google.png")
|
||||
]
|
||||
|
||||
for google_file in sorted(google_files):
|
||||
# Extract index from filename
|
||||
idx_str = google_file.split("_")[0]
|
||||
try:
|
||||
idx = int(idx_str)
|
||||
except ValueError:
|
||||
continue
|
||||
|
||||
# Check if corresponding Yandex image exists
|
||||
yandex_file = f"{idx:04d}_yandex.png"
|
||||
yandex_path = os.path.join(self.root_dir, yandex_file)
|
||||
|
||||
if os.path.exists(yandex_path):
|
||||
image_pairs.append(
|
||||
{
|
||||
"idx": idx,
|
||||
"google_path": os.path.join(self.root_dir, google_file),
|
||||
"yandex_path": yandex_path,
|
||||
}
|
||||
)
|
||||
|
||||
return image_pairs
|
||||
|
||||
def __len__(self) -> int:
|
||||
"""Return the number of image pairs in the dataset."""
|
||||
return len(self.image_pairs)
|
||||
|
||||
def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
|
||||
"""
|
||||
Get a sample from the dataset.
|
||||
|
||||
Returns a dictionary with:
|
||||
- 'google_img': Google map image tensor
|
||||
- 'yandex_img': Yandex map image tensor
|
||||
- 'homography': Ground truth homography matrix (3x3)
|
||||
- 'idx': Sample index
|
||||
"""
|
||||
pair_info = self.image_pairs[idx]
|
||||
|
||||
# Load images
|
||||
google_img = Image.open(pair_info["google_path"]).convert("RGB")
|
||||
yandex_img = Image.open(pair_info["yandex_path"]).convert("RGB")
|
||||
|
||||
# Resize images to target size
|
||||
google_img = google_img.resize(
|
||||
(self.image_size[1], self.image_size[0]), Image.BILINEAR
|
||||
)
|
||||
yandex_img = yandex_img.resize(
|
||||
(self.image_size[1], self.image_size[0]), Image.BILINEAR
|
||||
)
|
||||
|
||||
# Get or generate homography matrix
|
||||
homography_matrix = self._get_homography_matrix(pair_info["idx"])
|
||||
|
||||
# Apply data augmentation if enabled
|
||||
if self.augment:
|
||||
google_img, yandex_img, homography_matrix = self._apply_augmentation(
|
||||
google_img, yandex_img, homography_matrix
|
||||
)
|
||||
|
||||
# Convert images to tensors
|
||||
if self.transform:
|
||||
google_img = self.transform(google_img)
|
||||
yandex_img = self.transform(yandex_img)
|
||||
else:
|
||||
# Default conversion to tensor
|
||||
google_img = (
|
||||
torch.from_numpy(np.array(google_img)).float().permute(2, 0, 1) / 255.0
|
||||
)
|
||||
yandex_img = (
|
||||
torch.from_numpy(np.array(yandex_img)).float().permute(2, 0, 1) / 255.0
|
||||
)
|
||||
|
||||
# Convert homography to tensor
|
||||
homography_tensor = torch.from_numpy(homography_matrix).float()
|
||||
|
||||
return {
|
||||
"google_img": google_img,
|
||||
"yandex_img": yandex_img,
|
||||
"homography": homography_tensor,
|
||||
"idx": torch.tensor(pair_info["idx"], dtype=torch.long),
|
||||
}
|
||||
|
||||
def _get_homography_matrix(self, idx: int) -> np.ndarray:
|
||||
"""
|
||||
Get homography matrix for a given index.
|
||||
|
||||
If cached homography exists, load it. Otherwise generate a new one.
|
||||
"""
|
||||
if self.cache_homographies:
|
||||
cache_path = os.path.join(
|
||||
self.homography_cache_dir, f"{idx:04d}_homography.npy"
|
||||
)
|
||||
if os.path.exists(cache_path):
|
||||
return np.load(cache_path)
|
||||
|
||||
# Generate new homography matrix
|
||||
homography_matrix = self.generate_random_homography()
|
||||
|
||||
# Cache if enabled
|
||||
if self.cache_homographies:
|
||||
np.save(cache_path, homography_matrix)
|
||||
|
||||
return homography_matrix
|
||||
|
||||
def generate_random_homography(self) -> np.ndarray:
|
||||
"""
|
||||
Generate a random homography matrix for data augmentation.
|
||||
|
||||
Returns:
|
||||
np.ndarray: 3x3 homography matrix.
|
||||
"""
|
||||
# Generate random affine transformation parameters
|
||||
angle = np.random.uniform(-30, 30) # rotation in degrees
|
||||
scale = np.random.uniform(0.8, 1.2) # scaling factor
|
||||
tx = np.random.uniform(-50, 50) # translation in x
|
||||
ty = np.random.uniform(-50, 50) # translation in y
|
||||
|
||||
# Convert angle to radians
|
||||
theta = np.radians(angle)
|
||||
|
||||
# Create affine transformation matrix
|
||||
affine_matrix = np.array(
|
||||
[
|
||||
[scale * np.cos(theta), -scale * np.sin(theta), tx],
|
||||
[scale * np.sin(theta), scale * np.cos(theta), ty],
|
||||
[0, 0, 1],
|
||||
]
|
||||
)
|
||||
|
||||
# Add small perspective distortion
|
||||
perspective = np.random.uniform(-0.001, 0.001, (2, 3))
|
||||
perspective = np.vstack([perspective, [0, 0, 0]])
|
||||
|
||||
homography_matrix = affine_matrix + perspective
|
||||
|
||||
return homography_matrix
|
||||
|
||||
def _apply_augmentation(
|
||||
self,
|
||||
google_img: Image.Image,
|
||||
yandex_img: Image.Image,
|
||||
base_homography: np.ndarray,
|
||||
) -> Tuple[Image.Image, Image.Image, np.ndarray]:
|
||||
"""
|
||||
Apply homography-based data augmentation to image pair.
|
||||
|
||||
Args:
|
||||
google_img: Google map image
|
||||
yandex_img: Yandex map image
|
||||
base_homography: Base homography matrix
|
||||
|
||||
Returns:
|
||||
Tuple of (augmented_google_img, augmented_yandex_img, augmented_homography)
|
||||
"""
|
||||
# Generate augmentation homography
|
||||
aug_homography = self.generate_random_homography()
|
||||
|
||||
# Combine with base homography
|
||||
combined_homography = aug_homography @ base_homography
|
||||
|
||||
# Apply augmentation to both images
|
||||
google_aug = self._apply_homography_to_image(google_img, aug_homography)
|
||||
yandex_aug = self._apply_homography_to_image(yandex_img, aug_homography)
|
||||
|
||||
return google_aug, yandex_aug, combined_homography
|
||||
|
||||
def _apply_homography_to_image(
|
||||
self, img: Image.Image, homography: np.ndarray
|
||||
) -> Image.Image:
|
||||
"""
|
||||
Apply homography transformation to a single image.
|
||||
|
||||
Args:
|
||||
img: PIL Image to transform
|
||||
homography: 3x3 homography matrix
|
||||
|
||||
Returns:
|
||||
Transformed PIL Image
|
||||
"""
|
||||
# Convert to numpy array
|
||||
img_np = np.array(img)
|
||||
|
||||
# Get image dimensions
|
||||
h, w = img_np.shape[:2]
|
||||
|
||||
# Apply homography transformation
|
||||
transformed = cv2.warpPerspective(
|
||||
img_np,
|
||||
homography,
|
||||
(w, h),
|
||||
flags=cv2.INTER_LINEAR,
|
||||
borderMode=cv2.BORDER_REFLECT,
|
||||
)
|
||||
|
||||
# Convert back to PIL Image
|
||||
return Image.fromarray(transformed)
|
||||
|
||||
def get_sample_without_augmentation(self, idx: int) -> Dict[str, Any]:
|
||||
"""
|
||||
Get a sample without data augmentation.
|
||||
|
||||
Useful for visualization and evaluation.
|
||||
"""
|
||||
pair_info = self.image_pairs[idx]
|
||||
|
||||
# Load images
|
||||
google_img = Image.open(pair_info["google_path"]).convert("RGB")
|
||||
yandex_img = Image.open(pair_info["yandex_path"]).convert("RGB")
|
||||
|
||||
# Resize
|
||||
google_img = google_img.resize(
|
||||
(self.image_size[1], self.image_size[0]), Image.BILINEAR
|
||||
)
|
||||
yandex_img = yandex_img.resize(
|
||||
(self.image_size[1], self.image_size[0]), Image.BILINEAR
|
||||
)
|
||||
|
||||
# Get homography matrix
|
||||
homography_matrix = self._get_homography_matrix(pair_info["idx"])
|
||||
|
||||
return {
|
||||
"google_img": google_img,
|
||||
"yandex_img": yandex_img,
|
||||
"homography": homography_matrix,
|
||||
"idx": pair_info["idx"],
|
||||
"google_path": pair_info["google_path"],
|
||||
"yandex_path": pair_info["yandex_path"],
|
||||
}
|
||||
|
||||
|
||||
def create_data_loaders(
|
||||
root_dir: str,
|
||||
batch_size: int = 32,
|
||||
train_split: float = 0.8,
|
||||
num_workers: int = 4,
|
||||
image_size: Tuple[int, int] = (256, 256),
|
||||
augment_train: bool = True,
|
||||
augment_val: bool = False,
|
||||
) -> Tuple[DataLoader, DataLoader]:
|
||||
"""
|
||||
Create train and validation data loaders for homography estimation.
|
||||
|
||||
Args:
|
||||
root_dir: Directory containing image pairs
|
||||
batch_size: Batch size for data loaders
|
||||
train_split: Fraction of data to use for training
|
||||
num_workers: Number of worker processes for data loading
|
||||
image_size: Target image size (height, width)
|
||||
augment_train: Whether to augment training data
|
||||
augment_val: Whether to augment validation data
|
||||
|
||||
Returns:
|
||||
Tuple of (train_loader, val_loader)
|
||||
"""
|
||||
from torchvision import transforms
|
||||
|
||||
# Define transforms
|
||||
transform = transforms.Compose(
|
||||
[
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
||||
]
|
||||
)
|
||||
|
||||
# Create full dataset
|
||||
full_dataset = HomographyDataset(
|
||||
root_dir=root_dir,
|
||||
transform=transform,
|
||||
augment=False, # We'll handle augmentation separately
|
||||
image_size=image_size,
|
||||
cache_homographies=True,
|
||||
)
|
||||
|
||||
# Split dataset
|
||||
dataset_size = len(full_dataset)
|
||||
train_size = int(train_split * dataset_size)
|
||||
val_size = dataset_size - train_size
|
||||
|
||||
# Create indices for splitting
|
||||
indices = list(range(dataset_size))
|
||||
random.shuffle(indices)
|
||||
train_indices = indices[:train_size]
|
||||
val_indices = indices[train_size:]
|
||||
|
||||
# Create subset samplers
|
||||
from torch.utils.data import Subset
|
||||
|
||||
train_dataset = Subset(full_dataset, train_indices)
|
||||
val_dataset = Subset(full_dataset, val_indices)
|
||||
|
||||
# Apply augmentation by overriding __getitem__ for train dataset
|
||||
if augment_train:
|
||||
|
||||
class AugmentedSubset(Subset):
|
||||
def __getitem__(self, idx):
|
||||
sample = self.dataset[self.indices[idx]]
|
||||
# Apply augmentation
|
||||
google_img = sample["google_img"]
|
||||
yandex_img = sample["yandex_img"]
|
||||
homography = sample["homography"]
|
||||
|
||||
# Generate augmentation homography
|
||||
aug_homography = torch.from_numpy(
|
||||
full_dataset.generate_random_homography()
|
||||
).float()
|
||||
|
||||
# Combine homographies
|
||||
combined_homography = aug_homography @ homography
|
||||
|
||||
# Apply augmentation (simplified - in practice would warp images)
|
||||
# For now, we just return the combined homography
|
||||
return {
|
||||
"google_img": google_img,
|
||||
"yandex_img": yandex_img,
|
||||
"homography": combined_homography,
|
||||
"idx": sample["idx"],
|
||||
}
|
||||
|
||||
train_dataset = AugmentedSubset(full_dataset, train_indices)
|
||||
|
||||
# Create data loaders
|
||||
train_loader = DataLoader(
|
||||
train_dataset,
|
||||
batch_size=batch_size,
|
||||
shuffle=True,
|
||||
num_workers=num_workers,
|
||||
pin_memory=True,
|
||||
)
|
||||
|
||||
val_loader = DataLoader(
|
||||
val_dataset,
|
||||
batch_size=batch_size,
|
||||
shuffle=False,
|
||||
num_workers=num_workers,
|
||||
pin_memory=True,
|
||||
)
|
||||
|
||||
return train_loader, val_loader
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Example usage
|
||||
dataset = HomographyDataset(
|
||||
root_dir=r"C:\Users\admin\Projects\autopilot\datasets\ya_go_maps\images",
|
||||
augment=True,
|
||||
image_size=(256, 256),
|
||||
)
|
||||
|
||||
print(f"Dataset size: {len(dataset)}")
|
||||
|
||||
# Get a sample
|
||||
sample = dataset[0]
|
||||
print(f"Sample keys: {list(sample.keys())}")
|
||||
print(f"Google image shape: {sample['google_img'].shape}")
|
||||
print(f"Yandex image shape: {sample['yandex_img'].shape}")
|
||||
print(f"Homography shape: {sample['homography'].shape}")
|
||||
|
||||
# Create data loaders
|
||||
train_loader, val_loader = create_data_loaders(
|
||||
root_dir=r"C:\Users\admin\Projects\autopilot\datasets\ya_go_maps\images",
|
||||
batch_size=16,
|
||||
train_split=0.8,
|
||||
)
|
||||
|
||||
print(f"Train batches: {len(train_loader)}")
|
||||
print(f"Val batches: {len(val_loader)}")
|
||||
551
models/SiaN/homography_cnn.py
Normal file
551
models/SiaN/homography_cnn.py
Normal file
@@ -0,0 +1,551 @@
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class HomographyCNN(nn.Module):
|
||||
"""
|
||||
CNN model for homography estimation between two images.
|
||||
|
||||
This model takes two images (Google and Yandex maps) as input and
|
||||
outputs a 3x3 homography matrix that transforms one image to align with the other.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_channels: int = 3,
|
||||
hidden_channels: int = 64,
|
||||
num_blocks: int = 4,
|
||||
dropout_rate: float = 0.3,
|
||||
use_batch_norm: bool = True,
|
||||
output_size: int = 9, # Flattened 3x3 homography matrix
|
||||
):
|
||||
"""
|
||||
Initialize the HomographyCNN model.
|
||||
|
||||
Args:
|
||||
input_channels: Number of input channels per image (3 for RGB)
|
||||
hidden_channels: Base number of channels in the network
|
||||
num_blocks: Number of convolutional blocks
|
||||
dropout_rate: Dropout rate for regularization
|
||||
use_batch_norm: Whether to use batch normalization
|
||||
output_size: Size of output vector (9 for flattened 3x3 matrix)
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
self.input_channels = input_channels
|
||||
self.hidden_channels = hidden_channels
|
||||
self.num_blocks = num_blocks
|
||||
self.dropout_rate = dropout_rate
|
||||
self.use_batch_norm = use_batch_norm
|
||||
|
||||
# Feature extraction for each image separately
|
||||
self.google_encoder = self._build_encoder()
|
||||
self.yandex_encoder = self._build_encoder()
|
||||
|
||||
# Fusion layers to combine features from both images
|
||||
self.fusion_layers = self._build_fusion_layers()
|
||||
|
||||
# Regression head for homography estimation
|
||||
self.regression_head = self._build_regression_head(output_size)
|
||||
|
||||
# Initialize weights
|
||||
self._initialize_weights()
|
||||
|
||||
def _build_encoder(self) -> nn.Module:
|
||||
"""Build the encoder network for a single image."""
|
||||
layers = []
|
||||
in_channels = self.input_channels
|
||||
out_channels = self.hidden_channels
|
||||
|
||||
# First convolutional block
|
||||
layers.append(
|
||||
nn.Conv2d(in_channels, out_channels, kernel_size=7, stride=2, padding=3)
|
||||
)
|
||||
if self.use_batch_norm:
|
||||
layers.append(nn.BatchNorm2d(out_channels))
|
||||
layers.append(nn.ReLU(inplace=True))
|
||||
layers.append(nn.MaxPool2d(kernel_size=3, stride=2, padding=1))
|
||||
|
||||
# Additional convolutional blocks
|
||||
for i in range(self.num_blocks):
|
||||
block_in_channels = out_channels
|
||||
block_out_channels = out_channels * 2 if i < 2 else out_channels
|
||||
|
||||
layers.append(
|
||||
ResidualBlock(
|
||||
in_channels=block_in_channels,
|
||||
out_channels=block_out_channels,
|
||||
stride=1 if i == 0 else 2,
|
||||
dropout_rate=self.dropout_rate,
|
||||
use_batch_norm=self.use_batch_norm,
|
||||
)
|
||||
)
|
||||
|
||||
if i < 2:
|
||||
out_channels = block_out_channels
|
||||
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
def _build_fusion_layers(self) -> nn.Module:
|
||||
"""Build layers to fuse features from both images."""
|
||||
# After encoding, each image has hidden_channels * 4 features
|
||||
fused_channels = (
|
||||
self.hidden_channels * 8
|
||||
) # Concatenated features from both images
|
||||
|
||||
layers = [
|
||||
# Reduce dimensionality
|
||||
nn.Conv2d(
|
||||
fused_channels, self.hidden_channels * 4, kernel_size=3, padding=1
|
||||
),
|
||||
nn.BatchNorm2d(self.hidden_channels * 4)
|
||||
if self.use_batch_norm
|
||||
else nn.Identity(),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Dropout2d(self.dropout_rate),
|
||||
# Further processing
|
||||
nn.Conv2d(
|
||||
self.hidden_channels * 4,
|
||||
self.hidden_channels * 2,
|
||||
kernel_size=3,
|
||||
padding=1,
|
||||
),
|
||||
nn.BatchNorm2d(self.hidden_channels * 2)
|
||||
if self.use_batch_norm
|
||||
else nn.Identity(),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Dropout2d(self.dropout_rate),
|
||||
# Global pooling
|
||||
nn.AdaptiveAvgPool2d((1, 1)),
|
||||
]
|
||||
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
def _build_regression_head(self, output_size: int) -> nn.Module:
|
||||
"""Build the regression head for homography estimation."""
|
||||
# Input size after fusion and global pooling
|
||||
input_features = self.hidden_channels * 2
|
||||
|
||||
layers = [
|
||||
nn.Flatten(),
|
||||
nn.Linear(input_features, 512),
|
||||
nn.BatchNorm1d(512) if self.use_batch_norm else nn.Identity(),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Dropout(self.dropout_rate),
|
||||
nn.Linear(512, 256),
|
||||
nn.BatchNorm1d(256) if self.use_batch_norm else nn.Identity(),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Dropout(self.dropout_rate),
|
||||
nn.Linear(256, 128),
|
||||
nn.BatchNorm1d(128) if self.use_batch_norm else nn.Identity(),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Dropout(self.dropout_rate),
|
||||
nn.Linear(128, output_size),
|
||||
]
|
||||
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
def _initialize_weights(self):
|
||||
"""Initialize model weights."""
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
|
||||
if m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
elif isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d):
|
||||
nn.init.constant_(m.weight, 1)
|
||||
nn.init.constant_(m.bias, 0)
|
||||
elif isinstance(m, nn.Linear):
|
||||
nn.init.normal_(m.weight, 0, 0.01)
|
||||
nn.init.constant_(m.bias, 0)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
google_img: torch.Tensor,
|
||||
yandex_img: torch.Tensor,
|
||||
return_matrix: bool = True,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Forward pass of the model.
|
||||
|
||||
Args:
|
||||
google_img: Google map image tensor of shape (B, C, H, W)
|
||||
yandex_img: Yandex map image tensor of shape (B, C, H, W)
|
||||
return_matrix: If True, return 3x3 matrix; if False, return flattened vector
|
||||
|
||||
Returns:
|
||||
Homography matrix tensor of shape (B, 3, 3) or flattened vector of shape (B, 9)
|
||||
"""
|
||||
# Extract features from both images
|
||||
google_features = self.google_encoder(google_img)
|
||||
yandex_features = self.yandex_encoder(yandex_img)
|
||||
|
||||
# Concatenate features along channel dimension
|
||||
combined_features = torch.cat([google_features, yandex_features], dim=1)
|
||||
|
||||
# Fuse features
|
||||
fused_features = self.fusion_layers(combined_features)
|
||||
|
||||
# Regression to get homography parameters
|
||||
homography_flat = self.regression_head(fused_features)
|
||||
|
||||
if return_matrix:
|
||||
# Reshape to 3x3 matrix
|
||||
batch_size = homography_flat.shape[0]
|
||||
homography_matrix = homography_flat.view(batch_size, 3, 3)
|
||||
|
||||
# Ensure the last element is 1 (homogeneous coordinate normalization)
|
||||
# Add small epsilon to prevent division by zero
|
||||
epsilon = 1e-8
|
||||
homography_matrix = homography_matrix / (
|
||||
homography_matrix[:, 2, 2].view(-1, 1, 1) + epsilon
|
||||
)
|
||||
|
||||
return homography_matrix
|
||||
else:
|
||||
return homography_flat
|
||||
|
||||
def predict_homography(
|
||||
self,
|
||||
google_img: torch.Tensor,
|
||||
yandex_img: torch.Tensor,
|
||||
normalize: bool = True,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Predict homography matrix with optional normalization.
|
||||
|
||||
Args:
|
||||
google_img: Google map image tensor
|
||||
yandex_img: Yandex map image tensor
|
||||
normalize: Whether to normalize the homography matrix
|
||||
|
||||
Returns:
|
||||
Predicted homography matrix
|
||||
"""
|
||||
self.eval()
|
||||
with torch.no_grad():
|
||||
homography = self.forward(google_img, yandex_img, return_matrix=True)
|
||||
|
||||
if normalize:
|
||||
# Normalize so that last element is 1
|
||||
homography = homography / homography[:, 2, 2].view(-1, 1, 1)
|
||||
|
||||
return homography
|
||||
|
||||
|
||||
class ResidualBlock(nn.Module):
|
||||
"""Residual block with optional downsampling."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
stride: int = 1,
|
||||
dropout_rate: float = 0.3,
|
||||
use_batch_norm: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.conv1 = nn.Conv2d(
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size=3,
|
||||
stride=stride,
|
||||
padding=1,
|
||||
bias=False,
|
||||
)
|
||||
self.bn1 = nn.BatchNorm2d(out_channels) if use_batch_norm else nn.Identity()
|
||||
self.relu1 = nn.ReLU(inplace=True)
|
||||
self.dropout1 = (
|
||||
nn.Dropout2d(dropout_rate) if dropout_rate > 0 else nn.Identity()
|
||||
)
|
||||
|
||||
self.conv2 = nn.Conv2d(
|
||||
out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False
|
||||
)
|
||||
self.bn2 = nn.BatchNorm2d(out_channels) if use_batch_norm else nn.Identity()
|
||||
self.relu2 = nn.ReLU(inplace=True)
|
||||
self.dropout2 = (
|
||||
nn.Dropout2d(dropout_rate) if dropout_rate > 0 else nn.Identity()
|
||||
)
|
||||
|
||||
# Shortcut connection
|
||||
self.shortcut = nn.Sequential()
|
||||
if stride != 1 or in_channels != out_channels:
|
||||
self.shortcut = nn.Sequential(
|
||||
nn.Conv2d(
|
||||
in_channels, out_channels, kernel_size=1, stride=stride, bias=False
|
||||
),
|
||||
nn.BatchNorm2d(out_channels) if use_batch_norm else nn.Identity(),
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
identity = self.shortcut(x)
|
||||
|
||||
out = self.conv1(x)
|
||||
out = self.bn1(out)
|
||||
out = self.relu1(out)
|
||||
out = self.dropout1(out)
|
||||
|
||||
out = self.conv2(out)
|
||||
out = self.bn2(out)
|
||||
|
||||
out += identity
|
||||
out = self.relu2(out)
|
||||
out = self.dropout2(out)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class HomographyLoss(nn.Module):
|
||||
"""
|
||||
Custom loss function for homography estimation.
|
||||
|
||||
Combines multiple loss terms:
|
||||
1. Matrix element-wise L2 loss
|
||||
2. Geometric consistency loss (warping error)
|
||||
3. Determinant regularization (to prevent degenerate matrices)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
matrix_weight: float = 1.0,
|
||||
geometric_weight: float = 0.5,
|
||||
reg_weight: float = 0.1,
|
||||
grid_size: int = 8,
|
||||
):
|
||||
super().__init__()
|
||||
self.matrix_weight = matrix_weight
|
||||
self.geometric_weight = geometric_weight
|
||||
self.reg_weight = reg_weight
|
||||
self.grid_size = grid_size
|
||||
|
||||
# Create grid of points for geometric loss
|
||||
self.register_buffer(
|
||||
"grid_points",
|
||||
self._create_grid_points(grid_size),
|
||||
persistent=False,
|
||||
)
|
||||
|
||||
def _create_grid_points(self, grid_size: int) -> torch.Tensor:
|
||||
"""Create a grid of points for geometric consistency loss."""
|
||||
x = torch.linspace(-1, 1, grid_size)
|
||||
y = torch.linspace(-1, 1, grid_size)
|
||||
grid_y, grid_x = torch.meshgrid(y, x, indexing="ij")
|
||||
grid_points = torch.stack([grid_x.flatten(), grid_y.flatten()], dim=1)
|
||||
# Add homogeneous coordinate
|
||||
ones = torch.ones(grid_points.shape[0], 1)
|
||||
grid_points = torch.cat([grid_points, ones], dim=1)
|
||||
return grid_points.T # Shape: (3, grid_size*grid_size)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
pred_homography: torch.Tensor,
|
||||
target_homography: torch.Tensor,
|
||||
google_img: Optional[torch.Tensor] = None,
|
||||
yandex_img: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Compute homography loss.
|
||||
|
||||
Args:
|
||||
pred_homography: Predicted homography matrices (B, 3, 3)
|
||||
target_homography: Target homography matrices (B, 3, 3)
|
||||
google_img: Google images (optional, for geometric loss)
|
||||
yandex_img: Yandex images (optional, for geometric loss)
|
||||
|
||||
Returns:
|
||||
Combined loss value
|
||||
"""
|
||||
batch_size = pred_homography.shape[0]
|
||||
|
||||
# 1. Matrix element-wise L2 loss
|
||||
matrix_loss = F.mse_loss(pred_homography, target_homography)
|
||||
|
||||
# 2. Geometric consistency loss (if images provided)
|
||||
geometric_loss = torch.tensor(0.0, device=pred_homography.device)
|
||||
if google_img is not None and yandex_img is not None:
|
||||
# Warp grid points with predicted homography
|
||||
grid_points = self.grid_points.unsqueeze(0).expand(batch_size, -1, -1)
|
||||
warped_points = torch.bmm(pred_homography, grid_points)
|
||||
|
||||
# Normalize homogeneous coordinates
|
||||
warped_points = warped_points / (warped_points[:, 2:3, :] + 1e-8)
|
||||
|
||||
# Warp with target homography for comparison
|
||||
target_warped_points = torch.bmm(target_homography, grid_points)
|
||||
target_warped_points = target_warped_points / (
|
||||
target_warped_points[:, 2:3, :] + 1e-8
|
||||
)
|
||||
|
||||
# Compute point-wise distance
|
||||
geometric_loss = F.mse_loss(
|
||||
warped_points[:, :2, :], target_warped_points[:, :2, :]
|
||||
)
|
||||
|
||||
# 3. Regularization loss (prevent degenerate matrices)
|
||||
# Encourage determinant to be close to 1
|
||||
pred_det = torch.det(pred_homography)
|
||||
reg_loss = F.mse_loss(pred_det, torch.ones_like(pred_det))
|
||||
|
||||
# Combine losses
|
||||
total_loss = (
|
||||
self.matrix_weight * matrix_loss
|
||||
+ self.geometric_weight * geometric_loss
|
||||
+ self.reg_weight * reg_loss
|
||||
)
|
||||
|
||||
return total_loss
|
||||
|
||||
def compute_metrics(
|
||||
self,
|
||||
pred_homography: torch.Tensor,
|
||||
target_homography: torch.Tensor,
|
||||
) -> dict:
|
||||
"""
|
||||
Compute evaluation metrics for homography estimation.
|
||||
|
||||
Args:
|
||||
pred_homography: Predicted homography matrices
|
||||
target_homography: Target homography matrices
|
||||
|
||||
Returns:
|
||||
Dictionary of metrics
|
||||
"""
|
||||
with torch.no_grad():
|
||||
# Normalize matrices
|
||||
pred_norm = pred_homography / pred_homography[:, 2, 2].view(-1, 1, 1)
|
||||
target_norm = target_homography / target_homography[:, 2, 2].view(-1, 1, 1)
|
||||
|
||||
# Matrix L2 error
|
||||
matrix_error = F.mse_loss(pred_norm, target_norm, reduction="none").mean(
|
||||
dim=(1, 2)
|
||||
)
|
||||
|
||||
# Corner error (warp 4 corners of the image)
|
||||
corners = torch.tensor(
|
||||
[[-1, -1, 1], [1, -1, 1], [1, 1, 1], [-1, 1, 1]],
|
||||
dtype=torch.float32,
|
||||
device=pred_homography.device,
|
||||
).T # Shape: (3, 4)
|
||||
|
||||
corners = corners.unsqueeze(0).expand(pred_homography.shape[0], -1, -1)
|
||||
|
||||
pred_corners = torch.bmm(pred_norm, corners)
|
||||
pred_corners = pred_corners / (pred_corners[:, 2:3, :] + 1e-8)
|
||||
|
||||
target_corners = torch.bmm(target_norm, corners)
|
||||
target_corners = target_corners / (target_corners[:, 2:3, :] + 1e-8)
|
||||
|
||||
corner_error = torch.mean(
|
||||
torch.norm(pred_corners[:, :2, :] - target_corners[:, :2, :], dim=1),
|
||||
dim=1,
|
||||
)
|
||||
|
||||
# Average corner error in pixels (assuming image coordinates in [-1, 1])
|
||||
# Convert to pixel error if image size is known
|
||||
avg_corner_error = corner_error.mean().item()
|
||||
|
||||
return {
|
||||
"matrix_mse": matrix_error.mean().item(),
|
||||
"corner_error": avg_corner_error,
|
||||
"corner_error_px": avg_corner_error * 128, # Assuming 256x256 images
|
||||
}
|
||||
|
||||
|
||||
def create_homography_model(
|
||||
model_type: str = "cnn",
|
||||
input_size: Tuple[int, int] = (256, 256),
|
||||
**kwargs,
|
||||
) -> nn.Module:
|
||||
"""
|
||||
Factory function to create homography estimation model.
|
||||
|
||||
Args:
|
||||
model_type: Type of model to create ('cnn' or 'resnet')
|
||||
input_size: Input image size (height, width)
|
||||
**kwargs: Additional arguments passed to model constructor
|
||||
|
||||
Returns:
|
||||
Homography estimation model
|
||||
"""
|
||||
if model_type == "cnn":
|
||||
return HomographyCNN(**kwargs)
|
||||
else:
|
||||
raise ValueError(f"Unknown model type: {model_type}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Test the model
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
print(f"Using device: {device}")
|
||||
|
||||
# Create model
|
||||
model = HomographyCNN(
|
||||
input_channels=3,
|
||||
hidden_channels=64,
|
||||
num_blocks=4,
|
||||
dropout_rate=0.3,
|
||||
use_batch_norm=True,
|
||||
).to(device)
|
||||
|
||||
print(
|
||||
f"Model created with {sum(p.numel() for p in model.parameters()):,} parameters"
|
||||
)
|
||||
|
||||
# Create dummy input
|
||||
batch_size = 4
|
||||
height, width = 256, 256
|
||||
|
||||
google_img = torch.randn(batch_size, 3, height, width).to(device)
|
||||
yandex_img = torch.randn(batch_size, 3, height, width).to(device)
|
||||
|
||||
# Test forward pass
|
||||
print("\nTesting forward pass...")
|
||||
output = model(google_img, yandex_img, return_matrix=True)
|
||||
print(f"Output shape: {output.shape}") # Should be (4, 3, 3)
|
||||
print(f"Sample output:\n{output[0]}")
|
||||
|
||||
# Test prediction
|
||||
print("\nTesting prediction...")
|
||||
pred = model.predict_homography(google_img, yandex_img)
|
||||
print(f"Prediction shape: {pred.shape}")
|
||||
print(f"Last element (should be ~1): {pred[0, 2, 2]:.6f}")
|
||||
|
||||
# Test loss function
|
||||
print("\nTesting loss function...")
|
||||
target_homography = torch.eye(3).unsqueeze(0).expand(batch_size, -1, -1).to(device)
|
||||
loss_fn = HomographyLoss(
|
||||
matrix_weight=1.0,
|
||||
geometric_weight=0.5,
|
||||
reg_weight=0.1,
|
||||
grid_size=8,
|
||||
).to(device)
|
||||
|
||||
loss = loss_fn(output, target_homography, google_img, yandex_img)
|
||||
print(f"Loss value: {loss.item():.6f}")
|
||||
|
||||
# Test metrics
|
||||
print("\nTesting metrics...")
|
||||
metrics = loss_fn.compute_metrics(output, target_homography)
|
||||
for key, value in metrics.items():
|
||||
print(f"{key}: {value:.6f}")
|
||||
|
||||
# Test model factory
|
||||
print("\nTesting model factory...")
|
||||
model2 = create_homography_model(
|
||||
model_type="cnn",
|
||||
input_size=(256, 256),
|
||||
input_channels=3,
|
||||
hidden_channels=32,
|
||||
num_blocks=3,
|
||||
).to(device)
|
||||
|
||||
print(
|
||||
f"Model2 created with {sum(p.numel() for p in model2.parameters()):,} parameters"
|
||||
)
|
||||
|
||||
print("\nAll tests completed successfully!")
|
||||
553
models/SiaN/infer_homography.py
Normal file
553
models/SiaN/infer_homography.py
Normal file
@@ -0,0 +1,553 @@
|
||||
"""
|
||||
Inference script for homography estimation between Google and Yandex map images.
|
||||
|
||||
This script loads a trained homography estimation model and performs inference
|
||||
on new image pairs or the test dataset.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import cv2
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import torch
|
||||
from homography import HomographyDataset
|
||||
from homography_cnn import HomographyCNN, create_homography_model
|
||||
from PIL import Image
|
||||
from torchvision import transforms
|
||||
|
||||
|
||||
class HomographyInference:
|
||||
"""Class for performing inference with homography estimation model."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_path: str,
|
||||
config_path: Optional[str] = None,
|
||||
device: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
Initialize the inference class.
|
||||
|
||||
Args:
|
||||
model_path: Path to trained model checkpoint
|
||||
config_path: Path to model configuration file (optional)
|
||||
device: Device to run inference on ('cuda' or 'cpu')
|
||||
"""
|
||||
# Set device
|
||||
if device is None:
|
||||
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
else:
|
||||
self.device = torch.device(device)
|
||||
|
||||
print(f"Using device: {self.device}")
|
||||
|
||||
# Load configuration
|
||||
if config_path is None:
|
||||
# Try to find config in the same directory as model
|
||||
model_dir = Path(model_path).parent
|
||||
config_path = model_dir / "config.json"
|
||||
|
||||
if os.path.exists(config_path):
|
||||
with open(config_path, "r") as f:
|
||||
self.config = json.load(f)
|
||||
print(f"Loaded configuration from {config_path}")
|
||||
else:
|
||||
# Use default configuration
|
||||
self.config = {
|
||||
"image_size": [256, 256],
|
||||
"hidden_channels": 64,
|
||||
"num_blocks": 4,
|
||||
"dropout_rate": 0.3,
|
||||
"use_batch_norm": True,
|
||||
}
|
||||
print("Using default configuration")
|
||||
|
||||
# Create model
|
||||
self.model = self._create_model()
|
||||
self._load_model(model_path)
|
||||
|
||||
# Set up transforms
|
||||
self.transform = self._create_transforms()
|
||||
|
||||
# Set model to evaluation mode
|
||||
self.model.eval()
|
||||
|
||||
def _create_model(self) -> HomographyCNN:
|
||||
"""Create model based on configuration."""
|
||||
image_size = self.config.get("image_size", [256, 256])
|
||||
|
||||
model = create_homography_model(
|
||||
model_type="cnn",
|
||||
input_size=tuple(image_size),
|
||||
input_channels=3,
|
||||
hidden_channels=self.config.get("hidden_channels", 64),
|
||||
num_blocks=self.config.get("num_blocks", 4),
|
||||
dropout_rate=self.config.get("dropout_rate", 0.3),
|
||||
use_batch_norm=self.config.get("use_batch_norm", True),
|
||||
)
|
||||
|
||||
return model
|
||||
|
||||
def _load_model(self, model_path: str):
|
||||
"""Load model weights from checkpoint."""
|
||||
if not os.path.exists(model_path):
|
||||
raise FileNotFoundError(f"Model file not found: {model_path}")
|
||||
|
||||
# Load checkpoint
|
||||
checkpoint = torch.load(model_path, map_location=self.device)
|
||||
|
||||
if isinstance(checkpoint, dict) and "model_state_dict" in checkpoint:
|
||||
# Trainer checkpoint format
|
||||
self.model.load_state_dict(checkpoint["model_state_dict"])
|
||||
else:
|
||||
# Raw model weights format
|
||||
self.model.load_state_dict(checkpoint)
|
||||
|
||||
self.model = self.model.to(self.device)
|
||||
print(f"Loaded model from {model_path}")
|
||||
|
||||
def _create_transforms(self):
|
||||
"""Create image transforms for inference."""
|
||||
return transforms.Compose(
|
||||
[
|
||||
transforms.Resize(tuple(self.config.get("image_size", [256, 256]))),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize(
|
||||
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
def preprocess_images(
|
||||
self, google_img: Image.Image, yandex_img: Image.Image
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Preprocess images for inference.
|
||||
|
||||
Args:
|
||||
google_img: Google map image (PIL Image)
|
||||
yandex_img: Yandex map image (PIL Image)
|
||||
|
||||
Returns:
|
||||
Tuple of preprocessed image tensors
|
||||
"""
|
||||
# Convert to RGB if needed
|
||||
if google_img.mode != "RGB":
|
||||
google_img = google_img.convert("RGB")
|
||||
if yandex_img.mode != "RGB":
|
||||
yandex_img = yandex_img.convert("RGB")
|
||||
|
||||
# Apply transforms
|
||||
google_tensor = self.transform(google_img).unsqueeze(0) # Add batch dimension
|
||||
yandex_tensor = self.transform(yandex_img).unsqueeze(0)
|
||||
|
||||
return google_tensor, yandex_tensor
|
||||
|
||||
def predict(
|
||||
self,
|
||||
google_img: Image.Image,
|
||||
yandex_img: Image.Image,
|
||||
return_matrix: bool = True,
|
||||
normalize: bool = True,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Predict homography matrix for image pair.
|
||||
|
||||
Args:
|
||||
google_img: Google map image (PIL Image)
|
||||
yandex_img: Yandex map image (PIL Image)
|
||||
return_matrix: If True, return 3x3 matrix; if False, return flattened vector
|
||||
normalize: Whether to normalize the homography matrix
|
||||
|
||||
Returns:
|
||||
Predicted homography matrix or vector
|
||||
"""
|
||||
# Preprocess images
|
||||
google_tensor, yandex_tensor = self.preprocess_images(google_img, yandex_img)
|
||||
|
||||
# Move to device
|
||||
google_tensor = google_tensor.to(self.device)
|
||||
yandex_tensor = yandex_tensor.to(self.device)
|
||||
|
||||
# Perform inference
|
||||
with torch.no_grad():
|
||||
homography = self.model(
|
||||
google_tensor, yandex_tensor, return_matrix=return_matrix
|
||||
)
|
||||
|
||||
if return_matrix and normalize:
|
||||
# Normalize so that last element is 1
|
||||
homography = homography / homography[:, 2, 2].view(-1, 1, 1)
|
||||
|
||||
return homography.squeeze(0) # Remove batch dimension
|
||||
|
||||
def predict_from_paths(
|
||||
self,
|
||||
google_path: str,
|
||||
yandex_path: str,
|
||||
return_matrix: bool = True,
|
||||
normalize: bool = True,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Predict homography matrix from image file paths.
|
||||
|
||||
Args:
|
||||
google_path: Path to Google map image
|
||||
yandex_path: Path to Yandex map image
|
||||
return_matrix: If True, return 3x3 matrix; if False, return flattened vector
|
||||
normalize: Whether to normalize the homography matrix
|
||||
|
||||
Returns:
|
||||
Predicted homography matrix or vector
|
||||
"""
|
||||
# Load images
|
||||
google_img = Image.open(google_path)
|
||||
yandex_img = Image.open(yandex_path)
|
||||
|
||||
return self.predict(google_img, yandex_img, return_matrix, normalize)
|
||||
|
||||
def warp_image(
|
||||
self,
|
||||
img: Image.Image,
|
||||
homography: np.ndarray,
|
||||
output_size: Optional[Tuple[int, int]] = None,
|
||||
) -> Image.Image:
|
||||
"""
|
||||
Warp image using homography matrix.
|
||||
|
||||
Args:
|
||||
img: Input image (PIL Image)
|
||||
homography: 3x3 homography matrix (numpy array)
|
||||
output_size: Output image size (width, height). If None, uses input size.
|
||||
|
||||
Returns:
|
||||
Warped image (PIL Image)
|
||||
"""
|
||||
# Convert to numpy array
|
||||
img_np = np.array(img)
|
||||
|
||||
# Get output size
|
||||
if output_size is None:
|
||||
output_size = (img_np.shape[1], img_np.shape[0])
|
||||
|
||||
# Apply homography transformation
|
||||
warped_np = cv2.warpPerspective(
|
||||
img_np,
|
||||
homography,
|
||||
output_size,
|
||||
flags=cv2.INTER_LINEAR,
|
||||
borderMode=cv2.BORDER_REFLECT,
|
||||
)
|
||||
|
||||
# Convert back to PIL Image
|
||||
return Image.fromarray(warped_np)
|
||||
|
||||
def visualize_alignment(
|
||||
self,
|
||||
google_img: Image.Image,
|
||||
yandex_img: Image.Image,
|
||||
homography: np.ndarray,
|
||||
save_path: Optional[str] = None,
|
||||
show: bool = True,
|
||||
):
|
||||
"""
|
||||
Visualize alignment between images using homography.
|
||||
|
||||
Args:
|
||||
google_img: Google map image
|
||||
yandex_img: Yandex map image
|
||||
homography: Homography matrix
|
||||
save_path: Path to save visualization (optional)
|
||||
show: Whether to display the visualization
|
||||
"""
|
||||
# Warp yandex image to align with google
|
||||
yandex_warped = self.warp_image(yandex_img, homography)
|
||||
|
||||
# Convert images to numpy arrays for visualization
|
||||
google_np = np.array(google_img)
|
||||
yandex_np = np.array(yandex_img)
|
||||
yandex_warped_np = np.array(yandex_warped)
|
||||
|
||||
# Create visualization
|
||||
fig, axes = plt.subplots(2, 2, figsize=(12, 10))
|
||||
|
||||
# Original images
|
||||
axes[0, 0].imshow(google_np)
|
||||
axes[0, 0].set_title("Google Map (Original)")
|
||||
axes[0, 0].axis("off")
|
||||
|
||||
axes[0, 1].imshow(yandex_np)
|
||||
axes[0, 1].set_title("Yandex Map (Original)")
|
||||
axes[0, 1].axis("off")
|
||||
|
||||
# Warped image
|
||||
axes[1, 0].imshow(yandex_warped_np)
|
||||
axes[1, 0].set_title("Yandex Map (Warped)")
|
||||
axes[1, 0].axis("off")
|
||||
|
||||
# Overlay (50% transparency)
|
||||
overlay = cv2.addWeighted(google_np, 0.5, yandex_warped_np, 0.5, 0)
|
||||
axes[1, 1].imshow(overlay)
|
||||
axes[1, 1].set_title("Overlay (Google + Warped Yandex)")
|
||||
axes[1, 1].axis("off")
|
||||
|
||||
plt.tight_layout()
|
||||
|
||||
if save_path:
|
||||
plt.savefig(save_path, dpi=150, bbox_inches="tight")
|
||||
print(f"Visualization saved to {save_path}")
|
||||
|
||||
if show:
|
||||
plt.show()
|
||||
else:
|
||||
plt.close()
|
||||
|
||||
def evaluate_on_dataset(
|
||||
self,
|
||||
dataset_dir: str,
|
||||
num_samples: Optional[int] = None,
|
||||
save_dir: Optional[str] = None,
|
||||
) -> Dict[str, float]:
|
||||
"""
|
||||
Evaluate model on a dataset.
|
||||
|
||||
Args:
|
||||
dataset_dir: Directory containing image pairs
|
||||
num_samples: Number of samples to evaluate (None for all)
|
||||
save_dir: Directory to save visualizations (optional)
|
||||
|
||||
Returns:
|
||||
Dictionary of evaluation metrics
|
||||
"""
|
||||
# Create dataset
|
||||
dataset = HomographyDataset(
|
||||
root_dir=dataset_dir,
|
||||
transform=None, # We'll handle transforms manually
|
||||
augment=False,
|
||||
image_size=tuple(self.config.get("image_size", [256, 256])),
|
||||
cache_homographies=False,
|
||||
)
|
||||
|
||||
if num_samples is not None:
|
||||
indices = list(range(min(num_samples, len(dataset))))
|
||||
else:
|
||||
indices = list(range(len(dataset)))
|
||||
|
||||
errors = []
|
||||
corner_errors = []
|
||||
|
||||
print(f"Evaluating on {len(indices)} samples...")
|
||||
|
||||
for idx in indices:
|
||||
# Get sample without augmentation
|
||||
sample = dataset.get_sample_without_augmentation(idx)
|
||||
|
||||
google_img = sample["google_img"]
|
||||
yandex_img = sample["yandex_img"]
|
||||
target_homography = sample["homography"]
|
||||
|
||||
# Predict homography
|
||||
pred_homography = self.predict(
|
||||
google_img, yandex_img, return_matrix=True, normalize=True
|
||||
)
|
||||
|
||||
# Convert to numpy
|
||||
pred_homography_np = pred_homography.cpu().numpy()
|
||||
target_homography_np = target_homography
|
||||
|
||||
# Compute matrix error
|
||||
matrix_error = np.mean((pred_homography_np - target_homography_np) ** 2)
|
||||
errors.append(matrix_error)
|
||||
|
||||
# Compute corner error
|
||||
corners = np.array(
|
||||
[
|
||||
[-1, -1, 1],
|
||||
[1, -1, 1],
|
||||
[1, 1, 1],
|
||||
[-1, 1, 1],
|
||||
],
|
||||
dtype=np.float32,
|
||||
).T
|
||||
|
||||
pred_corners = pred_homography_np @ corners
|
||||
pred_corners = pred_corners / (pred_corners[2:3, :] + 1e-8)
|
||||
|
||||
target_corners = target_homography_np @ corners
|
||||
target_corners = target_corners / (target_corners[2:3, :] + 1e-8)
|
||||
|
||||
corner_error = np.mean(
|
||||
np.linalg.norm(pred_corners[:2, :] - target_corners[:2, :], axis=0)
|
||||
)
|
||||
corner_errors.append(corner_error)
|
||||
|
||||
# Save visualization if requested
|
||||
if save_dir:
|
||||
os.makedirs(save_dir, exist_ok=True)
|
||||
vis_path = os.path.join(save_dir, f"sample_{idx:04d}.png")
|
||||
self.visualize_alignment(
|
||||
google_img,
|
||||
yandex_img,
|
||||
pred_homography_np,
|
||||
save_path=vis_path,
|
||||
show=False,
|
||||
)
|
||||
|
||||
# Compute metrics
|
||||
metrics = {
|
||||
"mean_matrix_error": float(np.mean(errors)),
|
||||
"std_matrix_error": float(np.std(errors)),
|
||||
"mean_corner_error": float(np.mean(corner_errors)),
|
||||
"std_corner_error": float(np.std(corner_errors)),
|
||||
"median_corner_error": float(np.median(corner_errors)),
|
||||
"max_corner_error": float(np.max(corner_errors)),
|
||||
"min_corner_error": float(np.min(corner_errors)),
|
||||
}
|
||||
|
||||
print("\nEvaluation Results:")
|
||||
for key, value in metrics.items():
|
||||
print(f" {key}: {value:.6f}")
|
||||
|
||||
return metrics
|
||||
|
||||
|
||||
def main():
|
||||
"""Main inference function."""
|
||||
parser = argparse.ArgumentParser(description="Inference for homography estimation")
|
||||
|
||||
# Model arguments
|
||||
parser.add_argument(
|
||||
"--model_path",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to trained model checkpoint",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--config_path",
|
||||
type=str,
|
||||
help="Path to model configuration file (optional)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--device",
|
||||
type=str,
|
||||
choices=["cuda", "cpu"],
|
||||
help="Device to run inference on",
|
||||
)
|
||||
|
||||
# Inference mode
|
||||
parser.add_argument(
|
||||
"--mode",
|
||||
type=str,
|
||||
default="single",
|
||||
choices=["single", "dataset", "batch"],
|
||||
help="Inference mode",
|
||||
)
|
||||
|
||||
# Single image mode
|
||||
parser.add_argument(
|
||||
"--google_path",
|
||||
type=str,
|
||||
help="Path to Google map image (single mode)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--yandex_path",
|
||||
type=str,
|
||||
help="Path to Yandex map image (single mode)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output_vis",
|
||||
type=str,
|
||||
help="Path to save visualization (single mode)",
|
||||
)
|
||||
|
||||
# Dataset mode
|
||||
parser.add_argument(
|
||||
"--dataset_dir",
|
||||
type=str,
|
||||
default=r"C:\Users\admin\Projects\autopilot\datasets\ya_go_maps\images",
|
||||
help="Directory containing image pairs (dataset mode)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num_samples",
|
||||
type=int,
|
||||
help="Number of samples to evaluate (dataset mode)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--save_vis_dir",
|
||||
type=str,
|
||||
help="Directory to save visualizations (dataset mode)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--save_results",
|
||||
type=str,
|
||||
help="Path to save evaluation results (dataset mode)",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Create inference object
|
||||
inference = HomographyInference(
|
||||
model_path=args.model_path,
|
||||
config_path=args.config_path,
|
||||
device=args.device,
|
||||
)
|
||||
|
||||
if args.mode == "single":
|
||||
# Single image pair inference
|
||||
if not args.google_path or not args.yandex_path:
|
||||
raise ValueError(
|
||||
"Both --google_path and --yandex_path are required for single mode"
|
||||
)
|
||||
|
||||
print(f"Processing single image pair:")
|
||||
print(f" Google: {args.google_path}")
|
||||
print(f" Yandex: {args.yandex_path}")
|
||||
|
||||
# Predict homography
|
||||
homography = inference.predict_from_paths(args.google_path, args.yandex_path)
|
||||
|
||||
print(f"\nPredicted homography matrix:")
|
||||
print(homography.cpu().numpy())
|
||||
|
||||
# Visualize alignment
|
||||
if args.output_vis:
|
||||
google_img = Image.open(args.google_path)
|
||||
yandex_img = Image.open(args.yandex_path)
|
||||
inference.visualize_alignment(
|
||||
google_img,
|
||||
yandex_img,
|
||||
homography.cpu().numpy(),
|
||||
save_path=args.output_vis,
|
||||
show=True,
|
||||
)
|
||||
|
||||
elif args.mode == "dataset":
|
||||
# Evaluate on dataset
|
||||
metrics = inference.evaluate_on_dataset(
|
||||
dataset_dir=args.dataset_dir,
|
||||
num_samples=args.num_samples,
|
||||
save_dir=args.save_vis_dir,
|
||||
)
|
||||
|
||||
# Save results if requested
|
||||
if args.save_results:
|
||||
with open(args.save_results, "w") as f:
|
||||
json.dump(metrics, f, indent=2)
|
||||
print(f"\nResults saved to {args.save_results}")
|
||||
|
||||
elif args.mode == "batch":
|
||||
# Batch processing (placeholder for future implementation)
|
||||
print("Batch mode not yet implemented")
|
||||
# Could implement processing multiple image pairs from a directory
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unknown mode: {args.mode}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
611
models/SiaN/train_homography.py
Normal file
611
models/SiaN/train_homography.py
Normal file
@@ -0,0 +1,611 @@
|
||||
"""
|
||||
Training script for homography estimation between Google and Yandex map images.
|
||||
|
||||
This script trains a CNN model to estimate homography matrices that align
|
||||
Google map images with Yandex map images.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
from homography import HomographyDataset, create_data_loaders
|
||||
from homography_cnn import HomographyCNN, HomographyLoss, create_homography_model
|
||||
from torch.utils.data import DataLoader
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
class HomographyTrainer:
|
||||
"""Trainer class for homography estimation model."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: nn.Module,
|
||||
train_loader: DataLoader,
|
||||
val_loader: DataLoader,
|
||||
device: torch.device,
|
||||
config: Dict,
|
||||
):
|
||||
"""
|
||||
Initialize the trainer.
|
||||
|
||||
Args:
|
||||
model: Homography estimation model
|
||||
train_loader: Training data loader
|
||||
val_loader: Validation data loader
|
||||
device: Device to run training on
|
||||
config: Training configuration dictionary
|
||||
"""
|
||||
self.model = model.to(device)
|
||||
self.train_loader = train_loader
|
||||
self.val_loader = val_loader
|
||||
self.device = device
|
||||
self.config = config
|
||||
|
||||
# Loss function
|
||||
self.criterion = HomographyLoss(
|
||||
matrix_weight=config.get("matrix_weight", 1.0),
|
||||
geometric_weight=config.get("geometric_weight", 0.5),
|
||||
reg_weight=config.get("reg_weight", 0.1),
|
||||
grid_size=config.get("grid_size", 8),
|
||||
).to(device)
|
||||
|
||||
# Optimizer
|
||||
optimizer_name = config.get("optimizer", "adam").lower()
|
||||
lr = config.get("learning_rate", 1e-3)
|
||||
weight_decay = config.get("weight_decay", 1e-4)
|
||||
|
||||
if optimizer_name == "adam":
|
||||
self.optimizer = optim.Adam(
|
||||
self.model.parameters(), lr=lr, weight_decay=weight_decay
|
||||
)
|
||||
elif optimizer_name == "adamw":
|
||||
self.optimizer = optim.AdamW(
|
||||
self.model.parameters(), lr=lr, weight_decay=weight_decay
|
||||
)
|
||||
elif optimizer_name == "sgd":
|
||||
self.optimizer = optim.SGD(
|
||||
self.model.parameters(),
|
||||
lr=lr,
|
||||
momentum=config.get("momentum", 0.9),
|
||||
weight_decay=weight_decay,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown optimizer: {optimizer_name}")
|
||||
|
||||
# Learning rate scheduler
|
||||
scheduler_name = config.get("scheduler", "plateau").lower()
|
||||
if scheduler_name == "plateau":
|
||||
self.scheduler = optim.lr_scheduler.ReduceLROnPlateau(
|
||||
self.optimizer,
|
||||
mode="min",
|
||||
factor=config.get("scheduler_factor", 0.5),
|
||||
patience=config.get("scheduler_patience", 5),
|
||||
verbose=True,
|
||||
)
|
||||
elif scheduler_name == "cosine":
|
||||
self.scheduler = optim.lr_scheduler.CosineAnnealingLR(
|
||||
self.optimizer,
|
||||
T_max=config.get("epochs", 100),
|
||||
eta_min=config.get("min_lr", 1e-6),
|
||||
)
|
||||
elif scheduler_name == "step":
|
||||
self.scheduler = optim.lr_scheduler.StepLR(
|
||||
self.optimizer,
|
||||
step_size=config.get("step_size", 30),
|
||||
gamma=config.get("gamma", 0.1),
|
||||
)
|
||||
else:
|
||||
self.scheduler = None
|
||||
|
||||
# Training state
|
||||
self.current_epoch = 0
|
||||
self.best_val_loss = float("inf")
|
||||
self.train_losses: List[float] = []
|
||||
self.val_losses: List[float] = []
|
||||
self.val_metrics: List[Dict] = []
|
||||
|
||||
# Create output directory
|
||||
self.output_dir = Path(config.get("output_dir", "runs/homography"))
|
||||
self.output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# TensorBoard writer
|
||||
self.writer = SummaryWriter(log_dir=self.output_dir / "tensorboard")
|
||||
|
||||
# Save configuration
|
||||
config_path = self.output_dir / "config.json"
|
||||
with open(config_path, "w") as f:
|
||||
json.dump(config, f, indent=2)
|
||||
|
||||
print(f"Training configuration saved to {config_path}")
|
||||
print(
|
||||
f"Model has {sum(p.numel() for p in self.model.parameters()):,} parameters"
|
||||
)
|
||||
|
||||
def train_epoch(self) -> float:
|
||||
"""
|
||||
Train for one epoch.
|
||||
|
||||
Returns:
|
||||
Average training loss for the epoch
|
||||
"""
|
||||
self.model.train()
|
||||
total_loss = 0.0
|
||||
num_batches = len(self.train_loader)
|
||||
|
||||
progress_bar = tqdm(self.train_loader, desc=f"Epoch {self.current_epoch + 1}")
|
||||
for batch_idx, batch in enumerate(progress_bar):
|
||||
# Move data to device
|
||||
google_img = batch["google_img"].to(self.device)
|
||||
yandex_img = batch["yandex_img"].to(self.device)
|
||||
target_homography = batch["homography"].to(self.device)
|
||||
|
||||
# Forward pass
|
||||
self.optimizer.zero_grad()
|
||||
pred_homography = self.model(google_img, yandex_img, return_matrix=True)
|
||||
|
||||
# Compute loss
|
||||
loss = self.criterion(
|
||||
pred_homography,
|
||||
target_homography,
|
||||
google_img,
|
||||
yandex_img,
|
||||
)
|
||||
|
||||
# Backward pass
|
||||
loss.backward()
|
||||
|
||||
# Gradient clipping
|
||||
if self.config.get("grad_clip", 1.0) > 0:
|
||||
torch.nn.utils.clip_grad_norm_(
|
||||
self.model.parameters(),
|
||||
self.config.get("grad_clip", 1.0),
|
||||
)
|
||||
|
||||
# Optimizer step
|
||||
self.optimizer.step()
|
||||
|
||||
# Update statistics
|
||||
total_loss += loss.item()
|
||||
|
||||
# Update progress bar
|
||||
progress_bar.set_postfix({"loss": loss.item()})
|
||||
|
||||
# Log batch loss to TensorBoard
|
||||
global_step = self.current_epoch * num_batches + batch_idx
|
||||
self.writer.add_scalar("train/batch_loss", loss.item(), global_step)
|
||||
|
||||
avg_loss = total_loss / num_batches
|
||||
self.train_losses.append(avg_loss)
|
||||
|
||||
return avg_loss
|
||||
|
||||
@torch.no_grad()
|
||||
def validate(self) -> Tuple[float, Dict]:
|
||||
"""
|
||||
Validate the model.
|
||||
|
||||
Returns:
|
||||
Tuple of (average validation loss, validation metrics)
|
||||
"""
|
||||
self.model.eval()
|
||||
total_loss = 0.0
|
||||
all_metrics = []
|
||||
|
||||
progress_bar = tqdm(self.val_loader, desc="Validation")
|
||||
for batch in progress_bar:
|
||||
# Move data to device
|
||||
google_img = batch["google_img"].to(self.device)
|
||||
yandex_img = batch["yandex_img"].to(self.device)
|
||||
target_homography = batch["homography"].to(self.device)
|
||||
|
||||
# Forward pass
|
||||
pred_homography = self.model(google_img, yandex_img, return_matrix=True)
|
||||
|
||||
# Compute loss
|
||||
loss = self.criterion(
|
||||
pred_homography,
|
||||
target_homography,
|
||||
google_img,
|
||||
yandex_img,
|
||||
)
|
||||
|
||||
# Compute metrics
|
||||
metrics = self.criterion.compute_metrics(pred_homography, target_homography)
|
||||
|
||||
# Update statistics
|
||||
total_loss += loss.item()
|
||||
all_metrics.append(metrics)
|
||||
|
||||
# Update progress bar
|
||||
progress_bar.set_postfix({"loss": loss.item()})
|
||||
|
||||
avg_loss = total_loss / len(self.val_loader)
|
||||
self.val_losses.append(avg_loss)
|
||||
|
||||
# Aggregate metrics
|
||||
avg_metrics = {}
|
||||
for key in all_metrics[0].keys():
|
||||
avg_metrics[key] = np.mean([m[key] for m in all_metrics])
|
||||
|
||||
self.val_metrics.append(avg_metrics)
|
||||
|
||||
return avg_loss, avg_metrics
|
||||
|
||||
def save_checkpoint(self, is_best: bool = False):
|
||||
"""Save model checkpoint."""
|
||||
checkpoint = {
|
||||
"epoch": self.current_epoch,
|
||||
"model_state_dict": self.model.state_dict(),
|
||||
"optimizer_state_dict": self.optimizer.state_dict(),
|
||||
"train_losses": self.train_losses,
|
||||
"val_losses": self.val_losses,
|
||||
"val_metrics": self.val_metrics,
|
||||
"best_val_loss": self.best_val_loss,
|
||||
"config": self.config,
|
||||
}
|
||||
|
||||
if self.scheduler is not None:
|
||||
checkpoint["scheduler_state_dict"] = self.scheduler.state_dict()
|
||||
|
||||
# Save latest checkpoint
|
||||
checkpoint_path = self.output_dir / "checkpoint_latest.pth"
|
||||
torch.save(checkpoint, checkpoint_path)
|
||||
|
||||
# Save best checkpoint
|
||||
if is_best:
|
||||
best_path = self.output_dir / "checkpoint_best.pth"
|
||||
torch.save(checkpoint, best_path)
|
||||
print(f"Best model saved to {best_path}")
|
||||
|
||||
def load_checkpoint(self, checkpoint_path: str):
|
||||
"""Load model checkpoint."""
|
||||
checkpoint = torch.load(checkpoint_path, map_location=self.device)
|
||||
|
||||
self.model.load_state_dict(checkpoint["model_state_dict"])
|
||||
self.optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
|
||||
|
||||
if self.scheduler is not None and "scheduler_state_dict" in checkpoint:
|
||||
self.scheduler.load_state_dict(checkpoint["scheduler_state_dict"])
|
||||
|
||||
self.current_epoch = checkpoint["epoch"]
|
||||
self.train_losses = checkpoint["train_losses"]
|
||||
self.val_losses = checkpoint["val_losses"]
|
||||
self.val_metrics = checkpoint["val_metrics"]
|
||||
self.best_val_loss = checkpoint["best_val_loss"]
|
||||
|
||||
print(f"Loaded checkpoint from epoch {self.current_epoch}")
|
||||
|
||||
def train(self, num_epochs: int):
|
||||
"""
|
||||
Train the model for specified number of epochs.
|
||||
|
||||
Args:
|
||||
num_epochs: Number of epochs to train
|
||||
"""
|
||||
print(f"Starting training for {num_epochs} epochs...")
|
||||
start_time = time.time()
|
||||
|
||||
for epoch in range(num_epochs):
|
||||
self.current_epoch = epoch
|
||||
|
||||
# Train for one epoch
|
||||
train_loss = self.train_epoch()
|
||||
|
||||
# Validate
|
||||
val_loss, val_metrics = self.validate()
|
||||
|
||||
# Update learning rate scheduler
|
||||
if self.scheduler is not None:
|
||||
if isinstance(self.scheduler, optim.lr_scheduler.ReduceLROnPlateau):
|
||||
self.scheduler.step(val_loss)
|
||||
else:
|
||||
self.scheduler.step()
|
||||
|
||||
# Log to TensorBoard
|
||||
self.writer.add_scalar("train/epoch_loss", train_loss, epoch)
|
||||
self.writer.add_scalar("val/epoch_loss", val_loss, epoch)
|
||||
for metric_name, metric_value in val_metrics.items():
|
||||
self.writer.add_scalar(f"val/{metric_name}", metric_value, epoch)
|
||||
|
||||
# Print epoch summary
|
||||
print(f"\nEpoch {epoch + 1}/{num_epochs}:")
|
||||
print(f" Train Loss: {train_loss:.6f}")
|
||||
print(f" Val Loss: {val_loss:.6f}")
|
||||
print(" Val Metrics:")
|
||||
for metric_name, metric_value in val_metrics.items():
|
||||
print(f" {metric_name}: {metric_value:.6f}")
|
||||
|
||||
# Save checkpoint
|
||||
is_best = val_loss < self.best_val_loss
|
||||
if is_best:
|
||||
self.best_val_loss = val_loss
|
||||
|
||||
self.save_checkpoint(is_best=is_best)
|
||||
|
||||
# Early stopping
|
||||
if self.config.get("early_stopping_patience", 0) > 0:
|
||||
if (
|
||||
epoch - np.argmin(self.val_losses)
|
||||
>= self.config["early_stopping_patience"]
|
||||
):
|
||||
print(f"Early stopping at epoch {epoch + 1}")
|
||||
break
|
||||
|
||||
# Training completed
|
||||
training_time = time.time() - start_time
|
||||
print(f"\nTraining completed in {training_time:.2f} seconds")
|
||||
print(f"Best validation loss: {self.best_val_loss:.6f}")
|
||||
|
||||
# Save final model
|
||||
final_model_path = self.output_dir / "model_final.pth"
|
||||
torch.save(self.model.state_dict(), final_model_path)
|
||||
print(f"Final model saved to {final_model_path}")
|
||||
|
||||
# Close TensorBoard writer
|
||||
self.writer.close()
|
||||
|
||||
def evaluate(self, test_loader: Optional[DataLoader] = None) -> Dict:
|
||||
"""
|
||||
Evaluate the model on test data.
|
||||
|
||||
Args:
|
||||
test_loader: Test data loader (uses validation loader if None)
|
||||
|
||||
Returns:
|
||||
Dictionary of evaluation metrics
|
||||
"""
|
||||
if test_loader is None:
|
||||
test_loader = self.val_loader
|
||||
|
||||
self.model.eval()
|
||||
all_metrics = []
|
||||
|
||||
print("Evaluating model...")
|
||||
with torch.no_grad():
|
||||
for batch in tqdm(test_loader):
|
||||
# Move data to device
|
||||
google_img = batch["google_img"].to(self.device)
|
||||
yandex_img = batch["yandex_img"].to(self.device)
|
||||
target_homography = batch["homography"].to(self.device)
|
||||
|
||||
# Forward pass
|
||||
pred_homography = self.model(google_img, yandex_img, return_matrix=True)
|
||||
|
||||
# Compute metrics
|
||||
metrics = self.criterion.compute_metrics(
|
||||
pred_homography, target_homography
|
||||
)
|
||||
all_metrics.append(metrics)
|
||||
|
||||
# Aggregate metrics
|
||||
avg_metrics = {}
|
||||
for key in all_metrics[0].keys():
|
||||
avg_metrics[key] = np.mean([m[key] for m in all_metrics])
|
||||
|
||||
# Print evaluation results
|
||||
print("\nEvaluation Results:")
|
||||
for metric_name, metric_value in avg_metrics.items():
|
||||
print(f" {metric_name}: {metric_value:.6f}")
|
||||
|
||||
# Save evaluation results
|
||||
eval_path = self.output_dir / "evaluation_results.json"
|
||||
with open(eval_path, "w") as f:
|
||||
json.dump(avg_metrics, f, indent=2)
|
||||
print(f"Evaluation results saved to {eval_path}")
|
||||
|
||||
return avg_metrics
|
||||
|
||||
|
||||
def parse_args():
|
||||
"""Parse command line arguments."""
|
||||
parser = argparse.ArgumentParser(description="Train homography estimation model")
|
||||
|
||||
# Data arguments
|
||||
parser.add_argument(
|
||||
"--data_dir",
|
||||
type=str,
|
||||
default=r"C:\Users\admin\Projects\autopilot\datasets\ya_go_maps\images",
|
||||
help="Directory containing image pairs",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--batch_size", type=int, default=32, help="Batch size for training"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--image_size",
|
||||
type=int,
|
||||
nargs=2,
|
||||
default=[256, 256],
|
||||
help="Image size (height width)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--train_split", type=float, default=0.8, help="Train/validation split ratio"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num_workers", type=int, default=4, help="Number of data loader workers"
|
||||
)
|
||||
|
||||
# Model arguments
|
||||
parser.add_argument(
|
||||
"--model_type", type=str, default="cnn", choices=["cnn"], help="Model type"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--hidden_channels", type=int, default=64, help="Number of hidden channels"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num_blocks", type=int, default=4, help="Number of convolutional blocks"
|
||||
)
|
||||
parser.add_argument("--dropout_rate", type=float, default=0.3, help="Dropout rate")
|
||||
parser.add_argument(
|
||||
"--use_batch_norm", action="store_true", help="Use batch normalization"
|
||||
)
|
||||
|
||||
# Training arguments
|
||||
parser.add_argument("--epochs", type=int, default=100, help="Number of epochs")
|
||||
parser.add_argument("--lr", type=float, default=1e-3, help="Learning rate")
|
||||
parser.add_argument("--weight_decay", type=float, default=1e-4, help="Weight decay")
|
||||
parser.add_argument(
|
||||
"--optimizer", type=str, default="adam", choices=["adam", "adamw", "sgd"]
|
||||
)
|
||||
parser.add_argument(
|
||||
"--scheduler",
|
||||
type=str,
|
||||
default="plateau",
|
||||
choices=["plateau", "cosine", "step", "none"],
|
||||
)
|
||||
parser.add_argument(
|
||||
"--grad_clip", type=float, default=1.0, help="Gradient clipping value"
|
||||
)
|
||||
|
||||
# Loss arguments
|
||||
parser.add_argument(
|
||||
"--matrix_weight", type=float, default=1.0, help="Weight for matrix loss"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--geometric_weight",
|
||||
type=float,
|
||||
default=0.5,
|
||||
help="Weight for geometric loss",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--reg_weight", type=float, default=0.1, help="Weight for regularization loss"
|
||||
)
|
||||
|
||||
# Other arguments
|
||||
parser.add_argument(
|
||||
"--output_dir",
|
||||
type=str,
|
||||
default="runs/homography",
|
||||
help="Output directory for checkpoints and logs",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--resume",
|
||||
type=str,
|
||||
help="Path to checkpoint to resume training from",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--eval_only",
|
||||
action="store_true",
|
||||
help="Only evaluate the model (no training)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--seed", type=int, default=42, help="Random seed for reproducibility"
|
||||
)
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def main():
|
||||
"""Main training function."""
|
||||
args = parse_args()
|
||||
|
||||
# Set random seeds for reproducibility
|
||||
torch.manual_seed(args.seed)
|
||||
np.random.seed(args.seed)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed(args.seed)
|
||||
torch.cuda.manual_seed_all(args.seed)
|
||||
|
||||
# Set device
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
print(f"Using device: {device}")
|
||||
|
||||
# Create data loaders
|
||||
print("Creating data loaders...")
|
||||
train_loader, val_loader = create_data_loaders(
|
||||
root_dir=args.data_dir,
|
||||
batch_size=args.batch_size,
|
||||
train_split=args.train_split,
|
||||
num_workers=args.num_workers,
|
||||
image_size=tuple(args.image_size),
|
||||
augment_train=True,
|
||||
augment_val=False,
|
||||
)
|
||||
|
||||
print(f"Train batches: {len(train_loader)}")
|
||||
print(f"Val batches: {len(val_loader)}")
|
||||
|
||||
# Create model
|
||||
print("Creating model...")
|
||||
model = create_homography_model(
|
||||
model_type=args.model_type,
|
||||
input_size=tuple(args.image_size),
|
||||
input_channels=3,
|
||||
hidden_channels=args.hidden_channels,
|
||||
num_blocks=args.num_blocks,
|
||||
dropout_rate=args.dropout_rate,
|
||||
use_batch_norm=args.use_batch_norm,
|
||||
)
|
||||
|
||||
# Create trainer configuration
|
||||
config = {
|
||||
# Model config
|
||||
"model_type": args.model_type,
|
||||
"hidden_channels": args.hidden_channels,
|
||||
"num_blocks": args.num_blocks,
|
||||
"dropout_rate": args.dropout_rate,
|
||||
"use_batch_norm": args.use_batch_norm,
|
||||
"image_size": args.image_size,
|
||||
# Training config
|
||||
"epochs": args.epochs,
|
||||
"batch_size": args.batch_size,
|
||||
"learning_rate": args.lr,
|
||||
"weight_decay": args.weight_decay,
|
||||
"optimizer": args.optimizer,
|
||||
"scheduler": args.scheduler,
|
||||
"grad_clip": args.grad_clip,
|
||||
# Loss config
|
||||
"matrix_weight": args.matrix_weight,
|
||||
"geometric_weight": args.geometric_weight,
|
||||
"reg_weight": args.reg_weight,
|
||||
"grid_size": 8,
|
||||
# Data config
|
||||
"data_dir": args.data_dir,
|
||||
"train_split": args.train_split,
|
||||
"num_workers": args.num_workers,
|
||||
# Output config
|
||||
"output_dir": args.output_dir,
|
||||
"seed": args.seed,
|
||||
}
|
||||
|
||||
# Create trainer
|
||||
trainer = HomographyTrainer(
|
||||
model=model,
|
||||
train_loader=train_loader,
|
||||
val_loader=val_loader,
|
||||
device=device,
|
||||
config=config,
|
||||
)
|
||||
|
||||
# Resume from checkpoint if specified
|
||||
if args.resume:
|
||||
print(f"Resuming from checkpoint: {args.resume}")
|
||||
trainer.load_checkpoint(args.resume)
|
||||
|
||||
# Evaluate only mode
|
||||
if args.eval_only:
|
||||
trainer.evaluate()
|
||||
return
|
||||
|
||||
# Train the model
|
||||
trainer.train(num_epochs=args.epochs)
|
||||
|
||||
# Final evaluation
|
||||
print("\nPerforming final evaluation...")
|
||||
trainer.evaluate()
|
||||
|
||||
print("\nTraining completed successfully!")
|
||||
print(f"Results saved to: {args.output_dir}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
611
models/SiaN/train_homography_.py
Normal file
611
models/SiaN/train_homography_.py
Normal file
@@ -0,0 +1,611 @@
|
||||
"""
|
||||
Training script for homography estimation between Google and Yandex map images.
|
||||
|
||||
This script trains a CNN model to estimate homography matrices that align
|
||||
Google map images with Yandex map images.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
from homography import HomographyDataset, create_data_loaders
|
||||
from homography_cnn import HomographyCNN, HomographyLoss, create_homography_model
|
||||
from torch.utils.data import DataLoader
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
class HomographyTrainer:
|
||||
"""Trainer class for homography estimation model."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: nn.Module,
|
||||
train_loader: DataLoader,
|
||||
val_loader: DataLoader,
|
||||
device: torch.device,
|
||||
config: Dict,
|
||||
):
|
||||
"""
|
||||
Initialize the trainer.
|
||||
|
||||
Args:
|
||||
model: Homography estimation model
|
||||
train_loader: Training data loader
|
||||
val_loader: Validation data loader
|
||||
device: Device to run training on
|
||||
config: Training configuration dictionary
|
||||
"""
|
||||
self.model = model.to(device)
|
||||
self.train_loader = train_loader
|
||||
self.val_loader = val_loader
|
||||
self.device = device
|
||||
self.config = config
|
||||
|
||||
# Loss function
|
||||
self.criterion = HomographyLoss(
|
||||
matrix_weight=config.get("matrix_weight", 1.0),
|
||||
geometric_weight=config.get("geometric_weight", 0.5),
|
||||
reg_weight=config.get("reg_weight", 0.1),
|
||||
grid_size=config.get("grid_size", 8),
|
||||
).to(device)
|
||||
|
||||
# Optimizer
|
||||
optimizer_name = config.get("optimizer", "adam").lower()
|
||||
lr = config.get("learning_rate", 1e-3)
|
||||
weight_decay = config.get("weight_decay", 1e-4)
|
||||
|
||||
if optimizer_name == "adam":
|
||||
self.optimizer = optim.Adam(
|
||||
self.model.parameters(), lr=lr, weight_decay=weight_decay
|
||||
)
|
||||
elif optimizer_name == "adamw":
|
||||
self.optimizer = optim.AdamW(
|
||||
self.model.parameters(), lr=lr, weight_decay=weight_decay
|
||||
)
|
||||
elif optimizer_name == "sgd":
|
||||
self.optimizer = optim.SGD(
|
||||
self.model.parameters(),
|
||||
lr=lr,
|
||||
momentum=config.get("momentum", 0.9),
|
||||
weight_decay=weight_decay,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown optimizer: {optimizer_name}")
|
||||
|
||||
# Learning rate scheduler
|
||||
scheduler_name = config.get("scheduler", "plateau").lower()
|
||||
if scheduler_name == "plateau":
|
||||
self.scheduler = optim.lr_scheduler.ReduceLROnPlateau(
|
||||
self.optimizer,
|
||||
mode="min",
|
||||
factor=config.get("scheduler_factor", 0.5),
|
||||
patience=config.get("scheduler_patience", 5),
|
||||
verbose=True,
|
||||
)
|
||||
elif scheduler_name == "cosine":
|
||||
self.scheduler = optim.lr_scheduler.CosineAnnealingLR(
|
||||
self.optimizer,
|
||||
T_max=config.get("epochs", 100),
|
||||
eta_min=config.get("min_lr", 1e-6),
|
||||
)
|
||||
elif scheduler_name == "step":
|
||||
self.scheduler = optim.lr_scheduler.StepLR(
|
||||
self.optimizer,
|
||||
step_size=config.get("step_size", 30),
|
||||
gamma=config.get("gamma", 0.1),
|
||||
)
|
||||
else:
|
||||
self.scheduler = None
|
||||
|
||||
# Training state
|
||||
self.current_epoch = 0
|
||||
self.best_val_loss = float("inf")
|
||||
self.train_losses: List[float] = []
|
||||
self.val_losses: List[float] = []
|
||||
self.val_metrics: List[Dict] = []
|
||||
|
||||
# Create output directory
|
||||
self.output_dir = Path(config.get("output_dir", "runs/homography"))
|
||||
self.output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# TensorBoard writer
|
||||
self.writer = SummaryWriter(log_dir=self.output_dir / "tensorboard")
|
||||
|
||||
# Save configuration
|
||||
config_path = self.output_dir / "config.json"
|
||||
with open(config_path, "w") as f:
|
||||
json.dump(config, f, indent=2)
|
||||
|
||||
print(f"Training configuration saved to {config_path}")
|
||||
print(
|
||||
f"Model has {sum(p.numel() for p in self.model.parameters()):,} parameters"
|
||||
)
|
||||
|
||||
def train_epoch(self) -> float:
|
||||
"""
|
||||
Train for one epoch.
|
||||
|
||||
Returns:
|
||||
Average training loss for the epoch
|
||||
"""
|
||||
self.model.train()
|
||||
total_loss = 0.0
|
||||
num_batches = len(self.train_loader)
|
||||
|
||||
progress_bar = tqdm(self.train_loader, desc=f"Epoch {self.current_epoch + 1}")
|
||||
for batch_idx, batch in enumerate(progress_bar):
|
||||
# Move data to device
|
||||
google_img = batch["google_img"].to(self.device)
|
||||
yandex_img = batch["yandex_img"].to(self.device)
|
||||
target_homography = batch["homography"].to(self.device)
|
||||
|
||||
# Forward pass
|
||||
self.optimizer.zero_grad()
|
||||
pred_homography = self.model(google_img, yandex_img, return_matrix=True)
|
||||
|
||||
# Compute loss
|
||||
loss = self.criterion(
|
||||
pred_homography,
|
||||
target_homography,
|
||||
google_img,
|
||||
yandex_img,
|
||||
)
|
||||
|
||||
# Backward pass
|
||||
loss.backward()
|
||||
|
||||
# Gradient clipping
|
||||
if self.config.get("grad_clip", 1.0) > 0:
|
||||
torch.nn.utils.clip_grad_norm_(
|
||||
self.model.parameters(),
|
||||
self.config.get("grad_clip", 1.0),
|
||||
)
|
||||
|
||||
# Optimizer step
|
||||
self.optimizer.step()
|
||||
|
||||
# Update statistics
|
||||
total_loss += loss.item()
|
||||
|
||||
# Update progress bar
|
||||
progress_bar.set_postfix({"loss": loss.item()})
|
||||
|
||||
# Log batch loss to TensorBoard
|
||||
global_step = self.current_epoch * num_batches + batch_idx
|
||||
self.writer.add_scalar("train/batch_loss", loss.item(), global_step)
|
||||
|
||||
avg_loss = total_loss / num_batches
|
||||
self.train_losses.append(avg_loss)
|
||||
|
||||
return avg_loss
|
||||
|
||||
@torch.no_grad()
|
||||
def validate(self) -> Tuple[float, Dict]:
|
||||
"""
|
||||
Validate the model.
|
||||
|
||||
Returns:
|
||||
Tuple of (average validation loss, validation metrics)
|
||||
"""
|
||||
self.model.eval()
|
||||
total_loss = 0.0
|
||||
all_metrics = []
|
||||
|
||||
progress_bar = tqdm(self.val_loader, desc="Validation")
|
||||
for batch in progress_bar:
|
||||
# Move data to device
|
||||
google_img = batch["google_img"].to(self.device)
|
||||
yandex_img = batch["yandex_img"].to(self.device)
|
||||
target_homography = batch["homography"].to(self.device)
|
||||
|
||||
# Forward pass
|
||||
pred_homography = self.model(google_img, yandex_img, return_matrix=True)
|
||||
|
||||
# Compute loss
|
||||
loss = self.criterion(
|
||||
pred_homography,
|
||||
target_homography,
|
||||
google_img,
|
||||
yandex_img,
|
||||
)
|
||||
|
||||
# Compute metrics
|
||||
metrics = self.criterion.compute_metrics(pred_homography, target_homography)
|
||||
|
||||
# Update statistics
|
||||
total_loss += loss.item()
|
||||
all_metrics.append(metrics)
|
||||
|
||||
# Update progress bar
|
||||
progress_bar.set_postfix({"loss": loss.item()})
|
||||
|
||||
avg_loss = total_loss / len(self.val_loader)
|
||||
self.val_losses.append(avg_loss)
|
||||
|
||||
# Aggregate metrics
|
||||
avg_metrics = {}
|
||||
for key in all_metrics[0].keys():
|
||||
avg_metrics[key] = np.mean([m[key] for m in all_metrics])
|
||||
|
||||
self.val_metrics.append(avg_metrics)
|
||||
|
||||
return avg_loss, avg_metrics
|
||||
|
||||
def save_checkpoint(self, is_best: bool = False):
|
||||
"""Save model checkpoint."""
|
||||
checkpoint = {
|
||||
"epoch": self.current_epoch,
|
||||
"model_state_dict": self.model.state_dict(),
|
||||
"optimizer_state_dict": self.optimizer.state_dict(),
|
||||
"train_losses": self.train_losses,
|
||||
"val_losses": self.val_losses,
|
||||
"val_metrics": self.val_metrics,
|
||||
"best_val_loss": self.best_val_loss,
|
||||
"config": self.config,
|
||||
}
|
||||
|
||||
if self.scheduler is not None:
|
||||
checkpoint["scheduler_state_dict"] = self.scheduler.state_dict()
|
||||
|
||||
# Save latest checkpoint
|
||||
checkpoint_path = self.output_dir / "checkpoint_latest.pth"
|
||||
torch.save(checkpoint, checkpoint_path)
|
||||
|
||||
# Save best checkpoint
|
||||
if is_best:
|
||||
best_path = self.output_dir / "checkpoint_best.pth"
|
||||
torch.save(checkpoint, best_path)
|
||||
print(f"Best model saved to {best_path}")
|
||||
|
||||
def load_checkpoint(self, checkpoint_path: str):
|
||||
"""Load model checkpoint."""
|
||||
checkpoint = torch.load(checkpoint_path, map_location=self.device)
|
||||
|
||||
self.model.load_state_dict(checkpoint["model_state_dict"])
|
||||
self.optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
|
||||
|
||||
if self.scheduler is not None and "scheduler_state_dict" in checkpoint:
|
||||
self.scheduler.load_state_dict(checkpoint["scheduler_state_dict"])
|
||||
|
||||
self.current_epoch = checkpoint["epoch"]
|
||||
self.train_losses = checkpoint["train_losses"]
|
||||
self.val_losses = checkpoint["val_losses"]
|
||||
self.val_metrics = checkpoint["val_metrics"]
|
||||
self.best_val_loss = checkpoint["best_val_loss"]
|
||||
|
||||
print(f"Loaded checkpoint from epoch {self.current_epoch}")
|
||||
|
||||
def train(self, num_epochs: int):
|
||||
"""
|
||||
Train the model for specified number of epochs.
|
||||
|
||||
Args:
|
||||
num_epochs: Number of epochs to train
|
||||
"""
|
||||
print(f"Starting training for {num_epochs} epochs...")
|
||||
start_time = time.time()
|
||||
|
||||
for epoch in range(num_epochs):
|
||||
self.current_epoch = epoch
|
||||
|
||||
# Train for one epoch
|
||||
train_loss = self.train_epoch()
|
||||
|
||||
# Validate
|
||||
val_loss, val_metrics = self.validate()
|
||||
|
||||
# Update learning rate scheduler
|
||||
if self.scheduler is not None:
|
||||
if isinstance(self.scheduler, optim.lr_scheduler.ReduceLROnPlateau):
|
||||
self.scheduler.step(val_loss)
|
||||
else:
|
||||
self.scheduler.step()
|
||||
|
||||
# Log to TensorBoard
|
||||
self.writer.add_scalar("train/epoch_loss", train_loss, epoch)
|
||||
self.writer.add_scalar("val/epoch_loss", val_loss, epoch)
|
||||
for metric_name, metric_value in val_metrics.items():
|
||||
self.writer.add_scalar(f"val/{metric_name}", metric_value, epoch)
|
||||
|
||||
# Print epoch summary
|
||||
print(f"\nEpoch {epoch + 1}/{num_epochs}:")
|
||||
print(f" Train Loss: {train_loss:.6f}")
|
||||
print(f" Val Loss: {val_loss:.6f}")
|
||||
print(" Val Metrics:")
|
||||
for metric_name, metric_value in val_metrics.items():
|
||||
print(f" {metric_name}: {metric_value:.6f}")
|
||||
|
||||
# Save checkpoint
|
||||
is_best = val_loss < self.best_val_loss
|
||||
if is_best:
|
||||
self.best_val_loss = val_loss
|
||||
|
||||
self.save_checkpoint(is_best=is_best)
|
||||
|
||||
# Early stopping
|
||||
if self.config.get("early_stopping_patience", 0) > 0:
|
||||
if (
|
||||
epoch - np.argmin(self.val_losses)
|
||||
>= self.config["early_stopping_patience"]
|
||||
):
|
||||
print(f"Early stopping at epoch {epoch + 1}")
|
||||
break
|
||||
|
||||
# Training completed
|
||||
training_time = time.time() - start_time
|
||||
print(f"\nTraining completed in {training_time:.2f} seconds")
|
||||
print(f"Best validation loss: {self.best_val_loss:.6f}")
|
||||
|
||||
# Save final model
|
||||
final_model_path = self.output_dir / "model_final.pth"
|
||||
torch.save(self.model.state_dict(), final_model_path)
|
||||
print(f"Final model saved to {final_model_path}")
|
||||
|
||||
# Close TensorBoard writer
|
||||
self.writer.close()
|
||||
|
||||
def evaluate(self, test_loader: Optional[DataLoader] = None) -> Dict:
|
||||
"""
|
||||
Evaluate the model on test data.
|
||||
|
||||
Args:
|
||||
test_loader: Test data loader (uses validation loader if None)
|
||||
|
||||
Returns:
|
||||
Dictionary of evaluation metrics
|
||||
"""
|
||||
if test_loader is None:
|
||||
test_loader = self.val_loader
|
||||
|
||||
self.model.eval()
|
||||
all_metrics = []
|
||||
|
||||
print("Evaluating model...")
|
||||
with torch.no_grad():
|
||||
for batch in tqdm(test_loader):
|
||||
# Move data to device
|
||||
google_img = batch["google_img"].to(self.device)
|
||||
yandex_img = batch["yandex_img"].to(self.device)
|
||||
target_homography = batch["homography"].to(self.device)
|
||||
|
||||
# Forward pass
|
||||
pred_homography = self.model(google_img, yandex_img, return_matrix=True)
|
||||
|
||||
# Compute metrics
|
||||
metrics = self.criterion.compute_metrics(
|
||||
pred_homography, target_homography
|
||||
)
|
||||
all_metrics.append(metrics)
|
||||
|
||||
# Aggregate metrics
|
||||
avg_metrics = {}
|
||||
for key in all_metrics[0].keys():
|
||||
avg_metrics[key] = np.mean([m[key] for m in all_metrics])
|
||||
|
||||
# Print evaluation results
|
||||
print("\nEvaluation Results:")
|
||||
for metric_name, metric_value in avg_metrics.items():
|
||||
print(f" {metric_name}: {metric_value:.6f}")
|
||||
|
||||
# Save evaluation results
|
||||
eval_path = self.output_dir / "evaluation_results.json"
|
||||
with open(eval_path, "w") as f:
|
||||
json.dump(avg_metrics, f, indent=2)
|
||||
print(f"Evaluation results saved to {eval_path}")
|
||||
|
||||
return avg_metrics
|
||||
|
||||
|
||||
def parse_args():
|
||||
"""Parse command line arguments."""
|
||||
parser = argparse.ArgumentParser(description="Train homography estimation model")
|
||||
|
||||
# Data arguments
|
||||
parser.add_argument(
|
||||
"--data_dir",
|
||||
type=str,
|
||||
default=r"C:\Users\admin\Projects\autopilot\datasets\ya_go_maps\images",
|
||||
help="Directory containing image pairs",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--batch_size", type=int, default=32, help="Batch size for training"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--image_size",
|
||||
type=int,
|
||||
nargs=2,
|
||||
default=[256, 256],
|
||||
help="Image size (height width)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--train_split", type=float, default=0.8, help="Train/validation split ratio"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num_workers", type=int, default=4, help="Number of data loader workers"
|
||||
)
|
||||
|
||||
# Model arguments
|
||||
parser.add_argument(
|
||||
"--model_type", type=str, default="cnn", choices=["cnn"], help="Model type"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--hidden_channels", type=int, default=64, help="Number of hidden channels"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num_blocks", type=int, default=4, help="Number of convolutional blocks"
|
||||
)
|
||||
parser.add_argument("--dropout_rate", type=float, default=0.3, help="Dropout rate")
|
||||
parser.add_argument(
|
||||
"--use_batch_norm", action="store_true", help="Use batch normalization"
|
||||
)
|
||||
|
||||
# Training arguments
|
||||
parser.add_argument("--epochs", type=int, default=100, help="Number of epochs")
|
||||
parser.add_argument("--lr", type=float, default=1e-3, help="Learning rate")
|
||||
parser.add_argument("--weight_decay", type=float, default=1e-4, help="Weight decay")
|
||||
parser.add_argument(
|
||||
"--optimizer", type=str, default="adam", choices=["adam", "adamw", "sgd"]
|
||||
)
|
||||
parser.add_argument(
|
||||
"--scheduler",
|
||||
type=str,
|
||||
default="plateau",
|
||||
choices=["plateau", "cosine", "step", "none"],
|
||||
)
|
||||
parser.add_argument(
|
||||
"--grad_clip", type=float, default=1.0, help="Gradient clipping value"
|
||||
)
|
||||
|
||||
# Loss arguments
|
||||
parser.add_argument(
|
||||
"--matrix_weight", type=float, default=1.0, help="Weight for matrix loss"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--geometric_weight",
|
||||
type=float,
|
||||
default=0.5,
|
||||
help="Weight for geometric loss",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--reg_weight", type=float, default=0.1, help="Weight for regularization loss"
|
||||
)
|
||||
|
||||
# Other arguments
|
||||
parser.add_argument(
|
||||
"--output_dir",
|
||||
type=str,
|
||||
default="runs/homography",
|
||||
help="Output directory for checkpoints and logs",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--resume",
|
||||
type=str,
|
||||
help="Path to checkpoint to resume training from",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--eval_only",
|
||||
action="store_true",
|
||||
help="Only evaluate the model (no training)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--seed", type=int, default=42, help="Random seed for reproducibility"
|
||||
)
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def main():
|
||||
"""Main training function."""
|
||||
args = parse_args()
|
||||
|
||||
# Set random seeds for reproducibility
|
||||
torch.manual_seed(args.seed)
|
||||
np.random.seed(args.seed)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed(args.seed)
|
||||
torch.cuda.manual_seed_all(args.seed)
|
||||
|
||||
# Set device
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
print(f"Using device: {device}")
|
||||
|
||||
# Create data loaders
|
||||
print("Creating data loaders...")
|
||||
train_loader, val_loader = create_data_loaders(
|
||||
root_dir=args.data_dir,
|
||||
batch_size=args.batch_size,
|
||||
train_split=args.train_split,
|
||||
num_workers=args.num_workers,
|
||||
image_size=tuple(args.image_size),
|
||||
augment_train=True,
|
||||
augment_val=False,
|
||||
)
|
||||
|
||||
print(f"Train batches: {len(train_loader)}")
|
||||
print(f"Val batches: {len(val_loader)}")
|
||||
|
||||
# Create model
|
||||
print("Creating model...")
|
||||
model = create_homography_model(
|
||||
model_type=args.model_type,
|
||||
input_size=tuple(args.image_size),
|
||||
input_channels=3,
|
||||
hidden_channels=args.hidden_channels,
|
||||
num_blocks=args.num_blocks,
|
||||
dropout_rate=args.dropout_rate,
|
||||
use_batch_norm=args.use_batch_norm,
|
||||
)
|
||||
|
||||
# Create trainer configuration
|
||||
config = {
|
||||
# Model config
|
||||
"model_type": args.model_type,
|
||||
"hidden_channels": args.hidden_channels,
|
||||
"num_blocks": args.num_blocks,
|
||||
"dropout_rate": args.dropout_rate,
|
||||
"use_batch_norm": args.use_batch_norm,
|
||||
"image_size": args.image_size,
|
||||
# Training config
|
||||
"epochs": args.epochs,
|
||||
"batch_size": args.batch_size,
|
||||
"learning_rate": args.lr,
|
||||
"weight_decay": args.weight_decay,
|
||||
"optimizer": args.optimizer,
|
||||
"scheduler": args.scheduler,
|
||||
"grad_clip": args.grad_clip,
|
||||
# Loss config
|
||||
"matrix_weight": args.matrix_weight,
|
||||
"geometric_weight": args.geometric_weight,
|
||||
"reg_weight": args.reg_weight,
|
||||
"grid_size": 8,
|
||||
# Data config
|
||||
"data_dir": args.data_dir,
|
||||
"train_split": args.train_split,
|
||||
"num_workers": args.num_workers,
|
||||
# Output config
|
||||
"output_dir": args.output_dir,
|
||||
"seed": args.seed,
|
||||
}
|
||||
|
||||
# Create trainer
|
||||
trainer = HomographyTrainer(
|
||||
model=model,
|
||||
train_loader=train_loader,
|
||||
val_loader=val_loader,
|
||||
device=device,
|
||||
config=config,
|
||||
)
|
||||
|
||||
# Resume from checkpoint if specified
|
||||
if args.resume:
|
||||
print(f"Resuming from checkpoint: {args.resume}")
|
||||
trainer.load_checkpoint(args.resume)
|
||||
|
||||
# Evaluate only mode
|
||||
if args.eval_only:
|
||||
trainer.evaluate()
|
||||
return
|
||||
|
||||
# Train the model
|
||||
trainer.train(num_epochs=args.epochs)
|
||||
|
||||
# Final evaluation
|
||||
print("\nPerforming final evaluation...")
|
||||
trainer.evaluate()
|
||||
|
||||
print("\nTraining completed successfully!")
|
||||
print(f"Results saved to: {args.output_dir}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user