diff --git a/datasets/ya_go_maps/generate_dataset.py b/datasets/ya_go_maps/generate_dataset.py new file mode 100644 index 0000000..7a63e87 --- /dev/null +++ b/datasets/ya_go_maps/generate_dataset.py @@ -0,0 +1,43 @@ +import numpy as np +from pathlib import Path +from ...google_map import GoogleMap +from ...simulator import Simulator +from ...yandex_map import YandexMap + +LAT_MIN, LAT_MAX = 44.960236, 54.967830 +LON_MIN, LON_MAX = 53.084167, 58.677977 + +def create_new_asset(yandex_map, google_map): + folder = Path('dataset_ya_go_maps') + + id = 0 + print(id) + while (folder / f"{id:0{4}}_google.png").exists(): + id += 1 + + google_file = folder / f"{id:0{4}}_google.png" + yandex_file = folder / f"{id:0{4}}_yandex.png" + + lat = np.random.rand() * (LAT_MAX - LAT_MIN) + LAT_MIN + lon = np.random.rand() * (LON_MAX - LON_MIN) + LON_MIN + + yandex_map.open(lat, lon, 18) + google_map.open(lat, lon, 18) + + simulator = Simulator() + simulator._apply_perspective_transform(yandex_map.make_screenshot()).save(yandex_file) + simulator._apply_perspective_transform(google_map.make_screenshot()).save(google_file) + +def main(): + folder = Path('dataset_ya_go_maps') + if not folder.exists(): + folder.mkdir() + + yandex_map = YandexMap(initial_zoom=15) + google_map = GoogleMap(initial_zoom=15) + + for i in range(4): + create_new_asset(yandex_map, google_map) + +if __name__ == '__main__': + main() \ No newline at end of file diff --git a/models/SiaN/.gitignore b/models/SiaN/.gitignore new file mode 100644 index 0000000..e69de29 diff --git a/models/SiaN/README_homography.md b/models/SiaN/README_homography.md new file mode 100644 index 0000000..9862e24 --- /dev/null +++ b/models/SiaN/README_homography.md @@ -0,0 +1,295 @@ +# Homography Estimation System + +This system provides a complete pipeline for estimating homography matrices between Google and Yandex map images using deep learning. + +## Overview + +Homography estimation is crucial for aligning images from different sources (Google Maps and Yandex Maps in this case). The system includes: + +1. **Dataset handling** - Loading and preprocessing image pairs +2. **Data augmentation** - Homography-based augmentation for robust training +3. **CNN model** - Deep learning model for homography estimation +4. **Training pipeline** - Complete training and evaluation workflow +5. **Inference tools** - Tools for using trained models on new data + +## Installation + +### Prerequisites +- Python 3.8+ +- PyTorch 1.9+ +- OpenCV +- PIL/Pillow +- NumPy + +### Install dependencies +```bash +pip install torch torchvision opencv-python pillow numpy matplotlib tqdm tensorboard +``` + +## Dataset Structure + +The system expects image pairs in the following format: +``` +dataset/ +├── 0000_google.png +├── 0000_yandex.png +├── 0001_google.png +├── 0001_yandex.png +└── ... +``` + +Each pair consists of: +- `{idx:04d}_google.png` - Google map image +- `{idx:04d}_yandex.png` - Yandex map image + +## Quick Start + +### 1. Explore the dataset +```python +from models.homography import HomographyDataset + +dataset = HomographyDataset( + root_dir=r"C:\Users\admin\Projects\autopilot\datasets\ya_go_maps\images", + augment=True, + image_size=(256, 256) +) + +print(f"Found {len(dataset)} image pairs") +sample = dataset[0] +print(f"Sample homography matrix:\n{sample['homography'].numpy()}") +``` + +### 2. Train a model +```bash +python models/train_homography.py \ + --data_dir "C:\Users\admin\Projects\autopilot\datasets\ya_go_maps\images" \ + --epochs 50 \ + --batch_size 16 \ + --lr 1e-3 \ + --output_dir "runs/my_experiment" +``` + +### 3. Perform inference +```bash +python models/infer_homography.py \ + --model_path "runs/my_experiment/checkpoint_best.pth" \ + --mode single \ + --google_path "path/to/google.png" \ + --yandex_path "path/to/yandex.png" \ + --output_vis "alignment_result.png" +``` + +## File Structure + +``` +models/ +├── homography.py # Dataset class and data loaders +├── homography_cnn.py # CNN model architecture +├── train_homography.py # Training script +├── infer_homography.py # Inference script +├── example_homography.py # Example usage +├── homography.ipynb # Jupyter notebook (empty) +└── README_homography.md # This file +``` + +## Model Architecture + +The homography estimation model (`HomographyCNN`) consists of: + +1. **Dual encoders** - Separate feature extraction for Google and Yandex images +2. **Residual blocks** - For deep feature learning +3. **Fusion layers** - Combine features from both images +4. **Regression head** - Predict 3x3 homography matrix + +### Key Features: +- Residual connections for stable training +- Batch normalization options +- Dropout for regularization +- Geometric consistency loss + +## Training Configuration + +Default training parameters: +- **Optimizer**: Adam with learning rate 1e-3 +- **Loss function**: Combined matrix + geometric + regularization loss +- **Batch size**: 32 +- **Image size**: 256x256 +- **Train/val split**: 80/20 + +## Inference Modes + +The inference script supports three modes: + +### 1. Single image pair +```bash +python infer_homography.py --mode single \ + --google_path google.png --yandex_path yandex.png +``` + +### 2. Dataset evaluation +```bash +python infer_homography.py --mode dataset \ + --dataset_dir path/to/dataset --num_samples 100 +``` + +### 3. Batch processing +```bash +python infer_homography.py --mode batch \ + --input_dir path/to/input --output_dir path/to/output +``` + +## Evaluation Metrics + +The system computes several metrics: +- **Matrix MSE**: Mean squared error of homography matrix elements +- **Corner error**: Average pixel error at image corners +- **Geometric consistency**: Warping error across grid points + +## Data Augmentation + +The dataset applies homography-based augmentation: +- Random rotation (-30° to 30°) +- Random scaling (0.8x to 1.2x) +- Random translation (-50 to 50 pixels) +- Small perspective distortion + +## Integration with Autopilot System + +The homography estimation can be integrated into the autopilot system: + +```python +from models.infer_homography import HomographyInference + +# Initialize inference +inference = HomographyInference(model_path="path/to/model.pth") + +# During flight loop +homography = inference.predict(google_img, yandex_img) +# Use homography to update drone position +``` + +## Troubleshooting + +### Common Issues + +1. **Out of memory** + - Reduce batch size + - Use smaller image size + - Enable gradient checkpointing + +2. **Poor convergence** + - Adjust learning rate + - Increase model capacity + - Add more data augmentation + +3. **Inference errors** + - Check image formats (must be RGB) + - Verify model was trained with same image size + - Ensure proper normalization + +### Debugging Tips + +- Use `example_homography.py` to test components +- Enable TensorBoard for training visualization +- Check homography matrix normalization (last element should be ~1) + +## Advanced Usage + +### Custom Model Architecture +```python +from models.homography_cnn import create_homography_model + +model = create_homography_model( + model_type="cnn", + input_size=(512, 512), + hidden_channels=128, + num_blocks=6, + dropout_rate=0.5 +) +``` + +### Custom Loss Function +```python +from models.homography_cnn import HomographyLoss + +loss_fn = HomographyLoss( + matrix_weight=0.7, + geometric_weight=0.3, + reg_weight=0.05, + grid_size=16 +) +``` + +### Transfer Learning +```python +# Load pretrained model +checkpoint = torch.load("pretrained.pth") +model.load_state_dict(checkpoint["model_state_dict"]) + +# Fine-tune on new data +for param in model.google_encoder.parameters(): + param.requires_grad = False # Freeze encoder +``` + +## Performance Tips + +1. **Training speed** + - Use multiple GPU workers for data loading + - Enable mixed precision training + - Use gradient accumulation for larger effective batch sizes + +2. **Memory efficiency** + - Use gradient checkpointing + - Implement progressive resizing + - Use memory-efficient optimizers + +3. **Inference speed** + - Use TensorRT or ONNX for deployment + - Implement model quantization + - Use batch inference when possible + +## Future Improvements + +1. **Model enhancements** + - Transformer-based architecture + - Multi-scale feature fusion + - Uncertainty estimation + +2. **Training improvements** + - Self-supervised pre-training + - Curriculum learning + - Adversarial training + +3. **Deployment features** + - Real-time inference optimization + - Mobile deployment + - Web API + +## Citation + +If you use this system in your research, please cite: + +```bibtex +@software{homography_estimation_2024, + title = {Homography Estimation for Map Alignment}, + author = {Autopilot Team}, + year = {2024}, + url = {https://github.com/your-repo/homography-estimation} +} +``` + +## License + +This project is licensed under the MIT License - see the LICENSE file for details. + +## Support + +For questions and issues: +1. Check the troubleshooting section +2. Review the example scripts +3. Open an issue on GitHub +4. Contact the development team + +--- + +*Last updated: October 2024* \ No newline at end of file diff --git a/models/SiaN/example_homography.py b/models/SiaN/example_homography.py new file mode 100644 index 0000000..079bf40 --- /dev/null +++ b/models/SiaN/example_homography.py @@ -0,0 +1,345 @@ +""" +Example script demonstrating the complete homography estimation workflow. + +This script shows how to: +1. Load the ya_go_maps dataset +2. Create and train a homography estimation model +3. Perform inference on new image pairs +4. Visualize results +""" + +import os +import sys +from pathlib import Path + +# Add parent directory to path for imports +sys.path.append(str(Path(__file__).parent.parent)) + +import matplotlib.pyplot as plt +import numpy as np +import torch +from PIL import Image + +from models.homography import HomographyDataset, create_data_loaders +from models.homography_cnn import HomographyCNN, HomographyLoss, create_homography_model +from models.infer_homography import HomographyInference +from models.train_homography import HomographyTrainer + + +def example_dataset_loading(): + """Example 1: Loading and exploring the dataset.""" + print("=" * 60) + print("Example 1: Loading and exploring the dataset") + print("=" * 60) + + # Path to the dataset + dataset_path = r"C:\Users\admin\Projects\autopilot\datasets\ya_go_maps\images" + + # Create dataset + dataset = HomographyDataset( + root_dir=dataset_path, + augment=True, + image_size=(256, 256), + cache_homographies=True, + ) + + print(f"Dataset size: {len(dataset)} image pairs") + + # Get a sample + sample = dataset[0] + print(f"\nSample keys: {list(sample.keys())}") + print(f"Google image shape: {sample['google_img'].shape}") + print(f"Yandex image shape: {sample['yandex_img'].shape}") + print(f"Homography shape: {sample['homography'].shape}") + + # Show sample homography matrix + print(f"\nSample homography matrix:") + print(sample["homography"].numpy()) + + # Get sample without augmentation for visualization + raw_sample = dataset.get_sample_without_augmentation(0) + + # Visualize sample + fig, axes = plt.subplots(1, 2, figsize=(10, 5)) + + axes[0].imshow(raw_sample["google_img"]) + axes[0].set_title("Google Map") + axes[0].axis("off") + + axes[1].imshow(raw_sample["yandex_img"]) + axes[1].set_title("Yandex Map") + axes[1].axis("off") + + plt.suptitle("Sample Image Pair (without augmentation)") + plt.tight_layout() + plt.show() + + return dataset + + +def example_model_creation(): + """Example 2: Creating and testing the model.""" + print("\n" + "=" * 60) + print("Example 2: Creating and testing the model") + print("=" * 60) + + # Set device + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + print(f"Using device: {device}") + + # Create model + model = HomographyCNN( + input_channels=3, + hidden_channels=64, + num_blocks=4, + dropout_rate=0.3, + use_batch_norm=True, + ).to(device) + + print( + f"Model created with {sum(p.numel() for p in model.parameters()):,} parameters" + ) + + # Create dummy input + batch_size = 2 + height, width = 256, 256 + + google_img = torch.randn(batch_size, 3, height, width).to(device) + yandex_img = torch.randn(batch_size, 3, height, width).to(device) + + # Test forward pass + print("\nTesting forward pass...") + output = model(google_img, yandex_img, return_matrix=True) + print(f"Output shape: {output.shape}") # Should be (2, 3, 3) + print(f"Sample output matrix:") + print(output[0].cpu().detach().numpy()) + + # Test loss function + print("\nTesting loss function...") + target_homography = torch.eye(3).unsqueeze(0).expand(batch_size, -1, -1).to(device) + loss_fn = HomographyLoss( + matrix_weight=1.0, + geometric_weight=0.5, + reg_weight=0.1, + grid_size=8, + ).to(device) + + loss = loss_fn(output, target_homography, google_img, yandex_img) + print(f"Loss value: {loss.item():.6f}") + + # Test metrics + print("\nTesting metrics...") + metrics = loss_fn.compute_metrics(output, target_homography) + for key, value in metrics.items(): + print(f"{key}: {value:.6f}") + + return model, loss_fn + + +def example_data_loaders(): + """Example 3: Creating data loaders for training.""" + print("\n" + "=" * 60) + print("Example 3: Creating data loaders for training") + print("=" * 60) + + # Path to the dataset + dataset_path = r"C:\Users\admin\Projects\autopilot\datasets\ya_go_maps\images" + + # Create data loaders + train_loader, val_loader = create_data_loaders( + root_dir=dataset_path, + batch_size=16, + train_split=0.8, + num_workers=0, # Use 0 for debugging, increase for training + image_size=(256, 256), + augment_train=True, + augment_val=False, + ) + + print(f"Train batches: {len(train_loader)}") + print(f"Val batches: {len(val_loader)}") + + # Get a batch from train loader + batch = next(iter(train_loader)) + print(f"\nBatch keys: {list(batch.keys())}") + print(f"Batch size: {batch['google_img'].shape[0]}") + print(f"Image shape: {batch['google_img'].shape[1:]}") + + return train_loader, val_loader + + +def example_training_config(): + """Example 4: Setting up training configuration.""" + print("\n" + "=" * 60) + print("Example 4: Training configuration") + print("=" * 60) + + # Training configuration + config = { + # Model config + "model_type": "cnn", + "hidden_channels": 64, + "num_blocks": 4, + "dropout_rate": 0.3, + "use_batch_norm": True, + "image_size": [256, 256], + # Training config + "epochs": 50, + "batch_size": 16, + "learning_rate": 1e-3, + "weight_decay": 1e-4, + "optimizer": "adam", + "scheduler": "plateau", + "grad_clip": 1.0, + # Loss config + "matrix_weight": 1.0, + "geometric_weight": 0.5, + "reg_weight": 0.1, + "grid_size": 8, + # Data config + "data_dir": r"C:\Users\admin\Projects\autopilot\datasets\ya_go_maps\images", + "train_split": 0.8, + "num_workers": 0, + # Output config + "output_dir": "runs/example_training", + "seed": 42, + } + + print("Training configuration:") + for key, value in config.items(): + print(f" {key}: {value}") + + return config + + +def example_inference(): + """Example 5: Performing inference with a trained model.""" + print("\n" + "=" * 60) + print("Example 5: Inference example") + print("=" * 60) + + # Note: This example assumes you have a trained model + # For demonstration, we'll show the code structure + + print("Inference workflow:") + print("1. Load trained model") + print("2. Preprocess input images") + print("3. Predict homography matrix") + print("4. Visualize alignment") + + # Example code structure (commented out since we don't have a trained model yet) + """ + # Create inference object + inference = HomographyInference( + model_path="runs/homography/checkpoint_best.pth", + device="cuda" if torch.cuda.is_available() else "cpu", + ) + + # Load images + google_img = Image.open("path/to/google.png") + yandex_img = Image.open("path/to/yandex.png") + + # Predict homography + homography = inference.predict(google_img, yandex_img) + + print(f"Predicted homography matrix:") + print(homography.cpu().numpy()) + + # Visualize alignment + inference.visualize_alignment( + google_img, + yandex_img, + homography.cpu().numpy(), + save_path="alignment_visualization.png", + show=True, + ) + """ + + print( + "\nNote: To run actual inference, first train a model using train_homography.py" + ) + + +def example_quick_start(): + """Example 6: Quick start guide.""" + print("\n" + "=" * 60) + print("Example 6: Quick Start Guide") + print("=" * 60) + + print("\nTo get started with homography estimation:") + print("\n1. First, explore the dataset:") + print( + ' python -c "from models.homography import HomographyDataset; ' + "dataset = HomographyDataset(r'C:\\Users\\admin\\Projects\\autopilot\\datasets\\ya_go_maps\\images'); " + "print(f'Found {len(dataset)} image pairs')\"" + ) + + print("\n2. Train a model:") + print(" python models/train_homography.py \\") + print( + ' --data_dir "C:\\Users\\admin\\Projects\\autopilot\\datasets\\ya_go_maps\\images" \\' + ) + print(" --epochs 50 \\") + print(" --batch_size 16 \\") + print(' --output_dir "runs/my_experiment"') + + print("\n3. Perform inference on a single image pair:") + print(" python models/infer_homography.py \\") + print(' --model_path "runs/my_experiment/checkpoint_best.pth" \\') + print(" --mode single \\") + print(' --google_path "path/to/google.png" \\') + print(' --yandex_path "path/to/yandex.png" \\') + print(' --output_vis "alignment_result.png"') + + print("\n4. Evaluate on the entire dataset:") + print(" python models/infer_homography.py \\") + print(' --model_path "runs/my_experiment/checkpoint_best.pth" \\') + print(" --mode dataset \\") + print( + ' --dataset_dir "C:\\Users\\admin\\Projects\\autopilot\\datasets\\ya_go_maps\\images" \\' + ) + print(' --save_results "evaluation_results.json"') + + +def main(): + """Run all examples.""" + print("Homography Estimation Workflow Examples") + print("=" * 60) + + try: + # Example 1: Dataset loading + dataset = example_dataset_loading() + + # Example 2: Model creation + model, loss_fn = example_model_creation() + + # Example 3: Data loaders + train_loader, val_loader = example_data_loaders() + + # Example 4: Training configuration + config = example_training_config() + + # Example 5: Inference + example_inference() + + # Example 6: Quick start + example_quick_start() + + print("\n" + "=" * 60) + print("All examples completed successfully!") + print("=" * 60) + + print("\nNext steps:") + print("1. Run the training script to train a model") + print("2. Use the inference script to test the trained model") + print("3. Integrate the homography estimation into your autopilot system") + + except Exception as e: + print(f"\nError running examples: {e}") + import traceback + + traceback.print_exc() + + +if __name__ == "__main__": + main() diff --git a/models/SiaN/homography.ipynb b/models/SiaN/homography.ipynb new file mode 100644 index 0000000..1f4a657 --- /dev/null +++ b/models/SiaN/homography.ipynb @@ -0,0 +1,1679 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 19, + "id": "92144cc0", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Found 327 image pairs in C:\\Users\\admin\\Projects\\autopilot\\datasets\\ya_go_maps\\images\n", + "Dataset size: 327\n", + "Sample keys: ['google_img', 'yandex_img', 'homography', 'idx']\n", + "Google image shape: torch.Size([3, 700, 700])\n", + "Yandex image shape: torch.Size([3, 700, 700])\n", + "Homography shape: torch.Size([3, 3])\n", + "Found 327 image pairs in C:\\Users\\admin\\Projects\\autopilot\\datasets\\ya_go_maps\\images\n", + "Train batches: 17\n", + "Val batches: 5\n" + ] + } + ], + "source": [ + "import os\n", + "import random\n", + "from typing import Any, Dict, List, Optional, Tuple\n", + "\n", + "import cv2\n", + "import numpy as np\n", + "import torch\n", + "from PIL import Image\n", + "from torch.utils.data import DataLoader, Dataset\n", + "\n", + "\n", + "class HomographyDataset(Dataset):\n", + " \"\"\"\n", + " Dataset for homography estimation between Yandex and Google map image pairs.\n", + "\n", + " This dataset loads pairs of images (Yandex and Google maps) and provides\n", + " homography matrices for data augmentation and training.\n", + " \"\"\"\n", + "\n", + " def __init__(\n", + " self,\n", + " root_dir: str,\n", + " transform=None,\n", + " augment: bool = True,\n", + " max_samples: Optional[int] = None,\n", + " image_size: Tuple[int, int] = (700, 700),\n", + " cache_homographies: bool = True,\n", + " ):\n", + " \"\"\"\n", + " Initialize the HomographyDataset.\n", + "\n", + " Args:\n", + " root_dir: Directory containing image pairs (format: {idx:04d}_google.png, {idx:04d}_yandex.png)\n", + " transform: Optional torchvision transforms to apply\n", + " augment: Whether to apply homography-based data augmentation\n", + " max_samples: Maximum number of samples to load (None for all)\n", + " image_size: Target size for images (height, width)\n", + " cache_homographies: Whether to cache generated homography matrices to disk\n", + " \"\"\"\n", + " self.root_dir = root_dir\n", + " self.transform = transform\n", + " self.augment = augment\n", + " self.image_size = image_size\n", + " self.cache_homographies = cache_homographies\n", + "\n", + " # Find all image pairs\n", + " self.image_pairs = self._discover_image_pairs()\n", + "\n", + " if max_samples is not None:\n", + " self.image_pairs = self.image_pairs[:max_samples]\n", + "\n", + " print(f\"Found {len(self.image_pairs)} image pairs in {root_dir}\")\n", + "\n", + " # Create directory for cached homographies if needed\n", + " if cache_homographies:\n", + " self.homography_cache_dir = os.path.join(root_dir, \"homography_cache\")\n", + " os.makedirs(self.homography_cache_dir, exist_ok=True)\n", + "\n", + " def _discover_image_pairs(self) -> List[Dict[str, Any]]:\n", + " \"\"\"Discover all Google-Yandex image pairs in the dataset directory.\"\"\"\n", + " image_pairs = []\n", + "\n", + " # Get all Google images\n", + " google_files = [\n", + " f for f in os.listdir(self.root_dir) if f.endswith(\"_google.png\")\n", + " ]\n", + "\n", + " for google_file in sorted(google_files):\n", + " # Extract index from filename\n", + " idx_str = google_file.split(\"_\")[0]\n", + " try:\n", + " idx = int(idx_str)\n", + " except ValueError:\n", + " continue\n", + "\n", + " # Check if corresponding Yandex image exists\n", + " yandex_file = f\"{idx:04d}_yandex.png\"\n", + " yandex_path = os.path.join(self.root_dir, yandex_file)\n", + "\n", + " if os.path.exists(yandex_path):\n", + " image_pairs.append(\n", + " {\n", + " \"idx\": idx,\n", + " \"google_path\": os.path.join(self.root_dir, google_file),\n", + " \"yandex_path\": yandex_path,\n", + " }\n", + " )\n", + "\n", + " return image_pairs\n", + "\n", + " def __len__(self) -> int:\n", + " \"\"\"Return the number of image pairs in the dataset.\"\"\"\n", + " return len(self.image_pairs)\n", + "\n", + " def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:\n", + " \"\"\"\n", + " Get a sample from the dataset.\n", + "\n", + " Returns a dictionary with:\n", + " - 'google_img': Google map image tensor\n", + " - 'yandex_img': Yandex map image tensor\n", + " - 'homography': Ground truth homography matrix (3x3)\n", + " - 'idx': Sample index\n", + " \"\"\"\n", + " pair_info = self.image_pairs[idx]\n", + "\n", + " # Load images\n", + " google_img = Image.open(pair_info[\"google_path\"]).convert(\"RGB\")\n", + " yandex_img = Image.open(pair_info[\"yandex_path\"]).convert(\"RGB\")\n", + "\n", + " # Resize images to target size\n", + " google_img = google_img.resize(\n", + " (self.image_size[1], self.image_size[0]), Image.BILINEAR\n", + " )\n", + " yandex_img = yandex_img.resize(\n", + " (self.image_size[1], self.image_size[0]), Image.BILINEAR\n", + " )\n", + "\n", + " # Get or generate homography matrix\n", + " homography_matrix = self._get_homography_matrix(pair_info[\"idx\"])\n", + "\n", + " # Apply data augmentation if enabled\n", + " if self.augment:\n", + " google_img, yandex_img, homography_matrix = self._apply_augmentation(\n", + " google_img, yandex_img, homography_matrix\n", + " )\n", + "\n", + " # Convert images to tensors\n", + " if self.transform:\n", + " google_img = self.transform(google_img)\n", + " yandex_img = self.transform(yandex_img)\n", + " else:\n", + " # Default conversion to tensor\n", + " google_img = (\n", + " torch.from_numpy(np.array(google_img)).float().permute(2, 0, 1) / 255.0\n", + " )\n", + " yandex_img = (\n", + " torch.from_numpy(np.array(yandex_img)).float().permute(2, 0, 1) / 255.0\n", + " )\n", + "\n", + " # Convert homography to tensor\n", + " homography_tensor = torch.from_numpy(homography_matrix).float()\n", + "\n", + " return {\n", + " \"google_img\": google_img,\n", + " \"yandex_img\": yandex_img,\n", + " \"homography\": homography_tensor,\n", + " \"idx\": torch.tensor(pair_info[\"idx\"], dtype=torch.long),\n", + " }\n", + "\n", + " def _get_homography_matrix(self, idx: int) -> np.ndarray:\n", + " \"\"\"\n", + " Get homography matrix for a given index.\n", + "\n", + " If cached homography exists, load it. Otherwise generate a new one.\n", + " \"\"\"\n", + " if self.cache_homographies:\n", + " cache_path = os.path.join(\n", + " self.homography_cache_dir, f\"{idx:04d}_homography.npy\"\n", + " )\n", + " if os.path.exists(cache_path):\n", + " return np.load(cache_path)\n", + "\n", + " # Generate new homography matrix\n", + " homography_matrix = self.generate_random_homography()\n", + "\n", + " # Cache if enabled\n", + " if self.cache_homographies:\n", + " np.save(cache_path, homography_matrix)\n", + "\n", + " return homography_matrix\n", + "\n", + " def generate_random_homography(self) -> np.ndarray:\n", + " \"\"\"\n", + " Generate a random homography matrix for data augmentation.\n", + "\n", + " Returns:\n", + " np.ndarray: 3x3 homography matrix.\n", + " \"\"\"\n", + " # Generate random affine transformation parameters\n", + " angle = np.random.uniform(-30, 30) # rotation in degrees\n", + " scale = np.random.uniform(0.8, 1.2) # scaling factor\n", + " tx = np.random.uniform(-50, 50) # translation in x\n", + " ty = np.random.uniform(-50, 50) # translation in y\n", + "\n", + " # Convert angle to radians\n", + " theta = np.radians(angle)\n", + "\n", + " # Create affine transformation matrix\n", + " affine_matrix = np.array(\n", + " [\n", + " [scale * np.cos(theta), -scale * np.sin(theta), tx],\n", + " [scale * np.sin(theta), scale * np.cos(theta), ty],\n", + " [0, 0, 1],\n", + " ]\n", + " )\n", + "\n", + " # Add small perspective distortion\n", + " perspective = np.random.uniform(-0.001, 0.001, (2, 3))\n", + " perspective = np.vstack([perspective, [0, 0, 0]])\n", + "\n", + " homography_matrix = affine_matrix + perspective\n", + "\n", + " return homography_matrix\n", + "\n", + " def _apply_augmentation(\n", + " self,\n", + " google_img: Image.Image,\n", + " yandex_img: Image.Image,\n", + " base_homography: np.ndarray,\n", + " ) -> Tuple[Image.Image, Image.Image, np.ndarray]:\n", + " \"\"\"\n", + " Apply homography-based data augmentation to image pair.\n", + "\n", + " Args:\n", + " google_img: Google map image\n", + " yandex_img: Yandex map image\n", + " base_homography: Base homography matrix\n", + "\n", + " Returns:\n", + " Tuple of (augmented_google_img, augmented_yandex_img, augmented_homography)\n", + " \"\"\"\n", + " # Generate augmentation homography\n", + " aug_homography = self.generate_random_homography()\n", + "\n", + " # Combine with base homography\n", + " combined_homography = aug_homography @ base_homography\n", + "\n", + " # Apply augmentation to both images\n", + " google_aug = self._apply_homography_to_image(google_img, aug_homography)\n", + " yandex_aug = self._apply_homography_to_image(yandex_img, aug_homography)\n", + "\n", + " return google_aug, yandex_aug, combined_homography\n", + "\n", + " def _apply_homography_to_image(\n", + " self, img: Image.Image, homography: np.ndarray\n", + " ) -> Image.Image:\n", + " \"\"\"\n", + " Apply homography transformation to a single image.\n", + "\n", + " Args:\n", + " img: PIL Image to transform\n", + " homography: 3x3 homography matrix\n", + "\n", + " Returns:\n", + " Transformed PIL Image\n", + " \"\"\"\n", + " # Convert to numpy array\n", + " img_np = np.array(img)\n", + "\n", + " # Get image dimensions\n", + " h, w = img_np.shape[:2]\n", + "\n", + " # Apply homography transformation\n", + " transformed = cv2.warpPerspective(\n", + " img_np,\n", + " homography,\n", + " (w, h),\n", + " flags=cv2.INTER_LINEAR,\n", + " borderMode=cv2.BORDER_REFLECT,\n", + " )\n", + "\n", + " # Convert back to PIL Image\n", + " return Image.fromarray(transformed)\n", + "\n", + " def get_sample_without_augmentation(self, idx: int) -> Dict[str, Any]:\n", + " \"\"\"\n", + " Get a sample without data augmentation.\n", + "\n", + " Useful for visualization and evaluation.\n", + " \"\"\"\n", + " pair_info = self.image_pairs[idx]\n", + "\n", + " # Load images\n", + " google_img = Image.open(pair_info[\"google_path\"]).convert(\"RGB\")\n", + " yandex_img = Image.open(pair_info[\"yandex_path\"]).convert(\"RGB\")\n", + "\n", + " # Resize\n", + " google_img = google_img.resize(\n", + " (self.image_size[1], self.image_size[0]), Image.BILINEAR\n", + " )\n", + " yandex_img = yandex_img.resize(\n", + " (self.image_size[1], self.image_size[0]), Image.BILINEAR\n", + " )\n", + "\n", + " # Get homography matrix\n", + " homography_matrix = self._get_homography_matrix(pair_info[\"idx\"])\n", + "\n", + " return {\n", + " \"google_img\": google_img,\n", + " \"yandex_img\": yandex_img,\n", + " \"homography\": homography_matrix,\n", + " \"idx\": pair_info[\"idx\"],\n", + " \"google_path\": pair_info[\"google_path\"],\n", + " \"yandex_path\": pair_info[\"yandex_path\"],\n", + " }\n", + "\n", + "\n", + "def create_data_loaders(\n", + " root_dir: str,\n", + " batch_size: int = 32,\n", + " train_split: float = 0.8,\n", + " num_workers: int = 4,\n", + " image_size: Tuple[int, int] = (256, 256),\n", + " augment_train: bool = True,\n", + " augment_val: bool = False,\n", + ") -> Tuple[DataLoader, DataLoader]:\n", + " \"\"\"\n", + " Create train and validation data loaders for homography estimation.\n", + "\n", + " Args:\n", + " root_dir: Directory containing image pairs\n", + " batch_size: Batch size for data loaders\n", + " train_split: Fraction of data to use for training\n", + " num_workers: Number of worker processes for data loading\n", + " image_size: Target image size (height, width)\n", + " augment_train: Whether to augment training data\n", + " augment_val: Whether to augment validation data\n", + "\n", + " Returns:\n", + " Tuple of (train_loader, val_loader)\n", + " \"\"\"\n", + " from torchvision import transforms\n", + "\n", + " # Define transforms\n", + " transform = transforms.Compose(\n", + " [\n", + " transforms.ToTensor(),\n", + " transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),\n", + " ]\n", + " )\n", + "\n", + " # Create full dataset\n", + " full_dataset = HomographyDataset(\n", + " root_dir=root_dir,\n", + " transform=transform,\n", + " augment=False, # We'll handle augmentation separately\n", + " image_size=image_size,\n", + " cache_homographies=True,\n", + " )\n", + "\n", + " # Split dataset\n", + " dataset_size = len(full_dataset)\n", + " train_size = int(train_split * dataset_size)\n", + " val_size = dataset_size - train_size\n", + "\n", + " # Create indices for splitting\n", + " indices = list(range(dataset_size))\n", + " random.shuffle(indices)\n", + " train_indices = indices[:train_size]\n", + " val_indices = indices[train_size:]\n", + "\n", + " # Create subset samplers\n", + " from torch.utils.data import Subset\n", + "\n", + " train_dataset = Subset(full_dataset, train_indices)\n", + " val_dataset = Subset(full_dataset, val_indices)\n", + "\n", + " # Apply augmentation by overriding __getitem__ for train dataset\n", + " if augment_train:\n", + "\n", + " class AugmentedSubset(Subset):\n", + " def __getitem__(self, idx):\n", + " sample = self.dataset[self.indices[idx]]\n", + " # Apply augmentation\n", + " google_img = sample[\"google_img\"]\n", + " yandex_img = sample[\"yandex_img\"]\n", + " homography = sample[\"homography\"]\n", + "\n", + " # Generate augmentation homography\n", + " aug_homography = torch.from_numpy(\n", + " full_dataset.generate_random_homography()\n", + " ).float()\n", + "\n", + " # Combine homographies\n", + " combined_homography = aug_homography @ homography\n", + "\n", + " # Apply augmentation (simplified - in practice would warp images)\n", + " # For now, we just return the combined homography\n", + " return {\n", + " \"google_img\": google_img,\n", + " \"yandex_img\": yandex_img,\n", + " \"homography\": combined_homography,\n", + " \"idx\": sample[\"idx\"],\n", + " }\n", + "\n", + " train_dataset = AugmentedSubset(full_dataset, train_indices)\n", + "\n", + " # Create data loaders\n", + " train_loader = DataLoader(\n", + " train_dataset,\n", + " batch_size=batch_size,\n", + " shuffle=True,\n", + " num_workers=num_workers,\n", + " pin_memory=True,\n", + " )\n", + "\n", + " val_loader = DataLoader(\n", + " val_dataset,\n", + " batch_size=batch_size,\n", + " shuffle=False,\n", + " num_workers=num_workers,\n", + " pin_memory=True,\n", + " )\n", + "\n", + " return train_loader, val_loader\n", + "\n", + "\n", + "# Example usage\n", + "dataset = HomographyDataset(\n", + " root_dir=r\"C:\\Users\\admin\\Projects\\autopilot\\datasets\\ya_go_maps\\images\",\n", + " augment=True,\n", + " image_size=(700, 700),\n", + ")\n", + "\n", + "print(f\"Dataset size: {len(dataset)}\")\n", + "\n", + "# Get a sample\n", + "sample = dataset[0]\n", + "print(f\"Sample keys: {list(sample.keys())}\")\n", + "print(f\"Google image shape: {sample['google_img'].shape}\")\n", + "print(f\"Yandex image shape: {sample['yandex_img'].shape}\")\n", + "print(f\"Homography shape: {sample['homography'].shape}\")\n", + "\n", + "# Create data loaders\n", + "train_loader, val_loader = create_data_loaders(\n", + " root_dir=r\"C:\\Users\\admin\\Projects\\autopilot\\datasets\\ya_go_maps\\images\",\n", + " batch_size=16,\n", + " train_split=0.8,\n", + ")\n", + "\n", + "print(f\"Train batches: {len(train_loader)}\")\n", + "print(f\"Val batches: {len(val_loader)}\")\n" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "id": "bf3b0524", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Using device: cpu\n", + "Model created with 9,013,385 parameters\n", + "\n", + "Testing forward pass...\n", + "Output shape: torch.Size([4, 3, 3])\n", + "Sample output:\n", + "tensor([[-1.3744e+01, 9.4431e+00, 1.8618e+01],\n", + " [ 9.8099e+00, 5.7875e+00, -2.4102e+01],\n", + " [ 9.3618e-03, 3.3153e+00, 1.0000e+00]], grad_fn=)\n", + "\n", + "Testing prediction...\n", + "Prediction shape: torch.Size([4, 3, 3])\n", + "Last element (should be ~1): 1.000000\n", + "\n", + "Testing loss function...\n", + "Loss value: 582368034816.000000\n", + "\n", + "Testing metrics...\n", + "matrix_mse: 4946.436523\n", + "corner_error: 5.424056\n", + "corner_error_px: 694.279114\n", + "\n", + "Testing model factory...\n", + "Model2 created with 1,779,145 parameters\n", + "\n", + "All tests completed successfully!\n" + ] + } + ], + "source": [ + "from typing import Optional, Tuple\n", + "\n", + "import torch\n", + "import torch.nn as nn\n", + "import torch.nn.functional as F\n", + "\n", + "\n", + "class HomographyCNN(nn.Module):\n", + " \"\"\"\n", + " CNN model for homography estimation between two images.\n", + "\n", + " This model takes two images (Google and Yandex maps) as input and\n", + " outputs a 3x3 homography matrix that transforms one image to align with the other.\n", + " \"\"\"\n", + "\n", + " def __init__(\n", + " self,\n", + " input_channels: int = 3,\n", + " hidden_channels: int = 64,\n", + " num_blocks: int = 4,\n", + " dropout_rate: float = 0.3,\n", + " use_batch_norm: bool = True,\n", + " output_size: int = 9, # Flattened 3x3 homography matrix\n", + " ):\n", + " \"\"\"\n", + " Initialize the HomographyCNN model.\n", + "\n", + " Args:\n", + " input_channels: Number of input channels per image (3 for RGB)\n", + " hidden_channels: Base number of channels in the network\n", + " num_blocks: Number of convolutional blocks\n", + " dropout_rate: Dropout rate for regularization\n", + " use_batch_norm: Whether to use batch normalization\n", + " output_size: Size of output vector (9 for flattened 3x3 matrix)\n", + " \"\"\"\n", + " super().__init__()\n", + "\n", + " self.input_channels = input_channels\n", + " self.hidden_channels = hidden_channels\n", + " self.num_blocks = num_blocks\n", + " self.dropout_rate = dropout_rate\n", + " self.use_batch_norm = use_batch_norm\n", + "\n", + " # Feature extraction for each image separately\n", + " self.google_encoder = self._build_encoder()\n", + " self.yandex_encoder = self._build_encoder()\n", + "\n", + " # Fusion layers to combine features from both images\n", + " self.fusion_layers = self._build_fusion_layers()\n", + "\n", + " # Regression head for homography estimation\n", + " self.regression_head = self._build_regression_head(output_size)\n", + "\n", + " # Initialize weights\n", + " self._initialize_weights()\n", + "\n", + " def _build_encoder(self) -> nn.Module:\n", + " \"\"\"Build the encoder network for a single image.\"\"\"\n", + " layers = []\n", + " in_channels = self.input_channels\n", + " out_channels = self.hidden_channels\n", + "\n", + " # First convolutional block\n", + " layers.append(\n", + " nn.Conv2d(in_channels, out_channels, kernel_size=7, stride=2, padding=3)\n", + " )\n", + " if self.use_batch_norm:\n", + " layers.append(nn.BatchNorm2d(out_channels))\n", + " layers.append(nn.ReLU(inplace=True))\n", + " layers.append(nn.MaxPool2d(kernel_size=3, stride=2, padding=1))\n", + "\n", + " # Additional convolutional blocks\n", + " for i in range(self.num_blocks):\n", + " block_in_channels = out_channels\n", + " block_out_channels = out_channels * 2 if i < 2 else out_channels\n", + "\n", + " layers.append(\n", + " ResidualBlock(\n", + " in_channels=block_in_channels,\n", + " out_channels=block_out_channels,\n", + " stride=1 if i == 0 else 2,\n", + " dropout_rate=self.dropout_rate,\n", + " use_batch_norm=self.use_batch_norm,\n", + " )\n", + " )\n", + "\n", + " if i < 2:\n", + " out_channels = block_out_channels\n", + "\n", + " return nn.Sequential(*layers)\n", + "\n", + " def _build_fusion_layers(self) -> nn.Module:\n", + " \"\"\"Build layers to fuse features from both images.\"\"\"\n", + " # After encoding, each image has hidden_channels * 4 features\n", + " fused_channels = (\n", + " self.hidden_channels * 8\n", + " ) # Concatenated features from both images\n", + "\n", + " layers = [\n", + " # Reduce dimensionality\n", + " nn.Conv2d(\n", + " fused_channels, self.hidden_channels * 4, kernel_size=3, padding=1\n", + " ),\n", + " nn.BatchNorm2d(self.hidden_channels * 4)\n", + " if self.use_batch_norm\n", + " else nn.Identity(),\n", + " nn.ReLU(inplace=True),\n", + " nn.Dropout2d(self.dropout_rate),\n", + " # Further processing\n", + " nn.Conv2d(\n", + " self.hidden_channels * 4,\n", + " self.hidden_channels * 2,\n", + " kernel_size=3,\n", + " padding=1,\n", + " ),\n", + " nn.BatchNorm2d(self.hidden_channels * 2)\n", + " if self.use_batch_norm\n", + " else nn.Identity(),\n", + " nn.ReLU(inplace=True),\n", + " nn.Dropout2d(self.dropout_rate),\n", + " # Global pooling\n", + " nn.AdaptiveAvgPool2d((1, 1)),\n", + " ]\n", + "\n", + " return nn.Sequential(*layers)\n", + "\n", + " def _build_regression_head(self, output_size: int) -> nn.Module:\n", + " \"\"\"Build the regression head for homography estimation.\"\"\"\n", + " # Input size after fusion and global pooling\n", + " input_features = self.hidden_channels * 2\n", + "\n", + " layers = [\n", + " nn.Flatten(),\n", + " nn.Linear(input_features, 512),\n", + " nn.BatchNorm1d(512) if self.use_batch_norm else nn.Identity(),\n", + " nn.ReLU(inplace=True),\n", + " nn.Dropout(self.dropout_rate),\n", + " nn.Linear(512, 256),\n", + " nn.BatchNorm1d(256) if self.use_batch_norm else nn.Identity(),\n", + " nn.ReLU(inplace=True),\n", + " nn.Dropout(self.dropout_rate),\n", + " nn.Linear(256, 128),\n", + " nn.BatchNorm1d(128) if self.use_batch_norm else nn.Identity(),\n", + " nn.ReLU(inplace=True),\n", + " nn.Dropout(self.dropout_rate),\n", + " nn.Linear(128, output_size),\n", + " ]\n", + "\n", + " return nn.Sequential(*layers)\n", + "\n", + " def _initialize_weights(self):\n", + " \"\"\"Initialize model weights.\"\"\"\n", + " for m in self.modules():\n", + " if isinstance(m, nn.Conv2d):\n", + " nn.init.kaiming_normal_(m.weight, mode=\"fan_out\", nonlinearity=\"relu\")\n", + " if m.bias is not None:\n", + " nn.init.constant_(m.bias, 0)\n", + " elif isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d):\n", + " nn.init.constant_(m.weight, 1)\n", + " nn.init.constant_(m.bias, 0)\n", + " elif isinstance(m, nn.Linear):\n", + " nn.init.normal_(m.weight, 0, 0.01)\n", + " nn.init.constant_(m.bias, 0)\n", + "\n", + " def forward(\n", + " self,\n", + " google_img: torch.Tensor,\n", + " yandex_img: torch.Tensor,\n", + " return_matrix: bool = True,\n", + " ) -> torch.Tensor:\n", + " \"\"\"\n", + " Forward pass of the model.\n", + "\n", + " Args:\n", + " google_img: Google map image tensor of shape (B, C, H, W)\n", + " yandex_img: Yandex map image tensor of shape (B, C, H, W)\n", + " return_matrix: If True, return 3x3 matrix; if False, return flattened vector\n", + "\n", + " Returns:\n", + " Homography matrix tensor of shape (B, 3, 3) or flattened vector of shape (B, 9)\n", + " \"\"\"\n", + " # Extract features from both images\n", + " google_features = self.google_encoder(google_img)\n", + " yandex_features = self.yandex_encoder(yandex_img)\n", + "\n", + " # Concatenate features along channel dimension\n", + " combined_features = torch.cat([google_features, yandex_features], dim=1)\n", + "\n", + " # Fuse features\n", + " fused_features = self.fusion_layers(combined_features)\n", + "\n", + " # Regression to get homography parameters\n", + " homography_flat = self.regression_head(fused_features)\n", + "\n", + " if return_matrix:\n", + " # Reshape to 3x3 matrix\n", + " batch_size = homography_flat.shape[0]\n", + " homography_matrix = homography_flat.view(batch_size, 3, 3)\n", + "\n", + " # Ensure the last element is 1 (homogeneous coordinate normalization)\n", + " # Add small epsilon to prevent division by zero\n", + " epsilon = 1e-8\n", + " homography_matrix = homography_matrix / (\n", + " homography_matrix[:, 2, 2].view(-1, 1, 1) + epsilon\n", + " )\n", + "\n", + " return homography_matrix\n", + " else:\n", + " return homography_flat\n", + "\n", + " def predict_homography(\n", + " self,\n", + " google_img: torch.Tensor,\n", + " yandex_img: torch.Tensor,\n", + " normalize: bool = True,\n", + " ) -> torch.Tensor:\n", + " \"\"\"\n", + " Predict homography matrix with optional normalization.\n", + "\n", + " Args:\n", + " google_img: Google map image tensor\n", + " yandex_img: Yandex map image tensor\n", + " normalize: Whether to normalize the homography matrix\n", + "\n", + " Returns:\n", + " Predicted homography matrix\n", + " \"\"\"\n", + " self.eval()\n", + " with torch.no_grad():\n", + " homography = self.forward(google_img, yandex_img, return_matrix=True)\n", + "\n", + " if normalize:\n", + " # Normalize so that last element is 1\n", + " homography = homography / homography[:, 2, 2].view(-1, 1, 1)\n", + "\n", + " return homography\n", + "\n", + "\n", + "class ResidualBlock(nn.Module):\n", + " \"\"\"Residual block with optional downsampling.\"\"\"\n", + "\n", + " def __init__(\n", + " self,\n", + " in_channels: int,\n", + " out_channels: int,\n", + " stride: int = 1,\n", + " dropout_rate: float = 0.3,\n", + " use_batch_norm: bool = True,\n", + " ):\n", + " super().__init__()\n", + "\n", + " self.conv1 = nn.Conv2d(\n", + " in_channels,\n", + " out_channels,\n", + " kernel_size=3,\n", + " stride=stride,\n", + " padding=1,\n", + " bias=False,\n", + " )\n", + " self.bn1 = nn.BatchNorm2d(out_channels) if use_batch_norm else nn.Identity()\n", + " self.relu1 = nn.ReLU(inplace=True)\n", + " self.dropout1 = (\n", + " nn.Dropout2d(dropout_rate) if dropout_rate > 0 else nn.Identity()\n", + " )\n", + "\n", + " self.conv2 = nn.Conv2d(\n", + " out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False\n", + " )\n", + " self.bn2 = nn.BatchNorm2d(out_channels) if use_batch_norm else nn.Identity()\n", + " self.relu2 = nn.ReLU(inplace=True)\n", + " self.dropout2 = (\n", + " nn.Dropout2d(dropout_rate) if dropout_rate > 0 else nn.Identity()\n", + " )\n", + "\n", + " # Shortcut connection\n", + " self.shortcut = nn.Sequential()\n", + " if stride != 1 or in_channels != out_channels:\n", + " self.shortcut = nn.Sequential(\n", + " nn.Conv2d(\n", + " in_channels, out_channels, kernel_size=1, stride=stride, bias=False\n", + " ),\n", + " nn.BatchNorm2d(out_channels) if use_batch_norm else nn.Identity(),\n", + " )\n", + "\n", + " def forward(self, x: torch.Tensor) -> torch.Tensor:\n", + " identity = self.shortcut(x)\n", + "\n", + " out = self.conv1(x)\n", + " out = self.bn1(out)\n", + " out = self.relu1(out)\n", + " out = self.dropout1(out)\n", + "\n", + " out = self.conv2(out)\n", + " out = self.bn2(out)\n", + "\n", + " out += identity\n", + " out = self.relu2(out)\n", + " out = self.dropout2(out)\n", + "\n", + " return out\n", + "\n", + "\n", + "class HomographyLoss(nn.Module):\n", + " \"\"\"\n", + " Custom loss function for homography estimation.\n", + "\n", + " Combines multiple loss terms:\n", + " 1. Matrix element-wise L2 loss\n", + " 2. Geometric consistency loss (warping error)\n", + " 3. Determinant regularization (to prevent degenerate matrices)\n", + " \"\"\"\n", + "\n", + " def __init__(\n", + " self,\n", + " matrix_weight: float = 1.0,\n", + " geometric_weight: float = 0.5,\n", + " reg_weight: float = 0.1,\n", + " grid_size: int = 8,\n", + " ):\n", + " super().__init__()\n", + " self.matrix_weight = matrix_weight\n", + " self.geometric_weight = geometric_weight\n", + " self.reg_weight = reg_weight\n", + " self.grid_size = grid_size\n", + "\n", + " # Create grid of points for geometric loss\n", + " self.register_buffer(\n", + " \"grid_points\",\n", + " self._create_grid_points(grid_size),\n", + " persistent=False,\n", + " )\n", + "\n", + " def _create_grid_points(self, grid_size: int) -> torch.Tensor:\n", + " \"\"\"Create a grid of points for geometric consistency loss.\"\"\"\n", + " x = torch.linspace(-1, 1, grid_size)\n", + " y = torch.linspace(-1, 1, grid_size)\n", + " grid_y, grid_x = torch.meshgrid(y, x, indexing=\"ij\")\n", + " grid_points = torch.stack([grid_x.flatten(), grid_y.flatten()], dim=1)\n", + " # Add homogeneous coordinate\n", + " ones = torch.ones(grid_points.shape[0], 1)\n", + " grid_points = torch.cat([grid_points, ones], dim=1)\n", + " return grid_points.T # Shape: (3, grid_size*grid_size)\n", + "\n", + " def forward(\n", + " self,\n", + " pred_homography: torch.Tensor,\n", + " target_homography: torch.Tensor,\n", + " google_img: Optional[torch.Tensor] = None,\n", + " yandex_img: Optional[torch.Tensor] = None,\n", + " ) -> torch.Tensor:\n", + " \"\"\"\n", + " Compute homography loss.\n", + "\n", + " Args:\n", + " pred_homography: Predicted homography matrices (B, 3, 3)\n", + " target_homography: Target homography matrices (B, 3, 3)\n", + " google_img: Google images (optional, for geometric loss)\n", + " yandex_img: Yandex images (optional, for geometric loss)\n", + "\n", + " Returns:\n", + " Combined loss value\n", + " \"\"\"\n", + " batch_size = pred_homography.shape[0]\n", + "\n", + " # 1. Matrix element-wise L2 loss\n", + " matrix_loss = F.mse_loss(pred_homography, target_homography)\n", + "\n", + " # 2. Geometric consistency loss (if images provided)\n", + " geometric_loss = torch.tensor(0.0, device=pred_homography.device)\n", + " if google_img is not None and yandex_img is not None:\n", + " # Warp grid points with predicted homography\n", + " grid_points = self.grid_points.unsqueeze(0).expand(batch_size, -1, -1)\n", + " warped_points = torch.bmm(pred_homography, grid_points)\n", + "\n", + " # Normalize homogeneous coordinates\n", + " warped_points = warped_points / (warped_points[:, 2:3, :] + 1e-8)\n", + "\n", + " # Warp with target homography for comparison\n", + " target_warped_points = torch.bmm(target_homography, grid_points)\n", + " target_warped_points = target_warped_points / (\n", + " target_warped_points[:, 2:3, :] + 1e-8\n", + " )\n", + "\n", + " # Compute point-wise distance\n", + " geometric_loss = F.mse_loss(\n", + " warped_points[:, :2, :], target_warped_points[:, :2, :]\n", + " )\n", + "\n", + " # 3. Regularization loss (prevent degenerate matrices)\n", + " # Encourage determinant to be close to 1\n", + " pred_det = torch.det(pred_homography)\n", + " reg_loss = F.mse_loss(pred_det, torch.ones_like(pred_det))\n", + "\n", + " # Combine losses\n", + " total_loss = (\n", + " self.matrix_weight * matrix_loss\n", + " + self.geometric_weight * geometric_loss\n", + " + self.reg_weight * reg_loss\n", + " )\n", + "\n", + " return total_loss\n", + "\n", + " def compute_metrics(\n", + " self,\n", + " pred_homography: torch.Tensor,\n", + " target_homography: torch.Tensor,\n", + " ) -> dict:\n", + " \"\"\"\n", + " Compute evaluation metrics for homography estimation.\n", + "\n", + " Args:\n", + " pred_homography: Predicted homography matrices\n", + " target_homography: Target homography matrices\n", + "\n", + " Returns:\n", + " Dictionary of metrics\n", + " \"\"\"\n", + " with torch.no_grad():\n", + " # Normalize matrices\n", + " pred_norm = pred_homography / pred_homography[:, 2, 2].view(-1, 1, 1)\n", + " target_norm = target_homography / target_homography[:, 2, 2].view(-1, 1, 1)\n", + "\n", + " # Matrix L2 error\n", + " matrix_error = F.mse_loss(pred_norm, target_norm, reduction=\"none\").mean(\n", + " dim=(1, 2)\n", + " )\n", + "\n", + " # Corner error (warp 4 corners of the image)\n", + " corners = torch.tensor(\n", + " [[-1, -1, 1], [1, -1, 1], [1, 1, 1], [-1, 1, 1]],\n", + " dtype=torch.float32,\n", + " device=pred_homography.device,\n", + " ).T # Shape: (3, 4)\n", + "\n", + " corners = corners.unsqueeze(0).expand(pred_homography.shape[0], -1, -1)\n", + "\n", + " pred_corners = torch.bmm(pred_norm, corners)\n", + " pred_corners = pred_corners / (pred_corners[:, 2:3, :] + 1e-8)\n", + "\n", + " target_corners = torch.bmm(target_norm, corners)\n", + " target_corners = target_corners / (target_corners[:, 2:3, :] + 1e-8)\n", + "\n", + " corner_error = torch.mean(\n", + " torch.norm(pred_corners[:, :2, :] - target_corners[:, :2, :], dim=1),\n", + " dim=1,\n", + " )\n", + "\n", + " # Average corner error in pixels (assuming image coordinates in [-1, 1])\n", + " # Convert to pixel error if image size is known\n", + " avg_corner_error = corner_error.mean().item()\n", + "\n", + " return {\n", + " \"matrix_mse\": matrix_error.mean().item(),\n", + " \"corner_error\": avg_corner_error,\n", + " \"corner_error_px\": avg_corner_error * 128, # Assuming 256x256 images\n", + " }\n", + "\n", + "\n", + "def create_homography_model(\n", + " model_type: str = \"cnn\",\n", + " input_size: Tuple[int, int] = (256, 256),\n", + " **kwargs,\n", + ") -> nn.Module:\n", + " \"\"\"\n", + " Factory function to create homography estimation model.\n", + "\n", + " Args:\n", + " model_type: Type of model to create ('cnn' or 'resnet')\n", + " input_size: Input image size (height, width)\n", + " **kwargs: Additional arguments passed to model constructor\n", + "\n", + " Returns:\n", + " Homography estimation model\n", + " \"\"\"\n", + " if model_type == \"cnn\":\n", + " return HomographyCNN(**kwargs)\n", + " else:\n", + " raise ValueError(f\"Unknown model type: {model_type}\")\n", + "\n", + "\n", + "# Test the model\n", + "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", + "print(f\"Using device: {device}\")\n", + "\n", + "# Create model\n", + "model = HomographyCNN(\n", + " input_channels=3,\n", + " hidden_channels=64,\n", + " num_blocks=4,\n", + " dropout_rate=0.3,\n", + " use_batch_norm=True,\n", + ").to(device)\n", + "\n", + "print(\n", + " f\"Model created with {sum(p.numel() for p in model.parameters()):,} parameters\"\n", + ")\n", + "\n", + "# Create dummy input\n", + "batch_size = 4\n", + "height, width = 700, 700\n", + "\n", + "google_img = torch.randn(batch_size, 3, height, width).to(device)\n", + "yandex_img = torch.randn(batch_size, 3, height, width).to(device)\n", + "\n", + "# Test forward pass\n", + "print(\"\\nTesting forward pass...\")\n", + "output = model(google_img, yandex_img, return_matrix=True)\n", + "print(f\"Output shape: {output.shape}\") # Should be (4, 3, 3)\n", + "print(f\"Sample output:\\n{output[0]}\")\n", + "\n", + "# Test prediction\n", + "print(\"\\nTesting prediction...\")\n", + "pred = model.predict_homography(google_img, yandex_img)\n", + "print(f\"Prediction shape: {pred.shape}\")\n", + "print(f\"Last element (should be ~1): {pred[0, 2, 2]:.6f}\")\n", + "\n", + "# Test loss function\n", + "print(\"\\nTesting loss function...\")\n", + "target_homography = torch.eye(3).unsqueeze(0).expand(batch_size, -1, -1).to(device)\n", + "loss_fn = HomographyLoss(\n", + " matrix_weight=1.0,\n", + " geometric_weight=0.5,\n", + " reg_weight=0.1,\n", + " grid_size=8,\n", + ").to(device)\n", + "\n", + "loss = loss_fn(output, target_homography, google_img, yandex_img)\n", + "print(f\"Loss value: {loss.item():.6f}\")\n", + "\n", + "# Test metrics\n", + "print(\"\\nTesting metrics...\")\n", + "metrics = loss_fn.compute_metrics(output, target_homography)\n", + "for key, value in metrics.items():\n", + " print(f\"{key}: {value:.6f}\")\n", + "\n", + "# Test model factory\n", + "print(\"\\nTesting model factory...\")\n", + "model2 = create_homography_model(\n", + " model_type=\"cnn\",\n", + " input_size=(256, 256),\n", + " input_channels=3,\n", + " hidden_channels=32,\n", + " num_blocks=3,\n", + ").to(device)\n", + "\n", + "print(\n", + " f\"Model2 created with {sum(p.numel() for p in model2.parameters()):,} parameters\"\n", + ")\n", + "\n", + "print(\"\\nAll tests completed successfully!\")\n" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "id": "d7979efa", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Using device: cpu\n", + "Creating data loaders...\n", + "Found 327 image pairs in C:\\Users\\admin\\Projects\\autopilot\\datasets\\ya_go_maps\\images\n", + "Train batches: 9\n", + "Val batches: 3\n", + "Creating model...\n", + "Training configuration saved to runs\\homography\\config.json\n", + "Model has 8,999,817 parameters\n", + "Starting training for 100 epochs...\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Epoch 1: 0%| | 0/9 [00:05 \u001b[39m\u001b[32m1310\u001b[39m data = \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43m_data_queue\u001b[49m\u001b[43m.\u001b[49m\u001b[43mget\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtimeout\u001b[49m\u001b[43m=\u001b[49m\u001b[43mtimeout\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 1311\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m (\u001b[38;5;28;01mTrue\u001b[39;00m, data)\n", + "\u001b[36mFile \u001b[39m\u001b[32m~\\AppData\\Local\\Programs\\Python\\Python311\\Lib\\multiprocessing\\queues.py:114\u001b[39m, in \u001b[36mQueue.get\u001b[39m\u001b[34m(self, block, timeout)\u001b[39m\n\u001b[32m 113\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28mself\u001b[39m._poll(timeout):\n\u001b[32m--> \u001b[39m\u001b[32m114\u001b[39m \u001b[38;5;28;01mraise\u001b[39;00m Empty\n\u001b[32m 115\u001b[39m \u001b[38;5;28;01melif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28mself\u001b[39m._poll():\n", + "\u001b[31mEmpty\u001b[39m: ", + "\nThe above exception was the direct cause of the following exception:\n", + "\u001b[31mRuntimeError\u001b[39m Traceback (most recent call last)", + "\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[24]\u001b[39m\u001b[32m, line 533\u001b[39m\n\u001b[32m 530\u001b[39m trainer.evaluate()\n\u001b[32m 531\u001b[39m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[32m 532\u001b[39m \u001b[38;5;66;03m# Train the model\u001b[39;00m\n\u001b[32m--> \u001b[39m\u001b[32m533\u001b[39m \u001b[43mtrainer\u001b[49m\u001b[43m.\u001b[49m\u001b[43mtrain\u001b[49m\u001b[43m(\u001b[49m\u001b[43mnum_epochs\u001b[49m\u001b[43m=\u001b[49m\u001b[43margs\u001b[49m\u001b[43m.\u001b[49m\u001b[43mepochs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 535\u001b[39m \u001b[38;5;66;03m# Final evaluation\u001b[39;00m\n\u001b[32m 536\u001b[39m \u001b[38;5;28mprint\u001b[39m(\u001b[33m\"\u001b[39m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[33mPerforming final evaluation...\u001b[39m\u001b[33m\"\u001b[39m)\n", + "\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[24]\u001b[39m\u001b[32m, line 299\u001b[39m, in \u001b[36mHomographyTrainer.train\u001b[39m\u001b[34m(self, num_epochs)\u001b[39m\n\u001b[32m 296\u001b[39m \u001b[38;5;28mself\u001b[39m.current_epoch = epoch\n\u001b[32m 298\u001b[39m \u001b[38;5;66;03m# Train for one epoch\u001b[39;00m\n\u001b[32m--> \u001b[39m\u001b[32m299\u001b[39m train_loss = \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43mtrain_epoch\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 301\u001b[39m \u001b[38;5;66;03m# Validate\u001b[39;00m\n\u001b[32m 302\u001b[39m val_loss, val_metrics = \u001b[38;5;28mself\u001b[39m.validate()\n", + "\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[24]\u001b[39m\u001b[32m, line 143\u001b[39m, in \u001b[36mHomographyTrainer.train_epoch\u001b[39m\u001b[34m(self)\u001b[39m\n\u001b[32m 140\u001b[39m num_batches = \u001b[38;5;28mlen\u001b[39m(\u001b[38;5;28mself\u001b[39m.train_loader)\n\u001b[32m 142\u001b[39m progress_bar = tqdm(\u001b[38;5;28mself\u001b[39m.train_loader, desc=\u001b[33mf\u001b[39m\u001b[33m\"\u001b[39m\u001b[33mEpoch \u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28mself\u001b[39m.current_epoch\u001b[38;5;250m \u001b[39m+\u001b[38;5;250m \u001b[39m\u001b[32m1\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m\"\u001b[39m)\n\u001b[32m--> \u001b[39m\u001b[32m143\u001b[39m \u001b[43m\u001b[49m\u001b[38;5;28;43;01mfor\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43mbatch_idx\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mbatch\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;129;43;01min\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[38;5;28;43menumerate\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43mprogress_bar\u001b[49m\u001b[43m)\u001b[49m\u001b[43m:\u001b[49m\n\u001b[32m 144\u001b[39m \u001b[43m \u001b[49m\u001b[38;5;66;43;03m# Move data to device\u001b[39;49;00m\n\u001b[32m 145\u001b[39m \u001b[43m \u001b[49m\u001b[43mgoogle_img\u001b[49m\u001b[43m \u001b[49m\u001b[43m=\u001b[49m\u001b[43m \u001b[49m\u001b[43mbatch\u001b[49m\u001b[43m[\u001b[49m\u001b[33;43m\"\u001b[39;49m\u001b[33;43mgoogle_img\u001b[39;49m\u001b[33;43m\"\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m.\u001b[49m\u001b[43mto\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43mdevice\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 146\u001b[39m \u001b[43m \u001b[49m\u001b[43myandex_img\u001b[49m\u001b[43m \u001b[49m\u001b[43m=\u001b[49m\u001b[43m \u001b[49m\u001b[43mbatch\u001b[49m\u001b[43m[\u001b[49m\u001b[33;43m\"\u001b[39;49m\u001b[33;43myandex_img\u001b[39;49m\u001b[33;43m\"\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m.\u001b[49m\u001b[43mto\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43mdevice\u001b[49m\u001b[43m)\u001b[49m\n", + "\u001b[36mFile \u001b[39m\u001b[32mc:\\Users\\admin\\Projects\\autopilot\\.venv\\Lib\\site-packages\\tqdm\\std.py:1181\u001b[39m, in \u001b[36mtqdm.__iter__\u001b[39m\u001b[34m(self)\u001b[39m\n\u001b[32m 1178\u001b[39m time = \u001b[38;5;28mself\u001b[39m._time\n\u001b[32m 1180\u001b[39m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[32m-> \u001b[39m\u001b[32m1181\u001b[39m \u001b[43m \u001b[49m\u001b[38;5;28;43;01mfor\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43mobj\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;129;43;01min\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43miterable\u001b[49m\u001b[43m:\u001b[49m\n\u001b[32m 1182\u001b[39m \u001b[43m \u001b[49m\u001b[38;5;28;43;01myield\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43mobj\u001b[49m\n\u001b[32m 1183\u001b[39m \u001b[43m \u001b[49m\u001b[38;5;66;43;03m# Update and possibly print the progressbar.\u001b[39;49;00m\n\u001b[32m 1184\u001b[39m \u001b[43m \u001b[49m\u001b[38;5;66;43;03m# Note: does not call self.update(1) for speed optimisation.\u001b[39;49;00m\n", + "\u001b[36mFile \u001b[39m\u001b[32mc:\\Users\\admin\\Projects\\autopilot\\.venv\\Lib\\site-packages\\torch\\utils\\data\\dataloader.py:741\u001b[39m, in \u001b[36m_BaseDataLoaderIter.__next__\u001b[39m\u001b[34m(self)\u001b[39m\n\u001b[32m 738\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m._sampler_iter \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[32m 739\u001b[39m \u001b[38;5;66;03m# TODO(https://github.com/pytorch/pytorch/issues/76750)\u001b[39;00m\n\u001b[32m 740\u001b[39m \u001b[38;5;28mself\u001b[39m._reset() \u001b[38;5;66;03m# type: ignore[call-arg]\u001b[39;00m\n\u001b[32m--> \u001b[39m\u001b[32m741\u001b[39m data = \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43m_next_data\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 742\u001b[39m \u001b[38;5;28mself\u001b[39m._num_yielded += \u001b[32m1\u001b[39m\n\u001b[32m 743\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m (\n\u001b[32m 744\u001b[39m \u001b[38;5;28mself\u001b[39m._dataset_kind == _DatasetKind.Iterable\n\u001b[32m 745\u001b[39m \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;28mself\u001b[39m._IterableDataset_len_called \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[32m 746\u001b[39m \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;28mself\u001b[39m._num_yielded > \u001b[38;5;28mself\u001b[39m._IterableDataset_len_called\n\u001b[32m 747\u001b[39m ):\n", + "\u001b[36mFile \u001b[39m\u001b[32mc:\\Users\\admin\\Projects\\autopilot\\.venv\\Lib\\site-packages\\torch\\utils\\data\\dataloader.py:1524\u001b[39m, in \u001b[36m_MultiProcessingDataLoaderIter._next_data\u001b[39m\u001b[34m(self)\u001b[39m\n\u001b[32m 1520\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m._shutdown \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m._tasks_outstanding <= \u001b[32m0\u001b[39m:\n\u001b[32m 1521\u001b[39m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mAssertionError\u001b[39;00m(\n\u001b[32m 1522\u001b[39m \u001b[33m\"\u001b[39m\u001b[33mInvalid iterator state: shutdown or no outstanding tasks when fetching next data\u001b[39m\u001b[33m\"\u001b[39m\n\u001b[32m 1523\u001b[39m )\n\u001b[32m-> \u001b[39m\u001b[32m1524\u001b[39m idx, data = \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43m_get_data\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 1525\u001b[39m \u001b[38;5;28mself\u001b[39m._tasks_outstanding -= \u001b[32m1\u001b[39m\n\u001b[32m 1526\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m._dataset_kind == _DatasetKind.Iterable:\n\u001b[32m 1527\u001b[39m \u001b[38;5;66;03m# Check for _IterableDatasetStopIteration\u001b[39;00m\n", + "\u001b[36mFile \u001b[39m\u001b[32mc:\\Users\\admin\\Projects\\autopilot\\.venv\\Lib\\site-packages\\torch\\utils\\data\\dataloader.py:1483\u001b[39m, in \u001b[36m_MultiProcessingDataLoaderIter._get_data\u001b[39m\u001b[34m(self)\u001b[39m\n\u001b[32m 1479\u001b[39m \u001b[38;5;66;03m# In this case, `self._data_queue` is a `queue.Queue`,. But we don't\u001b[39;00m\n\u001b[32m 1480\u001b[39m \u001b[38;5;66;03m# need to call `.task_done()` because we don't use `.join()`.\u001b[39;00m\n\u001b[32m 1481\u001b[39m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[32m 1482\u001b[39m \u001b[38;5;28;01mwhile\u001b[39;00m \u001b[38;5;28;01mTrue\u001b[39;00m:\n\u001b[32m-> \u001b[39m\u001b[32m1483\u001b[39m success, data = \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43m_try_get_data\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 1484\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m success:\n\u001b[32m 1485\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m data\n", + "\u001b[36mFile \u001b[39m\u001b[32mc:\\Users\\admin\\Projects\\autopilot\\.venv\\Lib\\site-packages\\torch\\utils\\data\\dataloader.py:1323\u001b[39m, in \u001b[36m_MultiProcessingDataLoaderIter._try_get_data\u001b[39m\u001b[34m(self, timeout)\u001b[39m\n\u001b[32m 1321\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mlen\u001b[39m(failed_workers) > \u001b[32m0\u001b[39m:\n\u001b[32m 1322\u001b[39m pids_str = \u001b[33m\"\u001b[39m\u001b[33m, \u001b[39m\u001b[33m\"\u001b[39m.join(\u001b[38;5;28mstr\u001b[39m(w.pid) \u001b[38;5;28;01mfor\u001b[39;00m w \u001b[38;5;129;01min\u001b[39;00m failed_workers)\n\u001b[32m-> \u001b[39m\u001b[32m1323\u001b[39m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mRuntimeError\u001b[39;00m(\n\u001b[32m 1324\u001b[39m \u001b[33mf\u001b[39m\u001b[33m\"\u001b[39m\u001b[33mDataLoader worker (pid(s) \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mpids_str\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m) exited unexpectedly\u001b[39m\u001b[33m\"\u001b[39m\n\u001b[32m 1325\u001b[39m ) \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01me\u001b[39;00m\n\u001b[32m 1326\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(e, queue.Empty):\n\u001b[32m 1327\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m (\u001b[38;5;28;01mFalse\u001b[39;00m, \u001b[38;5;28;01mNone\u001b[39;00m)\n", + "\u001b[31mRuntimeError\u001b[39m: DataLoader worker (pid(s) 29616) exited unexpectedly" + ] + } + ], + "source": [ + "\"\"\"\n", + "Training script for homography estimation between Google and Yandex map images.\n", + "\n", + "This script trains a CNN model to estimate homography matrices that align\n", + "Google map images with Yandex map images.\n", + "\"\"\"\n", + "\n", + "import argparse\n", + "import json\n", + "import os\n", + "import time\n", + "from datetime import datetime\n", + "from pathlib import Path\n", + "from typing import Dict, List, Optional, Tuple\n", + "\n", + "import numpy as np\n", + "import torch\n", + "import torch.nn as nn\n", + "import torch.optim as optim\n", + "from torch.utils.data import DataLoader\n", + "from torch.utils.tensorboard import SummaryWriter\n", + "from tqdm import tqdm\n", + "\n", + "\n", + "class HomographyTrainer:\n", + " \"\"\"Trainer class for homography estimation model.\"\"\"\n", + "\n", + " def __init__(\n", + " self,\n", + " model: nn.Module,\n", + " train_loader: DataLoader,\n", + " val_loader: DataLoader,\n", + " device: torch.device,\n", + " config: Dict,\n", + " ):\n", + " \"\"\"\n", + " Initialize the trainer.\n", + "\n", + " Args:\n", + " model: Homography estimation model\n", + " train_loader: Training data loader\n", + " val_loader: Validation data loader\n", + " device: Device to run training on\n", + " config: Training configuration dictionary\n", + " \"\"\"\n", + " self.model = model.to(device)\n", + " self.train_loader = train_loader\n", + " self.val_loader = val_loader\n", + " self.device = device\n", + " self.config = config\n", + "\n", + " # Loss function\n", + " self.criterion = HomographyLoss(\n", + " matrix_weight=config.get(\"matrix_weight\", 1.0),\n", + " geometric_weight=config.get(\"geometric_weight\", 0.5),\n", + " reg_weight=config.get(\"reg_weight\", 0.1),\n", + " grid_size=config.get(\"grid_size\", 8),\n", + " ).to(device)\n", + "\n", + " # Optimizer\n", + " optimizer_name = config.get(\"optimizer\", \"adam\").lower()\n", + " lr = config.get(\"learning_rate\", 1e-3)\n", + " weight_decay = config.get(\"weight_decay\", 1e-4)\n", + "\n", + " if optimizer_name == \"adam\":\n", + " self.optimizer = optim.Adam(\n", + " self.model.parameters(), lr=lr, weight_decay=weight_decay\n", + " )\n", + " elif optimizer_name == \"adamw\":\n", + " self.optimizer = optim.AdamW(\n", + " self.model.parameters(), lr=lr, weight_decay=weight_decay\n", + " )\n", + " elif optimizer_name == \"sgd\":\n", + " self.optimizer = optim.SGD(\n", + " self.model.parameters(),\n", + " lr=lr,\n", + " momentum=config.get(\"momentum\", 0.9),\n", + " weight_decay=weight_decay,\n", + " )\n", + " else:\n", + " raise ValueError(f\"Unknown optimizer: {optimizer_name}\")\n", + "\n", + " # Learning rate scheduler\n", + " scheduler_name = config.get(\"scheduler\", \"plateau\").lower()\n", + " if scheduler_name == \"plateau\":\n", + " self.scheduler = optim.lr_scheduler.ReduceLROnPlateau(\n", + " self.optimizer,\n", + " mode=\"min\",\n", + " factor=config.get(\"scheduler_factor\", 0.5),\n", + " patience=config.get(\"scheduler_patience\", 5),\n", + " )\n", + " elif scheduler_name == \"cosine\":\n", + " self.scheduler = optim.lr_scheduler.CosineAnnealingLR(\n", + " self.optimizer,\n", + " T_max=config.get(\"epochs\", 100),\n", + " eta_min=config.get(\"min_lr\", 1e-6),\n", + " )\n", + " elif scheduler_name == \"step\":\n", + " self.scheduler = optim.lr_scheduler.StepLR(\n", + " self.optimizer,\n", + " step_size=config.get(\"step_size\", 30),\n", + " gamma=config.get(\"gamma\", 0.1),\n", + " )\n", + " else:\n", + " self.scheduler = None\n", + "\n", + " # Training state\n", + " self.current_epoch = 0\n", + " self.best_val_loss = float(\"inf\")\n", + " self.train_losses: List[float] = []\n", + " self.val_losses: List[float] = []\n", + " self.val_metrics: List[Dict] = []\n", + "\n", + " # Create output directory\n", + " self.output_dir = Path(config.get(\"output_dir\", \"runs/homography\"))\n", + " self.output_dir.mkdir(parents=True, exist_ok=True)\n", + "\n", + " # TensorBoard writer\n", + " self.writer = SummaryWriter(log_dir=self.output_dir / \"tensorboard\")\n", + "\n", + " # Save configuration\n", + " config_path = self.output_dir / \"config.json\"\n", + " with open(config_path, \"w\") as f:\n", + " json.dump(config, f, indent=2)\n", + "\n", + " print(f\"Training configuration saved to {config_path}\")\n", + " print(\n", + " f\"Model has {sum(p.numel() for p in self.model.parameters()):,} parameters\"\n", + " )\n", + "\n", + " def train_epoch(self) -> float:\n", + " \"\"\"\n", + " Train for one epoch.\n", + "\n", + " Returns:\n", + " Average training loss for the epoch\n", + " \"\"\"\n", + " self.model.train()\n", + " total_loss = 0.0\n", + " num_batches = len(self.train_loader)\n", + "\n", + " progress_bar = tqdm(self.train_loader, desc=f\"Epoch {self.current_epoch + 1}\")\n", + " for batch_idx, batch in enumerate(progress_bar):\n", + " # Move data to device\n", + " google_img = batch[\"google_img\"].to(self.device)\n", + " yandex_img = batch[\"yandex_img\"].to(self.device)\n", + " target_homography = batch[\"homography\"].to(self.device)\n", + "\n", + " # Forward pass\n", + " self.optimizer.zero_grad()\n", + " pred_homography = self.model(google_img, yandex_img, return_matrix=True)\n", + "\n", + " # Compute loss\n", + " loss = self.criterion(\n", + " pred_homography,\n", + " target_homography,\n", + " google_img,\n", + " yandex_img,\n", + " )\n", + "\n", + " # Backward pass\n", + " loss.backward()\n", + "\n", + " # Gradient clipping\n", + " if self.config.get(\"grad_clip\", 1.0) > 0:\n", + " torch.nn.utils.clip_grad_norm_(\n", + " self.model.parameters(),\n", + " self.config.get(\"grad_clip\", 1.0),\n", + " )\n", + "\n", + " # Optimizer step\n", + " self.optimizer.step()\n", + "\n", + " # Update statistics\n", + " total_loss += loss.item()\n", + "\n", + " # Update progress bar\n", + " progress_bar.set_postfix({\"loss\": loss.item()})\n", + "\n", + " # Log batch loss to TensorBoard\n", + " global_step = self.current_epoch * num_batches + batch_idx\n", + " self.writer.add_scalar(\"train/batch_loss\", loss.item(), global_step)\n", + "\n", + " avg_loss = total_loss / num_batches\n", + " self.train_losses.append(avg_loss)\n", + "\n", + " return avg_loss\n", + "\n", + " @torch.no_grad()\n", + " def validate(self) -> Tuple[float, Dict]:\n", + " \"\"\"\n", + " Validate the model.\n", + "\n", + " Returns:\n", + " Tuple of (average validation loss, validation metrics)\n", + " \"\"\"\n", + " self.model.eval()\n", + " total_loss = 0.0\n", + " all_metrics = []\n", + "\n", + " progress_bar = tqdm(self.val_loader, desc=\"Validation\")\n", + " for batch in progress_bar:\n", + " # Move data to device\n", + " google_img = batch[\"google_img\"].to(self.device)\n", + " yandex_img = batch[\"yandex_img\"].to(self.device)\n", + " target_homography = batch[\"homography\"].to(self.device)\n", + "\n", + " # Forward pass\n", + " pred_homography = self.model(google_img, yandex_img, return_matrix=True)\n", + "\n", + " # Compute loss\n", + " loss = self.criterion(\n", + " pred_homography,\n", + " target_homography,\n", + " google_img,\n", + " yandex_img,\n", + " )\n", + "\n", + " # Compute metrics\n", + " metrics = self.criterion.compute_metrics(pred_homography, target_homography)\n", + "\n", + " # Update statistics\n", + " total_loss += loss.item()\n", + " all_metrics.append(metrics)\n", + "\n", + " # Update progress bar\n", + " progress_bar.set_postfix({\"loss\": loss.item()})\n", + "\n", + " avg_loss = total_loss / len(self.val_loader)\n", + " self.val_losses.append(avg_loss)\n", + "\n", + " # Aggregate metrics\n", + " avg_metrics = {}\n", + " for key in all_metrics[0].keys():\n", + " avg_metrics[key] = np.mean([m[key] for m in all_metrics])\n", + "\n", + " self.val_metrics.append(avg_metrics)\n", + "\n", + " return avg_loss, avg_metrics\n", + "\n", + " def save_checkpoint(self, is_best: bool = False):\n", + " \"\"\"Save model checkpoint.\"\"\"\n", + " checkpoint = {\n", + " \"epoch\": self.current_epoch,\n", + " \"model_state_dict\": self.model.state_dict(),\n", + " \"optimizer_state_dict\": self.optimizer.state_dict(),\n", + " \"train_losses\": self.train_losses,\n", + " \"val_losses\": self.val_losses,\n", + " \"val_metrics\": self.val_metrics,\n", + " \"best_val_loss\": self.best_val_loss,\n", + " \"config\": self.config,\n", + " }\n", + "\n", + " if self.scheduler is not None:\n", + " checkpoint[\"scheduler_state_dict\"] = self.scheduler.state_dict()\n", + "\n", + " # Save latest checkpoint\n", + " checkpoint_path = self.output_dir / \"checkpoint_latest.pth\"\n", + " torch.save(checkpoint, checkpoint_path)\n", + "\n", + " # Save best checkpoint\n", + " if is_best:\n", + " best_path = self.output_dir / \"checkpoint_best.pth\"\n", + " torch.save(checkpoint, best_path)\n", + " print(f\"Best model saved to {best_path}\")\n", + "\n", + " def load_checkpoint(self, checkpoint_path: str):\n", + " \"\"\"Load model checkpoint.\"\"\"\n", + " checkpoint = torch.load(checkpoint_path, map_location=self.device)\n", + "\n", + " self.model.load_state_dict(checkpoint[\"model_state_dict\"])\n", + " self.optimizer.load_state_dict(checkpoint[\"optimizer_state_dict\"])\n", + "\n", + " if self.scheduler is not None and \"scheduler_state_dict\" in checkpoint:\n", + " self.scheduler.load_state_dict(checkpoint[\"scheduler_state_dict\"])\n", + "\n", + " self.current_epoch = checkpoint[\"epoch\"]\n", + " self.train_losses = checkpoint[\"train_losses\"]\n", + " self.val_losses = checkpoint[\"val_losses\"]\n", + " self.val_metrics = checkpoint[\"val_metrics\"]\n", + " self.best_val_loss = checkpoint[\"best_val_loss\"]\n", + "\n", + " print(f\"Loaded checkpoint from epoch {self.current_epoch}\")\n", + "\n", + " def train(self, num_epochs: int):\n", + " \"\"\"\n", + " Train the model for specified number of epochs.\n", + "\n", + " Args:\n", + " num_epochs: Number of epochs to train\n", + " \"\"\"\n", + " print(f\"Starting training for {num_epochs} epochs...\")\n", + " start_time = time.time()\n", + "\n", + " for epoch in range(num_epochs):\n", + " self.current_epoch = epoch\n", + "\n", + " # Train for one epoch\n", + " train_loss = self.train_epoch()\n", + "\n", + " # Validate\n", + " val_loss, val_metrics = self.validate()\n", + "\n", + " # Update learning rate scheduler\n", + " if self.scheduler is not None:\n", + " if isinstance(self.scheduler, optim.lr_scheduler.ReduceLROnPlateau):\n", + " self.scheduler.step(val_loss)\n", + " else:\n", + " self.scheduler.step()\n", + "\n", + " # Log to TensorBoard\n", + " self.writer.add_scalar(\"train/epoch_loss\", train_loss, epoch)\n", + " self.writer.add_scalar(\"val/epoch_loss\", val_loss, epoch)\n", + " for metric_name, metric_value in val_metrics.items():\n", + " self.writer.add_scalar(f\"val/{metric_name}\", metric_value, epoch)\n", + "\n", + " # Print epoch summary\n", + " print(f\"\\nEpoch {epoch + 1}/{num_epochs}:\")\n", + " print(f\" Train Loss: {train_loss:.6f}\")\n", + " print(f\" Val Loss: {val_loss:.6f}\")\n", + " print(\" Val Metrics:\")\n", + " for metric_name, metric_value in val_metrics.items():\n", + " print(f\" {metric_name}: {metric_value:.6f}\")\n", + "\n", + " # Save checkpoint\n", + " is_best = val_loss < self.best_val_loss\n", + " if is_best:\n", + " self.best_val_loss = val_loss\n", + "\n", + " self.save_checkpoint(is_best=is_best)\n", + "\n", + " # Early stopping\n", + " if self.config.get(\"early_stopping_patience\", 0) > 0:\n", + " if (\n", + " epoch - np.argmin(self.val_losses)\n", + " >= self.config[\"early_stopping_patience\"]\n", + " ):\n", + " print(f\"Early stopping at epoch {epoch + 1}\")\n", + " break\n", + "\n", + " # Training completed\n", + " training_time = time.time() - start_time\n", + " print(f\"\\nTraining completed in {training_time:.2f} seconds\")\n", + " print(f\"Best validation loss: {self.best_val_loss:.6f}\")\n", + "\n", + " # Save final model\n", + " final_model_path = self.output_dir / \"model_final.pth\"\n", + " torch.save(self.model.state_dict(), final_model_path)\n", + " print(f\"Final model saved to {final_model_path}\")\n", + "\n", + " # Close TensorBoard writer\n", + " self.writer.close()\n", + "\n", + " def evaluate(self, test_loader: Optional[DataLoader] = None) -> Dict:\n", + " \"\"\"\n", + " Evaluate the model on test data.\n", + "\n", + " Args:\n", + " test_loader: Test data loader (uses validation loader if None)\n", + "\n", + " Returns:\n", + " Dictionary of evaluation metrics\n", + " \"\"\"\n", + " if test_loader is None:\n", + " test_loader = self.val_loader\n", + "\n", + " self.model.eval()\n", + " all_metrics = []\n", + "\n", + " print(\"Evaluating model...\")\n", + " with torch.no_grad():\n", + " for batch in tqdm(test_loader):\n", + " # Move data to device\n", + " google_img = batch[\"google_img\"].to(self.device)\n", + " yandex_img = batch[\"yandex_img\"].to(self.device)\n", + " target_homography = batch[\"homography\"].to(self.device)\n", + "\n", + " # Forward pass\n", + " pred_homography = self.model(google_img, yandex_img, return_matrix=True)\n", + "\n", + " # Compute metrics\n", + " metrics = self.criterion.compute_metrics(\n", + " pred_homography, target_homography\n", + " )\n", + " all_metrics.append(metrics)\n", + "\n", + " # Aggregate metrics\n", + " avg_metrics = {}\n", + " for key in all_metrics[0].keys():\n", + " avg_metrics[key] = np.mean([m[key] for m in all_metrics])\n", + "\n", + " # Print evaluation results\n", + " print(\"\\nEvaluation Results:\")\n", + " for metric_name, metric_value in avg_metrics.items():\n", + " print(f\" {metric_name}: {metric_value:.6f}\")\n", + "\n", + " # Save evaluation results\n", + " eval_path = self.output_dir / \"evaluation_results.json\"\n", + " with open(eval_path, \"w\") as f:\n", + " json.dump(avg_metrics, f, indent=2)\n", + " print(f\"Evaluation results saved to {eval_path}\")\n", + "\n", + " return avg_metrics\n", + "\n", + "\n", + "from types import SimpleNamespace\n", + "\n", + "# Дефолтные значения параметров\n", + "args = SimpleNamespace(\n", + " # Data arguments\n", + " data_dir=r\"C:\\Users\\admin\\Projects\\autopilot\\datasets\\ya_go_maps\\images\",\n", + " batch_size=32,\n", + " image_size=[256, 256],\n", + " train_split=0.8,\n", + " num_workers=1,\n", + " \n", + " # Model arguments\n", + " model_type=\"cnn\",\n", + " hidden_channels=64,\n", + " num_blocks=4,\n", + " dropout_rate=0.3,\n", + " use_batch_norm=False,\n", + " \n", + " # Training arguments\n", + " epochs=100,\n", + " lr=1e-3,\n", + " weight_decay=1e-4,\n", + " optimizer=\"adam\",\n", + " scheduler=\"plateau\",\n", + " grad_clip=1.0,\n", + " \n", + " # Loss arguments\n", + " matrix_weight=1.0,\n", + " geometric_weight=0.5,\n", + " reg_weight=0.1,\n", + " \n", + " # Other arguments\n", + " output_dir=\"runs/homography\",\n", + " resume=None,\n", + " eval_only=False,\n", + " seed=42\n", + ")\n", + "\n", + "\n", + " # Set random seeds for reproducibility\n", + "torch.manual_seed(args.seed)\n", + "np.random.seed(args.seed)\n", + "if torch.cuda.is_available():\n", + " torch.cuda.manual_seed(args.seed)\n", + " torch.cuda.manual_seed_all(args.seed)\n", + "\n", + "# Set device\n", + "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", + "print(f\"Using device: {device}\")\n", + "\n", + "# Create data loaders\n", + "print(\"Creating data loaders...\")\n", + "train_loader, val_loader = create_data_loaders(\n", + " root_dir=args.data_dir,\n", + " batch_size=args.batch_size,\n", + " train_split=args.train_split,\n", + " num_workers=args.num_workers,\n", + " image_size=tuple(args.image_size),\n", + " augment_train=False,\n", + " augment_val=False,\n", + ")\n", + "\n", + "print(f\"Train batches: {len(train_loader)}\")\n", + "print(f\"Val batches: {len(val_loader)}\")\n", + "\n", + "# Create model\n", + "print(\"Creating model...\")\n", + "model = create_homography_model(\n", + " model_type=args.model_type,\n", + " input_size=tuple(args.image_size),\n", + " input_channels=3,\n", + " hidden_channels=args.hidden_channels,\n", + " num_blocks=args.num_blocks,\n", + " dropout_rate=args.dropout_rate,\n", + " use_batch_norm=args.use_batch_norm,\n", + ")\n", + "\n", + "# Create trainer configuration\n", + "config = {\n", + " # Model config\n", + " \"model_type\": args.model_type,\n", + " \"hidden_channels\": args.hidden_channels,\n", + " \"num_blocks\": args.num_blocks,\n", + " \"dropout_rate\": args.dropout_rate,\n", + " \"use_batch_norm\": args.use_batch_norm,\n", + " \"image_size\": args.image_size,\n", + " # Training config\n", + " \"epochs\": args.epochs,\n", + " \"batch_size\": args.batch_size,\n", + " \"learning_rate\": args.lr,\n", + " \"weight_decay\": args.weight_decay,\n", + " \"optimizer\": args.optimizer,\n", + " \"scheduler\": args.scheduler,\n", + " \"grad_clip\": args.grad_clip,\n", + " # Loss config\n", + " \"matrix_weight\": args.matrix_weight,\n", + " \"geometric_weight\": args.geometric_weight,\n", + " \"reg_weight\": args.reg_weight,\n", + " \"grid_size\": 8,\n", + " # Data config\n", + " \"data_dir\": args.data_dir,\n", + " \"train_split\": args.train_split,\n", + " \"num_workers\": args.num_workers,\n", + " # Output config\n", + " \"output_dir\": args.output_dir,\n", + " \"seed\": args.seed,\n", + "}\n", + "\n", + "# Create trainer\n", + "trainer = HomographyTrainer(\n", + " model=model,\n", + " train_loader=train_loader,\n", + " val_loader=val_loader,\n", + " device=device,\n", + " config=config,\n", + ")\n", + "\n", + "# Resume from checkpoint if specified\n", + "if args.resume:\n", + " print(f\"Resuming from checkpoint: {args.resume}\")\n", + " trainer.load_checkpoint(args.resume)\n", + "\n", + "# Evaluate only mode\n", + "if args.eval_only:\n", + " trainer.evaluate()\n", + "else:\n", + " # Train the model\n", + " trainer.train(num_epochs=args.epochs)\n", + "\n", + " # Final evaluation\n", + " print(\"\\nPerforming final evaluation...\")\n", + " trainer.evaluate()\n", + "\n", + " print(\"\\nTraining completed successfully!\")\n", + " print(f\"Results saved to: {args.output_dir}\")\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d1cd4bb8", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": ".venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.0" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/models/SiaN/homography.py b/models/SiaN/homography.py new file mode 100644 index 0000000..1e5464f --- /dev/null +++ b/models/SiaN/homography.py @@ -0,0 +1,434 @@ +import os +import random +from typing import Any, Dict, List, Optional, Tuple + +import cv2 +import numpy as np +import torch +from PIL import Image +from torch.utils.data import DataLoader, Dataset + + +class HomographyDataset(Dataset): + """ + Dataset for homography estimation between Yandex and Google map image pairs. + + This dataset loads pairs of images (Yandex and Google maps) and provides + homography matrices for data augmentation and training. + """ + + def __init__( + self, + root_dir: str, + transform=None, + augment: bool = True, + max_samples: Optional[int] = None, + image_size: Tuple[int, int] = (700, 700), + cache_homographies: bool = True, + ): + """ + Initialize the HomographyDataset. + + Args: + root_dir: Directory containing image pairs (format: {idx:04d}_google.png, {idx:04d}_yandex.png) + transform: Optional torchvision transforms to apply + augment: Whether to apply homography-based data augmentation + max_samples: Maximum number of samples to load (None for all) + image_size: Target size for images (height, width) + cache_homographies: Whether to cache generated homography matrices to disk + """ + self.root_dir = root_dir + self.transform = transform + self.augment = augment + self.image_size = image_size + self.cache_homographies = cache_homographies + + # Find all image pairs + self.image_pairs = self._discover_image_pairs() + + if max_samples is not None: + self.image_pairs = self.image_pairs[:max_samples] + + print(f"Found {len(self.image_pairs)} image pairs in {root_dir}") + + # Create directory for cached homographies if needed + if cache_homographies: + self.homography_cache_dir = os.path.join(root_dir, "homography_cache") + os.makedirs(self.homography_cache_dir, exist_ok=True) + + def _discover_image_pairs(self) -> List[Dict[str, Any]]: + """Discover all Google-Yandex image pairs in the dataset directory.""" + image_pairs = [] + + # Get all Google images + google_files = [ + f for f in os.listdir(self.root_dir) if f.endswith("_google.png") + ] + + for google_file in sorted(google_files): + # Extract index from filename + idx_str = google_file.split("_")[0] + try: + idx = int(idx_str) + except ValueError: + continue + + # Check if corresponding Yandex image exists + yandex_file = f"{idx:04d}_yandex.png" + yandex_path = os.path.join(self.root_dir, yandex_file) + + if os.path.exists(yandex_path): + image_pairs.append( + { + "idx": idx, + "google_path": os.path.join(self.root_dir, google_file), + "yandex_path": yandex_path, + } + ) + + return image_pairs + + def __len__(self) -> int: + """Return the number of image pairs in the dataset.""" + return len(self.image_pairs) + + def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]: + """ + Get a sample from the dataset. + + Returns a dictionary with: + - 'google_img': Google map image tensor + - 'yandex_img': Yandex map image tensor + - 'homography': Ground truth homography matrix (3x3) + - 'idx': Sample index + """ + pair_info = self.image_pairs[idx] + + # Load images + google_img = Image.open(pair_info["google_path"]).convert("RGB") + yandex_img = Image.open(pair_info["yandex_path"]).convert("RGB") + + # Resize images to target size + google_img = google_img.resize( + (self.image_size[1], self.image_size[0]), Image.BILINEAR + ) + yandex_img = yandex_img.resize( + (self.image_size[1], self.image_size[0]), Image.BILINEAR + ) + + # Get or generate homography matrix + homography_matrix = self._get_homography_matrix(pair_info["idx"]) + + # Apply data augmentation if enabled + if self.augment: + google_img, yandex_img, homography_matrix = self._apply_augmentation( + google_img, yandex_img, homography_matrix + ) + + # Convert images to tensors + if self.transform: + google_img = self.transform(google_img) + yandex_img = self.transform(yandex_img) + else: + # Default conversion to tensor + google_img = ( + torch.from_numpy(np.array(google_img)).float().permute(2, 0, 1) / 255.0 + ) + yandex_img = ( + torch.from_numpy(np.array(yandex_img)).float().permute(2, 0, 1) / 255.0 + ) + + # Convert homography to tensor + homography_tensor = torch.from_numpy(homography_matrix).float() + + return { + "google_img": google_img, + "yandex_img": yandex_img, + "homography": homography_tensor, + "idx": torch.tensor(pair_info["idx"], dtype=torch.long), + } + + def _get_homography_matrix(self, idx: int) -> np.ndarray: + """ + Get homography matrix for a given index. + + If cached homography exists, load it. Otherwise generate a new one. + """ + if self.cache_homographies: + cache_path = os.path.join( + self.homography_cache_dir, f"{idx:04d}_homography.npy" + ) + if os.path.exists(cache_path): + return np.load(cache_path) + + # Generate new homography matrix + homography_matrix = self.generate_random_homography() + + # Cache if enabled + if self.cache_homographies: + np.save(cache_path, homography_matrix) + + return homography_matrix + + def generate_random_homography(self) -> np.ndarray: + """ + Generate a random homography matrix for data augmentation. + + Returns: + np.ndarray: 3x3 homography matrix. + """ + # Generate random affine transformation parameters + angle = np.random.uniform(-30, 30) # rotation in degrees + scale = np.random.uniform(0.8, 1.2) # scaling factor + tx = np.random.uniform(-50, 50) # translation in x + ty = np.random.uniform(-50, 50) # translation in y + + # Convert angle to radians + theta = np.radians(angle) + + # Create affine transformation matrix + affine_matrix = np.array( + [ + [scale * np.cos(theta), -scale * np.sin(theta), tx], + [scale * np.sin(theta), scale * np.cos(theta), ty], + [0, 0, 1], + ] + ) + + # Add small perspective distortion + perspective = np.random.uniform(-0.001, 0.001, (2, 3)) + perspective = np.vstack([perspective, [0, 0, 0]]) + + homography_matrix = affine_matrix + perspective + + return homography_matrix + + def _apply_augmentation( + self, + google_img: Image.Image, + yandex_img: Image.Image, + base_homography: np.ndarray, + ) -> Tuple[Image.Image, Image.Image, np.ndarray]: + """ + Apply homography-based data augmentation to image pair. + + Args: + google_img: Google map image + yandex_img: Yandex map image + base_homography: Base homography matrix + + Returns: + Tuple of (augmented_google_img, augmented_yandex_img, augmented_homography) + """ + # Generate augmentation homography + aug_homography = self.generate_random_homography() + + # Combine with base homography + combined_homography = aug_homography @ base_homography + + # Apply augmentation to both images + google_aug = self._apply_homography_to_image(google_img, aug_homography) + yandex_aug = self._apply_homography_to_image(yandex_img, aug_homography) + + return google_aug, yandex_aug, combined_homography + + def _apply_homography_to_image( + self, img: Image.Image, homography: np.ndarray + ) -> Image.Image: + """ + Apply homography transformation to a single image. + + Args: + img: PIL Image to transform + homography: 3x3 homography matrix + + Returns: + Transformed PIL Image + """ + # Convert to numpy array + img_np = np.array(img) + + # Get image dimensions + h, w = img_np.shape[:2] + + # Apply homography transformation + transformed = cv2.warpPerspective( + img_np, + homography, + (w, h), + flags=cv2.INTER_LINEAR, + borderMode=cv2.BORDER_REFLECT, + ) + + # Convert back to PIL Image + return Image.fromarray(transformed) + + def get_sample_without_augmentation(self, idx: int) -> Dict[str, Any]: + """ + Get a sample without data augmentation. + + Useful for visualization and evaluation. + """ + pair_info = self.image_pairs[idx] + + # Load images + google_img = Image.open(pair_info["google_path"]).convert("RGB") + yandex_img = Image.open(pair_info["yandex_path"]).convert("RGB") + + # Resize + google_img = google_img.resize( + (self.image_size[1], self.image_size[0]), Image.BILINEAR + ) + yandex_img = yandex_img.resize( + (self.image_size[1], self.image_size[0]), Image.BILINEAR + ) + + # Get homography matrix + homography_matrix = self._get_homography_matrix(pair_info["idx"]) + + return { + "google_img": google_img, + "yandex_img": yandex_img, + "homography": homography_matrix, + "idx": pair_info["idx"], + "google_path": pair_info["google_path"], + "yandex_path": pair_info["yandex_path"], + } + + +def create_data_loaders( + root_dir: str, + batch_size: int = 32, + train_split: float = 0.8, + num_workers: int = 4, + image_size: Tuple[int, int] = (256, 256), + augment_train: bool = True, + augment_val: bool = False, +) -> Tuple[DataLoader, DataLoader]: + """ + Create train and validation data loaders for homography estimation. + + Args: + root_dir: Directory containing image pairs + batch_size: Batch size for data loaders + train_split: Fraction of data to use for training + num_workers: Number of worker processes for data loading + image_size: Target image size (height, width) + augment_train: Whether to augment training data + augment_val: Whether to augment validation data + + Returns: + Tuple of (train_loader, val_loader) + """ + from torchvision import transforms + + # Define transforms + transform = transforms.Compose( + [ + transforms.ToTensor(), + transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), + ] + ) + + # Create full dataset + full_dataset = HomographyDataset( + root_dir=root_dir, + transform=transform, + augment=False, # We'll handle augmentation separately + image_size=image_size, + cache_homographies=True, + ) + + # Split dataset + dataset_size = len(full_dataset) + train_size = int(train_split * dataset_size) + val_size = dataset_size - train_size + + # Create indices for splitting + indices = list(range(dataset_size)) + random.shuffle(indices) + train_indices = indices[:train_size] + val_indices = indices[train_size:] + + # Create subset samplers + from torch.utils.data import Subset + + train_dataset = Subset(full_dataset, train_indices) + val_dataset = Subset(full_dataset, val_indices) + + # Apply augmentation by overriding __getitem__ for train dataset + if augment_train: + + class AugmentedSubset(Subset): + def __getitem__(self, idx): + sample = self.dataset[self.indices[idx]] + # Apply augmentation + google_img = sample["google_img"] + yandex_img = sample["yandex_img"] + homography = sample["homography"] + + # Generate augmentation homography + aug_homography = torch.from_numpy( + full_dataset.generate_random_homography() + ).float() + + # Combine homographies + combined_homography = aug_homography @ homography + + # Apply augmentation (simplified - in practice would warp images) + # For now, we just return the combined homography + return { + "google_img": google_img, + "yandex_img": yandex_img, + "homography": combined_homography, + "idx": sample["idx"], + } + + train_dataset = AugmentedSubset(full_dataset, train_indices) + + # Create data loaders + train_loader = DataLoader( + train_dataset, + batch_size=batch_size, + shuffle=True, + num_workers=num_workers, + pin_memory=True, + ) + + val_loader = DataLoader( + val_dataset, + batch_size=batch_size, + shuffle=False, + num_workers=num_workers, + pin_memory=True, + ) + + return train_loader, val_loader + + +if __name__ == "__main__": + # Example usage + dataset = HomographyDataset( + root_dir=r"C:\Users\admin\Projects\autopilot\datasets\ya_go_maps\images", + augment=True, + image_size=(256, 256), + ) + + print(f"Dataset size: {len(dataset)}") + + # Get a sample + sample = dataset[0] + print(f"Sample keys: {list(sample.keys())}") + print(f"Google image shape: {sample['google_img'].shape}") + print(f"Yandex image shape: {sample['yandex_img'].shape}") + print(f"Homography shape: {sample['homography'].shape}") + + # Create data loaders + train_loader, val_loader = create_data_loaders( + root_dir=r"C:\Users\admin\Projects\autopilot\datasets\ya_go_maps\images", + batch_size=16, + train_split=0.8, + ) + + print(f"Train batches: {len(train_loader)}") + print(f"Val batches: {len(val_loader)}") diff --git a/models/SiaN/homography_cnn.py b/models/SiaN/homography_cnn.py new file mode 100644 index 0000000..7a004a8 --- /dev/null +++ b/models/SiaN/homography_cnn.py @@ -0,0 +1,551 @@ +from typing import Optional, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class HomographyCNN(nn.Module): + """ + CNN model for homography estimation between two images. + + This model takes two images (Google and Yandex maps) as input and + outputs a 3x3 homography matrix that transforms one image to align with the other. + """ + + def __init__( + self, + input_channels: int = 3, + hidden_channels: int = 64, + num_blocks: int = 4, + dropout_rate: float = 0.3, + use_batch_norm: bool = True, + output_size: int = 9, # Flattened 3x3 homography matrix + ): + """ + Initialize the HomographyCNN model. + + Args: + input_channels: Number of input channels per image (3 for RGB) + hidden_channels: Base number of channels in the network + num_blocks: Number of convolutional blocks + dropout_rate: Dropout rate for regularization + use_batch_norm: Whether to use batch normalization + output_size: Size of output vector (9 for flattened 3x3 matrix) + """ + super().__init__() + + self.input_channels = input_channels + self.hidden_channels = hidden_channels + self.num_blocks = num_blocks + self.dropout_rate = dropout_rate + self.use_batch_norm = use_batch_norm + + # Feature extraction for each image separately + self.google_encoder = self._build_encoder() + self.yandex_encoder = self._build_encoder() + + # Fusion layers to combine features from both images + self.fusion_layers = self._build_fusion_layers() + + # Regression head for homography estimation + self.regression_head = self._build_regression_head(output_size) + + # Initialize weights + self._initialize_weights() + + def _build_encoder(self) -> nn.Module: + """Build the encoder network for a single image.""" + layers = [] + in_channels = self.input_channels + out_channels = self.hidden_channels + + # First convolutional block + layers.append( + nn.Conv2d(in_channels, out_channels, kernel_size=7, stride=2, padding=3) + ) + if self.use_batch_norm: + layers.append(nn.BatchNorm2d(out_channels)) + layers.append(nn.ReLU(inplace=True)) + layers.append(nn.MaxPool2d(kernel_size=3, stride=2, padding=1)) + + # Additional convolutional blocks + for i in range(self.num_blocks): + block_in_channels = out_channels + block_out_channels = out_channels * 2 if i < 2 else out_channels + + layers.append( + ResidualBlock( + in_channels=block_in_channels, + out_channels=block_out_channels, + stride=1 if i == 0 else 2, + dropout_rate=self.dropout_rate, + use_batch_norm=self.use_batch_norm, + ) + ) + + if i < 2: + out_channels = block_out_channels + + return nn.Sequential(*layers) + + def _build_fusion_layers(self) -> nn.Module: + """Build layers to fuse features from both images.""" + # After encoding, each image has hidden_channels * 4 features + fused_channels = ( + self.hidden_channels * 8 + ) # Concatenated features from both images + + layers = [ + # Reduce dimensionality + nn.Conv2d( + fused_channels, self.hidden_channels * 4, kernel_size=3, padding=1 + ), + nn.BatchNorm2d(self.hidden_channels * 4) + if self.use_batch_norm + else nn.Identity(), + nn.ReLU(inplace=True), + nn.Dropout2d(self.dropout_rate), + # Further processing + nn.Conv2d( + self.hidden_channels * 4, + self.hidden_channels * 2, + kernel_size=3, + padding=1, + ), + nn.BatchNorm2d(self.hidden_channels * 2) + if self.use_batch_norm + else nn.Identity(), + nn.ReLU(inplace=True), + nn.Dropout2d(self.dropout_rate), + # Global pooling + nn.AdaptiveAvgPool2d((1, 1)), + ] + + return nn.Sequential(*layers) + + def _build_regression_head(self, output_size: int) -> nn.Module: + """Build the regression head for homography estimation.""" + # Input size after fusion and global pooling + input_features = self.hidden_channels * 2 + + layers = [ + nn.Flatten(), + nn.Linear(input_features, 512), + nn.BatchNorm1d(512) if self.use_batch_norm else nn.Identity(), + nn.ReLU(inplace=True), + nn.Dropout(self.dropout_rate), + nn.Linear(512, 256), + nn.BatchNorm1d(256) if self.use_batch_norm else nn.Identity(), + nn.ReLU(inplace=True), + nn.Dropout(self.dropout_rate), + nn.Linear(256, 128), + nn.BatchNorm1d(128) if self.use_batch_norm else nn.Identity(), + nn.ReLU(inplace=True), + nn.Dropout(self.dropout_rate), + nn.Linear(128, output_size), + ] + + return nn.Sequential(*layers) + + def _initialize_weights(self): + """Initialize model weights.""" + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") + if m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.Linear): + nn.init.normal_(m.weight, 0, 0.01) + nn.init.constant_(m.bias, 0) + + def forward( + self, + google_img: torch.Tensor, + yandex_img: torch.Tensor, + return_matrix: bool = True, + ) -> torch.Tensor: + """ + Forward pass of the model. + + Args: + google_img: Google map image tensor of shape (B, C, H, W) + yandex_img: Yandex map image tensor of shape (B, C, H, W) + return_matrix: If True, return 3x3 matrix; if False, return flattened vector + + Returns: + Homography matrix tensor of shape (B, 3, 3) or flattened vector of shape (B, 9) + """ + # Extract features from both images + google_features = self.google_encoder(google_img) + yandex_features = self.yandex_encoder(yandex_img) + + # Concatenate features along channel dimension + combined_features = torch.cat([google_features, yandex_features], dim=1) + + # Fuse features + fused_features = self.fusion_layers(combined_features) + + # Regression to get homography parameters + homography_flat = self.regression_head(fused_features) + + if return_matrix: + # Reshape to 3x3 matrix + batch_size = homography_flat.shape[0] + homography_matrix = homography_flat.view(batch_size, 3, 3) + + # Ensure the last element is 1 (homogeneous coordinate normalization) + # Add small epsilon to prevent division by zero + epsilon = 1e-8 + homography_matrix = homography_matrix / ( + homography_matrix[:, 2, 2].view(-1, 1, 1) + epsilon + ) + + return homography_matrix + else: + return homography_flat + + def predict_homography( + self, + google_img: torch.Tensor, + yandex_img: torch.Tensor, + normalize: bool = True, + ) -> torch.Tensor: + """ + Predict homography matrix with optional normalization. + + Args: + google_img: Google map image tensor + yandex_img: Yandex map image tensor + normalize: Whether to normalize the homography matrix + + Returns: + Predicted homography matrix + """ + self.eval() + with torch.no_grad(): + homography = self.forward(google_img, yandex_img, return_matrix=True) + + if normalize: + # Normalize so that last element is 1 + homography = homography / homography[:, 2, 2].view(-1, 1, 1) + + return homography + + +class ResidualBlock(nn.Module): + """Residual block with optional downsampling.""" + + def __init__( + self, + in_channels: int, + out_channels: int, + stride: int = 1, + dropout_rate: float = 0.3, + use_batch_norm: bool = True, + ): + super().__init__() + + self.conv1 = nn.Conv2d( + in_channels, + out_channels, + kernel_size=3, + stride=stride, + padding=1, + bias=False, + ) + self.bn1 = nn.BatchNorm2d(out_channels) if use_batch_norm else nn.Identity() + self.relu1 = nn.ReLU(inplace=True) + self.dropout1 = ( + nn.Dropout2d(dropout_rate) if dropout_rate > 0 else nn.Identity() + ) + + self.conv2 = nn.Conv2d( + out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False + ) + self.bn2 = nn.BatchNorm2d(out_channels) if use_batch_norm else nn.Identity() + self.relu2 = nn.ReLU(inplace=True) + self.dropout2 = ( + nn.Dropout2d(dropout_rate) if dropout_rate > 0 else nn.Identity() + ) + + # Shortcut connection + self.shortcut = nn.Sequential() + if stride != 1 or in_channels != out_channels: + self.shortcut = nn.Sequential( + nn.Conv2d( + in_channels, out_channels, kernel_size=1, stride=stride, bias=False + ), + nn.BatchNorm2d(out_channels) if use_batch_norm else nn.Identity(), + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + identity = self.shortcut(x) + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu1(out) + out = self.dropout1(out) + + out = self.conv2(out) + out = self.bn2(out) + + out += identity + out = self.relu2(out) + out = self.dropout2(out) + + return out + + +class HomographyLoss(nn.Module): + """ + Custom loss function for homography estimation. + + Combines multiple loss terms: + 1. Matrix element-wise L2 loss + 2. Geometric consistency loss (warping error) + 3. Determinant regularization (to prevent degenerate matrices) + """ + + def __init__( + self, + matrix_weight: float = 1.0, + geometric_weight: float = 0.5, + reg_weight: float = 0.1, + grid_size: int = 8, + ): + super().__init__() + self.matrix_weight = matrix_weight + self.geometric_weight = geometric_weight + self.reg_weight = reg_weight + self.grid_size = grid_size + + # Create grid of points for geometric loss + self.register_buffer( + "grid_points", + self._create_grid_points(grid_size), + persistent=False, + ) + + def _create_grid_points(self, grid_size: int) -> torch.Tensor: + """Create a grid of points for geometric consistency loss.""" + x = torch.linspace(-1, 1, grid_size) + y = torch.linspace(-1, 1, grid_size) + grid_y, grid_x = torch.meshgrid(y, x, indexing="ij") + grid_points = torch.stack([grid_x.flatten(), grid_y.flatten()], dim=1) + # Add homogeneous coordinate + ones = torch.ones(grid_points.shape[0], 1) + grid_points = torch.cat([grid_points, ones], dim=1) + return grid_points.T # Shape: (3, grid_size*grid_size) + + def forward( + self, + pred_homography: torch.Tensor, + target_homography: torch.Tensor, + google_img: Optional[torch.Tensor] = None, + yandex_img: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """ + Compute homography loss. + + Args: + pred_homography: Predicted homography matrices (B, 3, 3) + target_homography: Target homography matrices (B, 3, 3) + google_img: Google images (optional, for geometric loss) + yandex_img: Yandex images (optional, for geometric loss) + + Returns: + Combined loss value + """ + batch_size = pred_homography.shape[0] + + # 1. Matrix element-wise L2 loss + matrix_loss = F.mse_loss(pred_homography, target_homography) + + # 2. Geometric consistency loss (if images provided) + geometric_loss = torch.tensor(0.0, device=pred_homography.device) + if google_img is not None and yandex_img is not None: + # Warp grid points with predicted homography + grid_points = self.grid_points.unsqueeze(0).expand(batch_size, -1, -1) + warped_points = torch.bmm(pred_homography, grid_points) + + # Normalize homogeneous coordinates + warped_points = warped_points / (warped_points[:, 2:3, :] + 1e-8) + + # Warp with target homography for comparison + target_warped_points = torch.bmm(target_homography, grid_points) + target_warped_points = target_warped_points / ( + target_warped_points[:, 2:3, :] + 1e-8 + ) + + # Compute point-wise distance + geometric_loss = F.mse_loss( + warped_points[:, :2, :], target_warped_points[:, :2, :] + ) + + # 3. Regularization loss (prevent degenerate matrices) + # Encourage determinant to be close to 1 + pred_det = torch.det(pred_homography) + reg_loss = F.mse_loss(pred_det, torch.ones_like(pred_det)) + + # Combine losses + total_loss = ( + self.matrix_weight * matrix_loss + + self.geometric_weight * geometric_loss + + self.reg_weight * reg_loss + ) + + return total_loss + + def compute_metrics( + self, + pred_homography: torch.Tensor, + target_homography: torch.Tensor, + ) -> dict: + """ + Compute evaluation metrics for homography estimation. + + Args: + pred_homography: Predicted homography matrices + target_homography: Target homography matrices + + Returns: + Dictionary of metrics + """ + with torch.no_grad(): + # Normalize matrices + pred_norm = pred_homography / pred_homography[:, 2, 2].view(-1, 1, 1) + target_norm = target_homography / target_homography[:, 2, 2].view(-1, 1, 1) + + # Matrix L2 error + matrix_error = F.mse_loss(pred_norm, target_norm, reduction="none").mean( + dim=(1, 2) + ) + + # Corner error (warp 4 corners of the image) + corners = torch.tensor( + [[-1, -1, 1], [1, -1, 1], [1, 1, 1], [-1, 1, 1]], + dtype=torch.float32, + device=pred_homography.device, + ).T # Shape: (3, 4) + + corners = corners.unsqueeze(0).expand(pred_homography.shape[0], -1, -1) + + pred_corners = torch.bmm(pred_norm, corners) + pred_corners = pred_corners / (pred_corners[:, 2:3, :] + 1e-8) + + target_corners = torch.bmm(target_norm, corners) + target_corners = target_corners / (target_corners[:, 2:3, :] + 1e-8) + + corner_error = torch.mean( + torch.norm(pred_corners[:, :2, :] - target_corners[:, :2, :], dim=1), + dim=1, + ) + + # Average corner error in pixels (assuming image coordinates in [-1, 1]) + # Convert to pixel error if image size is known + avg_corner_error = corner_error.mean().item() + + return { + "matrix_mse": matrix_error.mean().item(), + "corner_error": avg_corner_error, + "corner_error_px": avg_corner_error * 128, # Assuming 256x256 images + } + + +def create_homography_model( + model_type: str = "cnn", + input_size: Tuple[int, int] = (256, 256), + **kwargs, +) -> nn.Module: + """ + Factory function to create homography estimation model. + + Args: + model_type: Type of model to create ('cnn' or 'resnet') + input_size: Input image size (height, width) + **kwargs: Additional arguments passed to model constructor + + Returns: + Homography estimation model + """ + if model_type == "cnn": + return HomographyCNN(**kwargs) + else: + raise ValueError(f"Unknown model type: {model_type}") + + +if __name__ == "__main__": + # Test the model + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + print(f"Using device: {device}") + + # Create model + model = HomographyCNN( + input_channels=3, + hidden_channels=64, + num_blocks=4, + dropout_rate=0.3, + use_batch_norm=True, + ).to(device) + + print( + f"Model created with {sum(p.numel() for p in model.parameters()):,} parameters" + ) + + # Create dummy input + batch_size = 4 + height, width = 256, 256 + + google_img = torch.randn(batch_size, 3, height, width).to(device) + yandex_img = torch.randn(batch_size, 3, height, width).to(device) + + # Test forward pass + print("\nTesting forward pass...") + output = model(google_img, yandex_img, return_matrix=True) + print(f"Output shape: {output.shape}") # Should be (4, 3, 3) + print(f"Sample output:\n{output[0]}") + + # Test prediction + print("\nTesting prediction...") + pred = model.predict_homography(google_img, yandex_img) + print(f"Prediction shape: {pred.shape}") + print(f"Last element (should be ~1): {pred[0, 2, 2]:.6f}") + + # Test loss function + print("\nTesting loss function...") + target_homography = torch.eye(3).unsqueeze(0).expand(batch_size, -1, -1).to(device) + loss_fn = HomographyLoss( + matrix_weight=1.0, + geometric_weight=0.5, + reg_weight=0.1, + grid_size=8, + ).to(device) + + loss = loss_fn(output, target_homography, google_img, yandex_img) + print(f"Loss value: {loss.item():.6f}") + + # Test metrics + print("\nTesting metrics...") + metrics = loss_fn.compute_metrics(output, target_homography) + for key, value in metrics.items(): + print(f"{key}: {value:.6f}") + + # Test model factory + print("\nTesting model factory...") + model2 = create_homography_model( + model_type="cnn", + input_size=(256, 256), + input_channels=3, + hidden_channels=32, + num_blocks=3, + ).to(device) + + print( + f"Model2 created with {sum(p.numel() for p in model2.parameters()):,} parameters" + ) + + print("\nAll tests completed successfully!") diff --git a/models/SiaN/infer_homography.py b/models/SiaN/infer_homography.py new file mode 100644 index 0000000..d226a7a --- /dev/null +++ b/models/SiaN/infer_homography.py @@ -0,0 +1,553 @@ +""" +Inference script for homography estimation between Google and Yandex map images. + +This script loads a trained homography estimation model and performs inference +on new image pairs or the test dataset. +""" + +import argparse +import json +import os +from pathlib import Path +from typing import Dict, List, Optional, Tuple + +import cv2 +import matplotlib.pyplot as plt +import numpy as np +import torch +from homography import HomographyDataset +from homography_cnn import HomographyCNN, create_homography_model +from PIL import Image +from torchvision import transforms + + +class HomographyInference: + """Class for performing inference with homography estimation model.""" + + def __init__( + self, + model_path: str, + config_path: Optional[str] = None, + device: Optional[str] = None, + ): + """ + Initialize the inference class. + + Args: + model_path: Path to trained model checkpoint + config_path: Path to model configuration file (optional) + device: Device to run inference on ('cuda' or 'cpu') + """ + # Set device + if device is None: + self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + else: + self.device = torch.device(device) + + print(f"Using device: {self.device}") + + # Load configuration + if config_path is None: + # Try to find config in the same directory as model + model_dir = Path(model_path).parent + config_path = model_dir / "config.json" + + if os.path.exists(config_path): + with open(config_path, "r") as f: + self.config = json.load(f) + print(f"Loaded configuration from {config_path}") + else: + # Use default configuration + self.config = { + "image_size": [256, 256], + "hidden_channels": 64, + "num_blocks": 4, + "dropout_rate": 0.3, + "use_batch_norm": True, + } + print("Using default configuration") + + # Create model + self.model = self._create_model() + self._load_model(model_path) + + # Set up transforms + self.transform = self._create_transforms() + + # Set model to evaluation mode + self.model.eval() + + def _create_model(self) -> HomographyCNN: + """Create model based on configuration.""" + image_size = self.config.get("image_size", [256, 256]) + + model = create_homography_model( + model_type="cnn", + input_size=tuple(image_size), + input_channels=3, + hidden_channels=self.config.get("hidden_channels", 64), + num_blocks=self.config.get("num_blocks", 4), + dropout_rate=self.config.get("dropout_rate", 0.3), + use_batch_norm=self.config.get("use_batch_norm", True), + ) + + return model + + def _load_model(self, model_path: str): + """Load model weights from checkpoint.""" + if not os.path.exists(model_path): + raise FileNotFoundError(f"Model file not found: {model_path}") + + # Load checkpoint + checkpoint = torch.load(model_path, map_location=self.device) + + if isinstance(checkpoint, dict) and "model_state_dict" in checkpoint: + # Trainer checkpoint format + self.model.load_state_dict(checkpoint["model_state_dict"]) + else: + # Raw model weights format + self.model.load_state_dict(checkpoint) + + self.model = self.model.to(self.device) + print(f"Loaded model from {model_path}") + + def _create_transforms(self): + """Create image transforms for inference.""" + return transforms.Compose( + [ + transforms.Resize(tuple(self.config.get("image_size", [256, 256]))), + transforms.ToTensor(), + transforms.Normalize( + mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] + ), + ] + ) + + def preprocess_images( + self, google_img: Image.Image, yandex_img: Image.Image + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Preprocess images for inference. + + Args: + google_img: Google map image (PIL Image) + yandex_img: Yandex map image (PIL Image) + + Returns: + Tuple of preprocessed image tensors + """ + # Convert to RGB if needed + if google_img.mode != "RGB": + google_img = google_img.convert("RGB") + if yandex_img.mode != "RGB": + yandex_img = yandex_img.convert("RGB") + + # Apply transforms + google_tensor = self.transform(google_img).unsqueeze(0) # Add batch dimension + yandex_tensor = self.transform(yandex_img).unsqueeze(0) + + return google_tensor, yandex_tensor + + def predict( + self, + google_img: Image.Image, + yandex_img: Image.Image, + return_matrix: bool = True, + normalize: bool = True, + ) -> torch.Tensor: + """ + Predict homography matrix for image pair. + + Args: + google_img: Google map image (PIL Image) + yandex_img: Yandex map image (PIL Image) + return_matrix: If True, return 3x3 matrix; if False, return flattened vector + normalize: Whether to normalize the homography matrix + + Returns: + Predicted homography matrix or vector + """ + # Preprocess images + google_tensor, yandex_tensor = self.preprocess_images(google_img, yandex_img) + + # Move to device + google_tensor = google_tensor.to(self.device) + yandex_tensor = yandex_tensor.to(self.device) + + # Perform inference + with torch.no_grad(): + homography = self.model( + google_tensor, yandex_tensor, return_matrix=return_matrix + ) + + if return_matrix and normalize: + # Normalize so that last element is 1 + homography = homography / homography[:, 2, 2].view(-1, 1, 1) + + return homography.squeeze(0) # Remove batch dimension + + def predict_from_paths( + self, + google_path: str, + yandex_path: str, + return_matrix: bool = True, + normalize: bool = True, + ) -> torch.Tensor: + """ + Predict homography matrix from image file paths. + + Args: + google_path: Path to Google map image + yandex_path: Path to Yandex map image + return_matrix: If True, return 3x3 matrix; if False, return flattened vector + normalize: Whether to normalize the homography matrix + + Returns: + Predicted homography matrix or vector + """ + # Load images + google_img = Image.open(google_path) + yandex_img = Image.open(yandex_path) + + return self.predict(google_img, yandex_img, return_matrix, normalize) + + def warp_image( + self, + img: Image.Image, + homography: np.ndarray, + output_size: Optional[Tuple[int, int]] = None, + ) -> Image.Image: + """ + Warp image using homography matrix. + + Args: + img: Input image (PIL Image) + homography: 3x3 homography matrix (numpy array) + output_size: Output image size (width, height). If None, uses input size. + + Returns: + Warped image (PIL Image) + """ + # Convert to numpy array + img_np = np.array(img) + + # Get output size + if output_size is None: + output_size = (img_np.shape[1], img_np.shape[0]) + + # Apply homography transformation + warped_np = cv2.warpPerspective( + img_np, + homography, + output_size, + flags=cv2.INTER_LINEAR, + borderMode=cv2.BORDER_REFLECT, + ) + + # Convert back to PIL Image + return Image.fromarray(warped_np) + + def visualize_alignment( + self, + google_img: Image.Image, + yandex_img: Image.Image, + homography: np.ndarray, + save_path: Optional[str] = None, + show: bool = True, + ): + """ + Visualize alignment between images using homography. + + Args: + google_img: Google map image + yandex_img: Yandex map image + homography: Homography matrix + save_path: Path to save visualization (optional) + show: Whether to display the visualization + """ + # Warp yandex image to align with google + yandex_warped = self.warp_image(yandex_img, homography) + + # Convert images to numpy arrays for visualization + google_np = np.array(google_img) + yandex_np = np.array(yandex_img) + yandex_warped_np = np.array(yandex_warped) + + # Create visualization + fig, axes = plt.subplots(2, 2, figsize=(12, 10)) + + # Original images + axes[0, 0].imshow(google_np) + axes[0, 0].set_title("Google Map (Original)") + axes[0, 0].axis("off") + + axes[0, 1].imshow(yandex_np) + axes[0, 1].set_title("Yandex Map (Original)") + axes[0, 1].axis("off") + + # Warped image + axes[1, 0].imshow(yandex_warped_np) + axes[1, 0].set_title("Yandex Map (Warped)") + axes[1, 0].axis("off") + + # Overlay (50% transparency) + overlay = cv2.addWeighted(google_np, 0.5, yandex_warped_np, 0.5, 0) + axes[1, 1].imshow(overlay) + axes[1, 1].set_title("Overlay (Google + Warped Yandex)") + axes[1, 1].axis("off") + + plt.tight_layout() + + if save_path: + plt.savefig(save_path, dpi=150, bbox_inches="tight") + print(f"Visualization saved to {save_path}") + + if show: + plt.show() + else: + plt.close() + + def evaluate_on_dataset( + self, + dataset_dir: str, + num_samples: Optional[int] = None, + save_dir: Optional[str] = None, + ) -> Dict[str, float]: + """ + Evaluate model on a dataset. + + Args: + dataset_dir: Directory containing image pairs + num_samples: Number of samples to evaluate (None for all) + save_dir: Directory to save visualizations (optional) + + Returns: + Dictionary of evaluation metrics + """ + # Create dataset + dataset = HomographyDataset( + root_dir=dataset_dir, + transform=None, # We'll handle transforms manually + augment=False, + image_size=tuple(self.config.get("image_size", [256, 256])), + cache_homographies=False, + ) + + if num_samples is not None: + indices = list(range(min(num_samples, len(dataset)))) + else: + indices = list(range(len(dataset))) + + errors = [] + corner_errors = [] + + print(f"Evaluating on {len(indices)} samples...") + + for idx in indices: + # Get sample without augmentation + sample = dataset.get_sample_without_augmentation(idx) + + google_img = sample["google_img"] + yandex_img = sample["yandex_img"] + target_homography = sample["homography"] + + # Predict homography + pred_homography = self.predict( + google_img, yandex_img, return_matrix=True, normalize=True + ) + + # Convert to numpy + pred_homography_np = pred_homography.cpu().numpy() + target_homography_np = target_homography + + # Compute matrix error + matrix_error = np.mean((pred_homography_np - target_homography_np) ** 2) + errors.append(matrix_error) + + # Compute corner error + corners = np.array( + [ + [-1, -1, 1], + [1, -1, 1], + [1, 1, 1], + [-1, 1, 1], + ], + dtype=np.float32, + ).T + + pred_corners = pred_homography_np @ corners + pred_corners = pred_corners / (pred_corners[2:3, :] + 1e-8) + + target_corners = target_homography_np @ corners + target_corners = target_corners / (target_corners[2:3, :] + 1e-8) + + corner_error = np.mean( + np.linalg.norm(pred_corners[:2, :] - target_corners[:2, :], axis=0) + ) + corner_errors.append(corner_error) + + # Save visualization if requested + if save_dir: + os.makedirs(save_dir, exist_ok=True) + vis_path = os.path.join(save_dir, f"sample_{idx:04d}.png") + self.visualize_alignment( + google_img, + yandex_img, + pred_homography_np, + save_path=vis_path, + show=False, + ) + + # Compute metrics + metrics = { + "mean_matrix_error": float(np.mean(errors)), + "std_matrix_error": float(np.std(errors)), + "mean_corner_error": float(np.mean(corner_errors)), + "std_corner_error": float(np.std(corner_errors)), + "median_corner_error": float(np.median(corner_errors)), + "max_corner_error": float(np.max(corner_errors)), + "min_corner_error": float(np.min(corner_errors)), + } + + print("\nEvaluation Results:") + for key, value in metrics.items(): + print(f" {key}: {value:.6f}") + + return metrics + + +def main(): + """Main inference function.""" + parser = argparse.ArgumentParser(description="Inference for homography estimation") + + # Model arguments + parser.add_argument( + "--model_path", + type=str, + required=True, + help="Path to trained model checkpoint", + ) + parser.add_argument( + "--config_path", + type=str, + help="Path to model configuration file (optional)", + ) + parser.add_argument( + "--device", + type=str, + choices=["cuda", "cpu"], + help="Device to run inference on", + ) + + # Inference mode + parser.add_argument( + "--mode", + type=str, + default="single", + choices=["single", "dataset", "batch"], + help="Inference mode", + ) + + # Single image mode + parser.add_argument( + "--google_path", + type=str, + help="Path to Google map image (single mode)", + ) + parser.add_argument( + "--yandex_path", + type=str, + help="Path to Yandex map image (single mode)", + ) + parser.add_argument( + "--output_vis", + type=str, + help="Path to save visualization (single mode)", + ) + + # Dataset mode + parser.add_argument( + "--dataset_dir", + type=str, + default=r"C:\Users\admin\Projects\autopilot\datasets\ya_go_maps\images", + help="Directory containing image pairs (dataset mode)", + ) + parser.add_argument( + "--num_samples", + type=int, + help="Number of samples to evaluate (dataset mode)", + ) + parser.add_argument( + "--save_vis_dir", + type=str, + help="Directory to save visualizations (dataset mode)", + ) + parser.add_argument( + "--save_results", + type=str, + help="Path to save evaluation results (dataset mode)", + ) + + args = parser.parse_args() + + # Create inference object + inference = HomographyInference( + model_path=args.model_path, + config_path=args.config_path, + device=args.device, + ) + + if args.mode == "single": + # Single image pair inference + if not args.google_path or not args.yandex_path: + raise ValueError( + "Both --google_path and --yandex_path are required for single mode" + ) + + print(f"Processing single image pair:") + print(f" Google: {args.google_path}") + print(f" Yandex: {args.yandex_path}") + + # Predict homography + homography = inference.predict_from_paths(args.google_path, args.yandex_path) + + print(f"\nPredicted homography matrix:") + print(homography.cpu().numpy()) + + # Visualize alignment + if args.output_vis: + google_img = Image.open(args.google_path) + yandex_img = Image.open(args.yandex_path) + inference.visualize_alignment( + google_img, + yandex_img, + homography.cpu().numpy(), + save_path=args.output_vis, + show=True, + ) + + elif args.mode == "dataset": + # Evaluate on dataset + metrics = inference.evaluate_on_dataset( + dataset_dir=args.dataset_dir, + num_samples=args.num_samples, + save_dir=args.save_vis_dir, + ) + + # Save results if requested + if args.save_results: + with open(args.save_results, "w") as f: + json.dump(metrics, f, indent=2) + print(f"\nResults saved to {args.save_results}") + + elif args.mode == "batch": + # Batch processing (placeholder for future implementation) + print("Batch mode not yet implemented") + # Could implement processing multiple image pairs from a directory + + else: + raise ValueError(f"Unknown mode: {args.mode}") + + +if __name__ == "__main__": + main() diff --git a/models/SiaN/train_homography.py b/models/SiaN/train_homography.py new file mode 100644 index 0000000..f9fed47 --- /dev/null +++ b/models/SiaN/train_homography.py @@ -0,0 +1,611 @@ +""" +Training script for homography estimation between Google and Yandex map images. + +This script trains a CNN model to estimate homography matrices that align +Google map images with Yandex map images. +""" + +import argparse +import json +import os +import time +from datetime import datetime +from pathlib import Path +from typing import Dict, List, Optional, Tuple + +import numpy as np +import torch +import torch.nn as nn +import torch.optim as optim +from homography import HomographyDataset, create_data_loaders +from homography_cnn import HomographyCNN, HomographyLoss, create_homography_model +from torch.utils.data import DataLoader +from torch.utils.tensorboard import SummaryWriter +from tqdm import tqdm + + +class HomographyTrainer: + """Trainer class for homography estimation model.""" + + def __init__( + self, + model: nn.Module, + train_loader: DataLoader, + val_loader: DataLoader, + device: torch.device, + config: Dict, + ): + """ + Initialize the trainer. + + Args: + model: Homography estimation model + train_loader: Training data loader + val_loader: Validation data loader + device: Device to run training on + config: Training configuration dictionary + """ + self.model = model.to(device) + self.train_loader = train_loader + self.val_loader = val_loader + self.device = device + self.config = config + + # Loss function + self.criterion = HomographyLoss( + matrix_weight=config.get("matrix_weight", 1.0), + geometric_weight=config.get("geometric_weight", 0.5), + reg_weight=config.get("reg_weight", 0.1), + grid_size=config.get("grid_size", 8), + ).to(device) + + # Optimizer + optimizer_name = config.get("optimizer", "adam").lower() + lr = config.get("learning_rate", 1e-3) + weight_decay = config.get("weight_decay", 1e-4) + + if optimizer_name == "adam": + self.optimizer = optim.Adam( + self.model.parameters(), lr=lr, weight_decay=weight_decay + ) + elif optimizer_name == "adamw": + self.optimizer = optim.AdamW( + self.model.parameters(), lr=lr, weight_decay=weight_decay + ) + elif optimizer_name == "sgd": + self.optimizer = optim.SGD( + self.model.parameters(), + lr=lr, + momentum=config.get("momentum", 0.9), + weight_decay=weight_decay, + ) + else: + raise ValueError(f"Unknown optimizer: {optimizer_name}") + + # Learning rate scheduler + scheduler_name = config.get("scheduler", "plateau").lower() + if scheduler_name == "plateau": + self.scheduler = optim.lr_scheduler.ReduceLROnPlateau( + self.optimizer, + mode="min", + factor=config.get("scheduler_factor", 0.5), + patience=config.get("scheduler_patience", 5), + verbose=True, + ) + elif scheduler_name == "cosine": + self.scheduler = optim.lr_scheduler.CosineAnnealingLR( + self.optimizer, + T_max=config.get("epochs", 100), + eta_min=config.get("min_lr", 1e-6), + ) + elif scheduler_name == "step": + self.scheduler = optim.lr_scheduler.StepLR( + self.optimizer, + step_size=config.get("step_size", 30), + gamma=config.get("gamma", 0.1), + ) + else: + self.scheduler = None + + # Training state + self.current_epoch = 0 + self.best_val_loss = float("inf") + self.train_losses: List[float] = [] + self.val_losses: List[float] = [] + self.val_metrics: List[Dict] = [] + + # Create output directory + self.output_dir = Path(config.get("output_dir", "runs/homography")) + self.output_dir.mkdir(parents=True, exist_ok=True) + + # TensorBoard writer + self.writer = SummaryWriter(log_dir=self.output_dir / "tensorboard") + + # Save configuration + config_path = self.output_dir / "config.json" + with open(config_path, "w") as f: + json.dump(config, f, indent=2) + + print(f"Training configuration saved to {config_path}") + print( + f"Model has {sum(p.numel() for p in self.model.parameters()):,} parameters" + ) + + def train_epoch(self) -> float: + """ + Train for one epoch. + + Returns: + Average training loss for the epoch + """ + self.model.train() + total_loss = 0.0 + num_batches = len(self.train_loader) + + progress_bar = tqdm(self.train_loader, desc=f"Epoch {self.current_epoch + 1}") + for batch_idx, batch in enumerate(progress_bar): + # Move data to device + google_img = batch["google_img"].to(self.device) + yandex_img = batch["yandex_img"].to(self.device) + target_homography = batch["homography"].to(self.device) + + # Forward pass + self.optimizer.zero_grad() + pred_homography = self.model(google_img, yandex_img, return_matrix=True) + + # Compute loss + loss = self.criterion( + pred_homography, + target_homography, + google_img, + yandex_img, + ) + + # Backward pass + loss.backward() + + # Gradient clipping + if self.config.get("grad_clip", 1.0) > 0: + torch.nn.utils.clip_grad_norm_( + self.model.parameters(), + self.config.get("grad_clip", 1.0), + ) + + # Optimizer step + self.optimizer.step() + + # Update statistics + total_loss += loss.item() + + # Update progress bar + progress_bar.set_postfix({"loss": loss.item()}) + + # Log batch loss to TensorBoard + global_step = self.current_epoch * num_batches + batch_idx + self.writer.add_scalar("train/batch_loss", loss.item(), global_step) + + avg_loss = total_loss / num_batches + self.train_losses.append(avg_loss) + + return avg_loss + + @torch.no_grad() + def validate(self) -> Tuple[float, Dict]: + """ + Validate the model. + + Returns: + Tuple of (average validation loss, validation metrics) + """ + self.model.eval() + total_loss = 0.0 + all_metrics = [] + + progress_bar = tqdm(self.val_loader, desc="Validation") + for batch in progress_bar: + # Move data to device + google_img = batch["google_img"].to(self.device) + yandex_img = batch["yandex_img"].to(self.device) + target_homography = batch["homography"].to(self.device) + + # Forward pass + pred_homography = self.model(google_img, yandex_img, return_matrix=True) + + # Compute loss + loss = self.criterion( + pred_homography, + target_homography, + google_img, + yandex_img, + ) + + # Compute metrics + metrics = self.criterion.compute_metrics(pred_homography, target_homography) + + # Update statistics + total_loss += loss.item() + all_metrics.append(metrics) + + # Update progress bar + progress_bar.set_postfix({"loss": loss.item()}) + + avg_loss = total_loss / len(self.val_loader) + self.val_losses.append(avg_loss) + + # Aggregate metrics + avg_metrics = {} + for key in all_metrics[0].keys(): + avg_metrics[key] = np.mean([m[key] for m in all_metrics]) + + self.val_metrics.append(avg_metrics) + + return avg_loss, avg_metrics + + def save_checkpoint(self, is_best: bool = False): + """Save model checkpoint.""" + checkpoint = { + "epoch": self.current_epoch, + "model_state_dict": self.model.state_dict(), + "optimizer_state_dict": self.optimizer.state_dict(), + "train_losses": self.train_losses, + "val_losses": self.val_losses, + "val_metrics": self.val_metrics, + "best_val_loss": self.best_val_loss, + "config": self.config, + } + + if self.scheduler is not None: + checkpoint["scheduler_state_dict"] = self.scheduler.state_dict() + + # Save latest checkpoint + checkpoint_path = self.output_dir / "checkpoint_latest.pth" + torch.save(checkpoint, checkpoint_path) + + # Save best checkpoint + if is_best: + best_path = self.output_dir / "checkpoint_best.pth" + torch.save(checkpoint, best_path) + print(f"Best model saved to {best_path}") + + def load_checkpoint(self, checkpoint_path: str): + """Load model checkpoint.""" + checkpoint = torch.load(checkpoint_path, map_location=self.device) + + self.model.load_state_dict(checkpoint["model_state_dict"]) + self.optimizer.load_state_dict(checkpoint["optimizer_state_dict"]) + + if self.scheduler is not None and "scheduler_state_dict" in checkpoint: + self.scheduler.load_state_dict(checkpoint["scheduler_state_dict"]) + + self.current_epoch = checkpoint["epoch"] + self.train_losses = checkpoint["train_losses"] + self.val_losses = checkpoint["val_losses"] + self.val_metrics = checkpoint["val_metrics"] + self.best_val_loss = checkpoint["best_val_loss"] + + print(f"Loaded checkpoint from epoch {self.current_epoch}") + + def train(self, num_epochs: int): + """ + Train the model for specified number of epochs. + + Args: + num_epochs: Number of epochs to train + """ + print(f"Starting training for {num_epochs} epochs...") + start_time = time.time() + + for epoch in range(num_epochs): + self.current_epoch = epoch + + # Train for one epoch + train_loss = self.train_epoch() + + # Validate + val_loss, val_metrics = self.validate() + + # Update learning rate scheduler + if self.scheduler is not None: + if isinstance(self.scheduler, optim.lr_scheduler.ReduceLROnPlateau): + self.scheduler.step(val_loss) + else: + self.scheduler.step() + + # Log to TensorBoard + self.writer.add_scalar("train/epoch_loss", train_loss, epoch) + self.writer.add_scalar("val/epoch_loss", val_loss, epoch) + for metric_name, metric_value in val_metrics.items(): + self.writer.add_scalar(f"val/{metric_name}", metric_value, epoch) + + # Print epoch summary + print(f"\nEpoch {epoch + 1}/{num_epochs}:") + print(f" Train Loss: {train_loss:.6f}") + print(f" Val Loss: {val_loss:.6f}") + print(" Val Metrics:") + for metric_name, metric_value in val_metrics.items(): + print(f" {metric_name}: {metric_value:.6f}") + + # Save checkpoint + is_best = val_loss < self.best_val_loss + if is_best: + self.best_val_loss = val_loss + + self.save_checkpoint(is_best=is_best) + + # Early stopping + if self.config.get("early_stopping_patience", 0) > 0: + if ( + epoch - np.argmin(self.val_losses) + >= self.config["early_stopping_patience"] + ): + print(f"Early stopping at epoch {epoch + 1}") + break + + # Training completed + training_time = time.time() - start_time + print(f"\nTraining completed in {training_time:.2f} seconds") + print(f"Best validation loss: {self.best_val_loss:.6f}") + + # Save final model + final_model_path = self.output_dir / "model_final.pth" + torch.save(self.model.state_dict(), final_model_path) + print(f"Final model saved to {final_model_path}") + + # Close TensorBoard writer + self.writer.close() + + def evaluate(self, test_loader: Optional[DataLoader] = None) -> Dict: + """ + Evaluate the model on test data. + + Args: + test_loader: Test data loader (uses validation loader if None) + + Returns: + Dictionary of evaluation metrics + """ + if test_loader is None: + test_loader = self.val_loader + + self.model.eval() + all_metrics = [] + + print("Evaluating model...") + with torch.no_grad(): + for batch in tqdm(test_loader): + # Move data to device + google_img = batch["google_img"].to(self.device) + yandex_img = batch["yandex_img"].to(self.device) + target_homography = batch["homography"].to(self.device) + + # Forward pass + pred_homography = self.model(google_img, yandex_img, return_matrix=True) + + # Compute metrics + metrics = self.criterion.compute_metrics( + pred_homography, target_homography + ) + all_metrics.append(metrics) + + # Aggregate metrics + avg_metrics = {} + for key in all_metrics[0].keys(): + avg_metrics[key] = np.mean([m[key] for m in all_metrics]) + + # Print evaluation results + print("\nEvaluation Results:") + for metric_name, metric_value in avg_metrics.items(): + print(f" {metric_name}: {metric_value:.6f}") + + # Save evaluation results + eval_path = self.output_dir / "evaluation_results.json" + with open(eval_path, "w") as f: + json.dump(avg_metrics, f, indent=2) + print(f"Evaluation results saved to {eval_path}") + + return avg_metrics + + +def parse_args(): + """Parse command line arguments.""" + parser = argparse.ArgumentParser(description="Train homography estimation model") + + # Data arguments + parser.add_argument( + "--data_dir", + type=str, + default=r"C:\Users\admin\Projects\autopilot\datasets\ya_go_maps\images", + help="Directory containing image pairs", + ) + parser.add_argument( + "--batch_size", type=int, default=32, help="Batch size for training" + ) + parser.add_argument( + "--image_size", + type=int, + nargs=2, + default=[256, 256], + help="Image size (height width)", + ) + parser.add_argument( + "--train_split", type=float, default=0.8, help="Train/validation split ratio" + ) + parser.add_argument( + "--num_workers", type=int, default=4, help="Number of data loader workers" + ) + + # Model arguments + parser.add_argument( + "--model_type", type=str, default="cnn", choices=["cnn"], help="Model type" + ) + parser.add_argument( + "--hidden_channels", type=int, default=64, help="Number of hidden channels" + ) + parser.add_argument( + "--num_blocks", type=int, default=4, help="Number of convolutional blocks" + ) + parser.add_argument("--dropout_rate", type=float, default=0.3, help="Dropout rate") + parser.add_argument( + "--use_batch_norm", action="store_true", help="Use batch normalization" + ) + + # Training arguments + parser.add_argument("--epochs", type=int, default=100, help="Number of epochs") + parser.add_argument("--lr", type=float, default=1e-3, help="Learning rate") + parser.add_argument("--weight_decay", type=float, default=1e-4, help="Weight decay") + parser.add_argument( + "--optimizer", type=str, default="adam", choices=["adam", "adamw", "sgd"] + ) + parser.add_argument( + "--scheduler", + type=str, + default="plateau", + choices=["plateau", "cosine", "step", "none"], + ) + parser.add_argument( + "--grad_clip", type=float, default=1.0, help="Gradient clipping value" + ) + + # Loss arguments + parser.add_argument( + "--matrix_weight", type=float, default=1.0, help="Weight for matrix loss" + ) + parser.add_argument( + "--geometric_weight", + type=float, + default=0.5, + help="Weight for geometric loss", + ) + parser.add_argument( + "--reg_weight", type=float, default=0.1, help="Weight for regularization loss" + ) + + # Other arguments + parser.add_argument( + "--output_dir", + type=str, + default="runs/homography", + help="Output directory for checkpoints and logs", + ) + parser.add_argument( + "--resume", + type=str, + help="Path to checkpoint to resume training from", + ) + parser.add_argument( + "--eval_only", + action="store_true", + help="Only evaluate the model (no training)", + ) + parser.add_argument( + "--seed", type=int, default=42, help="Random seed for reproducibility" + ) + + return parser.parse_args() + + +def main(): + """Main training function.""" + args = parse_args() + + # Set random seeds for reproducibility + torch.manual_seed(args.seed) + np.random.seed(args.seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + + # Set device + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + print(f"Using device: {device}") + + # Create data loaders + print("Creating data loaders...") + train_loader, val_loader = create_data_loaders( + root_dir=args.data_dir, + batch_size=args.batch_size, + train_split=args.train_split, + num_workers=args.num_workers, + image_size=tuple(args.image_size), + augment_train=True, + augment_val=False, + ) + + print(f"Train batches: {len(train_loader)}") + print(f"Val batches: {len(val_loader)}") + + # Create model + print("Creating model...") + model = create_homography_model( + model_type=args.model_type, + input_size=tuple(args.image_size), + input_channels=3, + hidden_channels=args.hidden_channels, + num_blocks=args.num_blocks, + dropout_rate=args.dropout_rate, + use_batch_norm=args.use_batch_norm, + ) + + # Create trainer configuration + config = { + # Model config + "model_type": args.model_type, + "hidden_channels": args.hidden_channels, + "num_blocks": args.num_blocks, + "dropout_rate": args.dropout_rate, + "use_batch_norm": args.use_batch_norm, + "image_size": args.image_size, + # Training config + "epochs": args.epochs, + "batch_size": args.batch_size, + "learning_rate": args.lr, + "weight_decay": args.weight_decay, + "optimizer": args.optimizer, + "scheduler": args.scheduler, + "grad_clip": args.grad_clip, + # Loss config + "matrix_weight": args.matrix_weight, + "geometric_weight": args.geometric_weight, + "reg_weight": args.reg_weight, + "grid_size": 8, + # Data config + "data_dir": args.data_dir, + "train_split": args.train_split, + "num_workers": args.num_workers, + # Output config + "output_dir": args.output_dir, + "seed": args.seed, + } + + # Create trainer + trainer = HomographyTrainer( + model=model, + train_loader=train_loader, + val_loader=val_loader, + device=device, + config=config, + ) + + # Resume from checkpoint if specified + if args.resume: + print(f"Resuming from checkpoint: {args.resume}") + trainer.load_checkpoint(args.resume) + + # Evaluate only mode + if args.eval_only: + trainer.evaluate() + return + + # Train the model + trainer.train(num_epochs=args.epochs) + + # Final evaluation + print("\nPerforming final evaluation...") + trainer.evaluate() + + print("\nTraining completed successfully!") + print(f"Results saved to: {args.output_dir}") + + +if __name__ == "__main__": + main() diff --git a/models/SiaN/train_homography_.py b/models/SiaN/train_homography_.py new file mode 100644 index 0000000..0550432 --- /dev/null +++ b/models/SiaN/train_homography_.py @@ -0,0 +1,611 @@ +""" +Training script for homography estimation between Google and Yandex map images. + +This script trains a CNN model to estimate homography matrices that align +Google map images with Yandex map images. +""" + +import argparse +import json +import os +import time +from datetime import datetime +from pathlib import Path +from typing import Dict, List, Optional, Tuple + +import numpy as np +import torch +import torch.nn as nn +import torch.optim as optim +from homography import HomographyDataset, create_data_loaders +from homography_cnn import HomographyCNN, HomographyLoss, create_homography_model +from torch.utils.data import DataLoader +from torch.utils.tensorboard import SummaryWriter +from tqdm import tqdm + + +class HomographyTrainer: + """Trainer class for homography estimation model.""" + + def __init__( + self, + model: nn.Module, + train_loader: DataLoader, + val_loader: DataLoader, + device: torch.device, + config: Dict, + ): + """ + Initialize the trainer. + + Args: + model: Homography estimation model + train_loader: Training data loader + val_loader: Validation data loader + device: Device to run training on + config: Training configuration dictionary + """ + self.model = model.to(device) + self.train_loader = train_loader + self.val_loader = val_loader + self.device = device + self.config = config + + # Loss function + self.criterion = HomographyLoss( + matrix_weight=config.get("matrix_weight", 1.0), + geometric_weight=config.get("geometric_weight", 0.5), + reg_weight=config.get("reg_weight", 0.1), + grid_size=config.get("grid_size", 8), + ).to(device) + + # Optimizer + optimizer_name = config.get("optimizer", "adam").lower() + lr = config.get("learning_rate", 1e-3) + weight_decay = config.get("weight_decay", 1e-4) + + if optimizer_name == "adam": + self.optimizer = optim.Adam( + self.model.parameters(), lr=lr, weight_decay=weight_decay + ) + elif optimizer_name == "adamw": + self.optimizer = optim.AdamW( + self.model.parameters(), lr=lr, weight_decay=weight_decay + ) + elif optimizer_name == "sgd": + self.optimizer = optim.SGD( + self.model.parameters(), + lr=lr, + momentum=config.get("momentum", 0.9), + weight_decay=weight_decay, + ) + else: + raise ValueError(f"Unknown optimizer: {optimizer_name}") + + # Learning rate scheduler + scheduler_name = config.get("scheduler", "plateau").lower() + if scheduler_name == "plateau": + self.scheduler = optim.lr_scheduler.ReduceLROnPlateau( + self.optimizer, + mode="min", + factor=config.get("scheduler_factor", 0.5), + patience=config.get("scheduler_patience", 5), + verbose=True, + ) + elif scheduler_name == "cosine": + self.scheduler = optim.lr_scheduler.CosineAnnealingLR( + self.optimizer, + T_max=config.get("epochs", 100), + eta_min=config.get("min_lr", 1e-6), + ) + elif scheduler_name == "step": + self.scheduler = optim.lr_scheduler.StepLR( + self.optimizer, + step_size=config.get("step_size", 30), + gamma=config.get("gamma", 0.1), + ) + else: + self.scheduler = None + + # Training state + self.current_epoch = 0 + self.best_val_loss = float("inf") + self.train_losses: List[float] = [] + self.val_losses: List[float] = [] + self.val_metrics: List[Dict] = [] + + # Create output directory + self.output_dir = Path(config.get("output_dir", "runs/homography")) + self.output_dir.mkdir(parents=True, exist_ok=True) + + # TensorBoard writer + self.writer = SummaryWriter(log_dir=self.output_dir / "tensorboard") + + # Save configuration + config_path = self.output_dir / "config.json" + with open(config_path, "w") as f: + json.dump(config, f, indent=2) + + print(f"Training configuration saved to {config_path}") + print( + f"Model has {sum(p.numel() for p in self.model.parameters()):,} parameters" + ) + + def train_epoch(self) -> float: + """ + Train for one epoch. + + Returns: + Average training loss for the epoch + """ + self.model.train() + total_loss = 0.0 + num_batches = len(self.train_loader) + + progress_bar = tqdm(self.train_loader, desc=f"Epoch {self.current_epoch + 1}") + for batch_idx, batch in enumerate(progress_bar): + # Move data to device + google_img = batch["google_img"].to(self.device) + yandex_img = batch["yandex_img"].to(self.device) + target_homography = batch["homography"].to(self.device) + + # Forward pass + self.optimizer.zero_grad() + pred_homography = self.model(google_img, yandex_img, return_matrix=True) + + # Compute loss + loss = self.criterion( + pred_homography, + target_homography, + google_img, + yandex_img, + ) + + # Backward pass + loss.backward() + + # Gradient clipping + if self.config.get("grad_clip", 1.0) > 0: + torch.nn.utils.clip_grad_norm_( + self.model.parameters(), + self.config.get("grad_clip", 1.0), + ) + + # Optimizer step + self.optimizer.step() + + # Update statistics + total_loss += loss.item() + + # Update progress bar + progress_bar.set_postfix({"loss": loss.item()}) + + # Log batch loss to TensorBoard + global_step = self.current_epoch * num_batches + batch_idx + self.writer.add_scalar("train/batch_loss", loss.item(), global_step) + + avg_loss = total_loss / num_batches + self.train_losses.append(avg_loss) + + return avg_loss + + @torch.no_grad() + def validate(self) -> Tuple[float, Dict]: + """ + Validate the model. + + Returns: + Tuple of (average validation loss, validation metrics) + """ + self.model.eval() + total_loss = 0.0 + all_metrics = [] + + progress_bar = tqdm(self.val_loader, desc="Validation") + for batch in progress_bar: + # Move data to device + google_img = batch["google_img"].to(self.device) + yandex_img = batch["yandex_img"].to(self.device) + target_homography = batch["homography"].to(self.device) + + # Forward pass + pred_homography = self.model(google_img, yandex_img, return_matrix=True) + + # Compute loss + loss = self.criterion( + pred_homography, + target_homography, + google_img, + yandex_img, + ) + + # Compute metrics + metrics = self.criterion.compute_metrics(pred_homography, target_homography) + + # Update statistics + total_loss += loss.item() + all_metrics.append(metrics) + + # Update progress bar + progress_bar.set_postfix({"loss": loss.item()}) + + avg_loss = total_loss / len(self.val_loader) + self.val_losses.append(avg_loss) + + # Aggregate metrics + avg_metrics = {} + for key in all_metrics[0].keys(): + avg_metrics[key] = np.mean([m[key] for m in all_metrics]) + + self.val_metrics.append(avg_metrics) + + return avg_loss, avg_metrics + + def save_checkpoint(self, is_best: bool = False): + """Save model checkpoint.""" + checkpoint = { + "epoch": self.current_epoch, + "model_state_dict": self.model.state_dict(), + "optimizer_state_dict": self.optimizer.state_dict(), + "train_losses": self.train_losses, + "val_losses": self.val_losses, + "val_metrics": self.val_metrics, + "best_val_loss": self.best_val_loss, + "config": self.config, + } + + if self.scheduler is not None: + checkpoint["scheduler_state_dict"] = self.scheduler.state_dict() + + # Save latest checkpoint + checkpoint_path = self.output_dir / "checkpoint_latest.pth" + torch.save(checkpoint, checkpoint_path) + + # Save best checkpoint + if is_best: + best_path = self.output_dir / "checkpoint_best.pth" + torch.save(checkpoint, best_path) + print(f"Best model saved to {best_path}") + + def load_checkpoint(self, checkpoint_path: str): + """Load model checkpoint.""" + checkpoint = torch.load(checkpoint_path, map_location=self.device) + + self.model.load_state_dict(checkpoint["model_state_dict"]) + self.optimizer.load_state_dict(checkpoint["optimizer_state_dict"]) + + if self.scheduler is not None and "scheduler_state_dict" in checkpoint: + self.scheduler.load_state_dict(checkpoint["scheduler_state_dict"]) + + self.current_epoch = checkpoint["epoch"] + self.train_losses = checkpoint["train_losses"] + self.val_losses = checkpoint["val_losses"] + self.val_metrics = checkpoint["val_metrics"] + self.best_val_loss = checkpoint["best_val_loss"] + + print(f"Loaded checkpoint from epoch {self.current_epoch}") + + def train(self, num_epochs: int): + """ + Train the model for specified number of epochs. + + Args: + num_epochs: Number of epochs to train + """ + print(f"Starting training for {num_epochs} epochs...") + start_time = time.time() + + for epoch in range(num_epochs): + self.current_epoch = epoch + + # Train for one epoch + train_loss = self.train_epoch() + + # Validate + val_loss, val_metrics = self.validate() + + # Update learning rate scheduler + if self.scheduler is not None: + if isinstance(self.scheduler, optim.lr_scheduler.ReduceLROnPlateau): + self.scheduler.step(val_loss) + else: + self.scheduler.step() + + # Log to TensorBoard + self.writer.add_scalar("train/epoch_loss", train_loss, epoch) + self.writer.add_scalar("val/epoch_loss", val_loss, epoch) + for metric_name, metric_value in val_metrics.items(): + self.writer.add_scalar(f"val/{metric_name}", metric_value, epoch) + + # Print epoch summary + print(f"\nEpoch {epoch + 1}/{num_epochs}:") + print(f" Train Loss: {train_loss:.6f}") + print(f" Val Loss: {val_loss:.6f}") + print(" Val Metrics:") + for metric_name, metric_value in val_metrics.items(): + print(f" {metric_name}: {metric_value:.6f}") + + # Save checkpoint + is_best = val_loss < self.best_val_loss + if is_best: + self.best_val_loss = val_loss + + self.save_checkpoint(is_best=is_best) + + # Early stopping + if self.config.get("early_stopping_patience", 0) > 0: + if ( + epoch - np.argmin(self.val_losses) + >= self.config["early_stopping_patience"] + ): + print(f"Early stopping at epoch {epoch + 1}") + break + + # Training completed + training_time = time.time() - start_time + print(f"\nTraining completed in {training_time:.2f} seconds") + print(f"Best validation loss: {self.best_val_loss:.6f}") + + # Save final model + final_model_path = self.output_dir / "model_final.pth" + torch.save(self.model.state_dict(), final_model_path) + print(f"Final model saved to {final_model_path}") + + # Close TensorBoard writer + self.writer.close() + + def evaluate(self, test_loader: Optional[DataLoader] = None) -> Dict: + """ + Evaluate the model on test data. + + Args: + test_loader: Test data loader (uses validation loader if None) + + Returns: + Dictionary of evaluation metrics + """ + if test_loader is None: + test_loader = self.val_loader + + self.model.eval() + all_metrics = [] + + print("Evaluating model...") + with torch.no_grad(): + for batch in tqdm(test_loader): + # Move data to device + google_img = batch["google_img"].to(self.device) + yandex_img = batch["yandex_img"].to(self.device) + target_homography = batch["homography"].to(self.device) + + # Forward pass + pred_homography = self.model(google_img, yandex_img, return_matrix=True) + + # Compute metrics + metrics = self.criterion.compute_metrics( + pred_homography, target_homography + ) + all_metrics.append(metrics) + + # Aggregate metrics + avg_metrics = {} + for key in all_metrics[0].keys(): + avg_metrics[key] = np.mean([m[key] for m in all_metrics]) + + # Print evaluation results + print("\nEvaluation Results:") + for metric_name, metric_value in avg_metrics.items(): + print(f" {metric_name}: {metric_value:.6f}") + + # Save evaluation results + eval_path = self.output_dir / "evaluation_results.json" + with open(eval_path, "w") as f: + json.dump(avg_metrics, f, indent=2) + print(f"Evaluation results saved to {eval_path}") + + return avg_metrics + + +def parse_args(): + """Parse command line arguments.""" + parser = argparse.ArgumentParser(description="Train homography estimation model") + + # Data arguments + parser.add_argument( + "--data_dir", + type=str, + default=r"C:\Users\admin\Projects\autopilot\datasets\ya_go_maps\images", + help="Directory containing image pairs", + ) + parser.add_argument( + "--batch_size", type=int, default=32, help="Batch size for training" + ) + parser.add_argument( + "--image_size", + type=int, + nargs=2, + default=[256, 256], + help="Image size (height width)", + ) + parser.add_argument( + "--train_split", type=float, default=0.8, help="Train/validation split ratio" + ) + parser.add_argument( + "--num_workers", type=int, default=4, help="Number of data loader workers" + ) + + # Model arguments + parser.add_argument( + "--model_type", type=str, default="cnn", choices=["cnn"], help="Model type" + ) + parser.add_argument( + "--hidden_channels", type=int, default=64, help="Number of hidden channels" + ) + parser.add_argument( + "--num_blocks", type=int, default=4, help="Number of convolutional blocks" + ) + parser.add_argument("--dropout_rate", type=float, default=0.3, help="Dropout rate") + parser.add_argument( + "--use_batch_norm", action="store_true", help="Use batch normalization" + ) + + # Training arguments + parser.add_argument("--epochs", type=int, default=100, help="Number of epochs") + parser.add_argument("--lr", type=float, default=1e-3, help="Learning rate") + parser.add_argument("--weight_decay", type=float, default=1e-4, help="Weight decay") + parser.add_argument( + "--optimizer", type=str, default="adam", choices=["adam", "adamw", "sgd"] + ) + parser.add_argument( + "--scheduler", + type=str, + default="plateau", + choices=["plateau", "cosine", "step", "none"], + ) + parser.add_argument( + "--grad_clip", type=float, default=1.0, help="Gradient clipping value" + ) + + # Loss arguments + parser.add_argument( + "--matrix_weight", type=float, default=1.0, help="Weight for matrix loss" + ) + parser.add_argument( + "--geometric_weight", + type=float, + default=0.5, + help="Weight for geometric loss", + ) + parser.add_argument( + "--reg_weight", type=float, default=0.1, help="Weight for regularization loss" + ) + + # Other arguments + parser.add_argument( + "--output_dir", + type=str, + default="runs/homography", + help="Output directory for checkpoints and logs", + ) + parser.add_argument( + "--resume", + type=str, + help="Path to checkpoint to resume training from", + ) + parser.add_argument( + "--eval_only", + action="store_true", + help="Only evaluate the model (no training)", + ) + parser.add_argument( + "--seed", type=int, default=42, help="Random seed for reproducibility" + ) + + return parser.parse_args() + + +def main(): + """Main training function.""" + args = parse_args() + + # Set random seeds for reproducibility + torch.manual_seed(args.seed) + np.random.seed(args.seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + + # Set device + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + print(f"Using device: {device}") + + # Create data loaders + print("Creating data loaders...") + train_loader, val_loader = create_data_loaders( + root_dir=args.data_dir, + batch_size=args.batch_size, + train_split=args.train_split, + num_workers=args.num_workers, + image_size=tuple(args.image_size), + augment_train=True, + augment_val=False, + ) + + print(f"Train batches: {len(train_loader)}") + print(f"Val batches: {len(val_loader)}") + + # Create model + print("Creating model...") + model = create_homography_model( + model_type=args.model_type, + input_size=tuple(args.image_size), + input_channels=3, + hidden_channels=args.hidden_channels, + num_blocks=args.num_blocks, + dropout_rate=args.dropout_rate, + use_batch_norm=args.use_batch_norm, + ) + + # Create trainer configuration + config = { + # Model config + "model_type": args.model_type, + "hidden_channels": args.hidden_channels, + "num_blocks": args.num_blocks, + "dropout_rate": args.dropout_rate, + "use_batch_norm": args.use_batch_norm, + "image_size": args.image_size, + # Training config + "epochs": args.epochs, + "batch_size": args.batch_size, + "learning_rate": args.lr, + "weight_decay": args.weight_decay, + "optimizer": args.optimizer, + "scheduler": args.scheduler, + "grad_clip": args.grad_clip, + # Loss config + "matrix_weight": args.matrix_weight, + "geometric_weight": args.geometric_weight, + "reg_weight": args.reg_weight, + "grid_size": 8, + # Data config + "data_dir": args.data_dir, + "train_split": args.train_split, + "num_workers": args.num_workers, + # Output config + "output_dir": args.output_dir, + "seed": args.seed, + } + + # Create trainer + trainer = HomographyTrainer( + model=model, + train_loader=train_loader, + val_loader=val_loader, + device=device, + config=config, + ) + + # Resume from checkpoint if specified + if args.resume: + print(f"Resuming from checkpoint: {args.resume}") + trainer.load_checkpoint(args.resume) + + # Evaluate only mode + if args.eval_only: + trainer.evaluate() + return + + # Train the model + trainer.train(num_epochs=args.epochs) + + # Final evaluation + print("\nPerforming final evaluation...") + trainer.evaluate() + + print("\nTraining completed successfully!") + print(f"Results saved to: {args.output_dir}") + + +if __name__ == "__main__": + main()