Files
autopilot/models/SiaN/infer_homography.py
2026-02-16 19:07:31 +03:00

554 lines
18 KiB
Python

"""
Inference script for homography estimation between Google and Yandex map images.
This script loads a trained homography estimation model and performs inference
on new image pairs or the test dataset.
"""
import argparse
import json
import os
from pathlib import Path
from typing import Dict, List, Optional, Tuple
import cv2
import matplotlib.pyplot as plt
import numpy as np
import torch
from homography import HomographyDataset
from homography_cnn import HomographyCNN, create_homography_model
from PIL import Image
from torchvision import transforms
class HomographyInference:
"""Class for performing inference with homography estimation model."""
def __init__(
self,
model_path: str,
config_path: Optional[str] = None,
device: Optional[str] = None,
):
"""
Initialize the inference class.
Args:
model_path: Path to trained model checkpoint
config_path: Path to model configuration file (optional)
device: Device to run inference on ('cuda' or 'cpu')
"""
# Set device
if device is None:
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
else:
self.device = torch.device(device)
print(f"Using device: {self.device}")
# Load configuration
if config_path is None:
# Try to find config in the same directory as model
model_dir = Path(model_path).parent
config_path = model_dir / "config.json"
if os.path.exists(config_path):
with open(config_path, "r") as f:
self.config = json.load(f)
print(f"Loaded configuration from {config_path}")
else:
# Use default configuration
self.config = {
"image_size": [256, 256],
"hidden_channels": 64,
"num_blocks": 4,
"dropout_rate": 0.3,
"use_batch_norm": True,
}
print("Using default configuration")
# Create model
self.model = self._create_model()
self._load_model(model_path)
# Set up transforms
self.transform = self._create_transforms()
# Set model to evaluation mode
self.model.eval()
def _create_model(self) -> HomographyCNN:
"""Create model based on configuration."""
image_size = self.config.get("image_size", [256, 256])
model = create_homography_model(
model_type="cnn",
input_size=tuple(image_size),
input_channels=3,
hidden_channels=self.config.get("hidden_channels", 64),
num_blocks=self.config.get("num_blocks", 4),
dropout_rate=self.config.get("dropout_rate", 0.3),
use_batch_norm=self.config.get("use_batch_norm", True),
)
return model
def _load_model(self, model_path: str):
"""Load model weights from checkpoint."""
if not os.path.exists(model_path):
raise FileNotFoundError(f"Model file not found: {model_path}")
# Load checkpoint
checkpoint = torch.load(model_path, map_location=self.device)
if isinstance(checkpoint, dict) and "model_state_dict" in checkpoint:
# Trainer checkpoint format
self.model.load_state_dict(checkpoint["model_state_dict"])
else:
# Raw model weights format
self.model.load_state_dict(checkpoint)
self.model = self.model.to(self.device)
print(f"Loaded model from {model_path}")
def _create_transforms(self):
"""Create image transforms for inference."""
return transforms.Compose(
[
transforms.Resize(tuple(self.config.get("image_size", [256, 256]))),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
),
]
)
def preprocess_images(
self, google_img: Image.Image, yandex_img: Image.Image
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Preprocess images for inference.
Args:
google_img: Google map image (PIL Image)
yandex_img: Yandex map image (PIL Image)
Returns:
Tuple of preprocessed image tensors
"""
# Convert to RGB if needed
if google_img.mode != "RGB":
google_img = google_img.convert("RGB")
if yandex_img.mode != "RGB":
yandex_img = yandex_img.convert("RGB")
# Apply transforms
google_tensor = self.transform(google_img).unsqueeze(0) # Add batch dimension
yandex_tensor = self.transform(yandex_img).unsqueeze(0)
return google_tensor, yandex_tensor
def predict(
self,
google_img: Image.Image,
yandex_img: Image.Image,
return_matrix: bool = True,
normalize: bool = True,
) -> torch.Tensor:
"""
Predict homography matrix for image pair.
Args:
google_img: Google map image (PIL Image)
yandex_img: Yandex map image (PIL Image)
return_matrix: If True, return 3x3 matrix; if False, return flattened vector
normalize: Whether to normalize the homography matrix
Returns:
Predicted homography matrix or vector
"""
# Preprocess images
google_tensor, yandex_tensor = self.preprocess_images(google_img, yandex_img)
# Move to device
google_tensor = google_tensor.to(self.device)
yandex_tensor = yandex_tensor.to(self.device)
# Perform inference
with torch.no_grad():
homography = self.model(
google_tensor, yandex_tensor, return_matrix=return_matrix
)
if return_matrix and normalize:
# Normalize so that last element is 1
homography = homography / homography[:, 2, 2].view(-1, 1, 1)
return homography.squeeze(0) # Remove batch dimension
def predict_from_paths(
self,
google_path: str,
yandex_path: str,
return_matrix: bool = True,
normalize: bool = True,
) -> torch.Tensor:
"""
Predict homography matrix from image file paths.
Args:
google_path: Path to Google map image
yandex_path: Path to Yandex map image
return_matrix: If True, return 3x3 matrix; if False, return flattened vector
normalize: Whether to normalize the homography matrix
Returns:
Predicted homography matrix or vector
"""
# Load images
google_img = Image.open(google_path)
yandex_img = Image.open(yandex_path)
return self.predict(google_img, yandex_img, return_matrix, normalize)
def warp_image(
self,
img: Image.Image,
homography: np.ndarray,
output_size: Optional[Tuple[int, int]] = None,
) -> Image.Image:
"""
Warp image using homography matrix.
Args:
img: Input image (PIL Image)
homography: 3x3 homography matrix (numpy array)
output_size: Output image size (width, height). If None, uses input size.
Returns:
Warped image (PIL Image)
"""
# Convert to numpy array
img_np = np.array(img)
# Get output size
if output_size is None:
output_size = (img_np.shape[1], img_np.shape[0])
# Apply homography transformation
warped_np = cv2.warpPerspective(
img_np,
homography,
output_size,
flags=cv2.INTER_LINEAR,
borderMode=cv2.BORDER_REFLECT,
)
# Convert back to PIL Image
return Image.fromarray(warped_np)
def visualize_alignment(
self,
google_img: Image.Image,
yandex_img: Image.Image,
homography: np.ndarray,
save_path: Optional[str] = None,
show: bool = True,
):
"""
Visualize alignment between images using homography.
Args:
google_img: Google map image
yandex_img: Yandex map image
homography: Homography matrix
save_path: Path to save visualization (optional)
show: Whether to display the visualization
"""
# Warp yandex image to align with google
yandex_warped = self.warp_image(yandex_img, homography)
# Convert images to numpy arrays for visualization
google_np = np.array(google_img)
yandex_np = np.array(yandex_img)
yandex_warped_np = np.array(yandex_warped)
# Create visualization
fig, axes = plt.subplots(2, 2, figsize=(12, 10))
# Original images
axes[0, 0].imshow(google_np)
axes[0, 0].set_title("Google Map (Original)")
axes[0, 0].axis("off")
axes[0, 1].imshow(yandex_np)
axes[0, 1].set_title("Yandex Map (Original)")
axes[0, 1].axis("off")
# Warped image
axes[1, 0].imshow(yandex_warped_np)
axes[1, 0].set_title("Yandex Map (Warped)")
axes[1, 0].axis("off")
# Overlay (50% transparency)
overlay = cv2.addWeighted(google_np, 0.5, yandex_warped_np, 0.5, 0)
axes[1, 1].imshow(overlay)
axes[1, 1].set_title("Overlay (Google + Warped Yandex)")
axes[1, 1].axis("off")
plt.tight_layout()
if save_path:
plt.savefig(save_path, dpi=150, bbox_inches="tight")
print(f"Visualization saved to {save_path}")
if show:
plt.show()
else:
plt.close()
def evaluate_on_dataset(
self,
dataset_dir: str,
num_samples: Optional[int] = None,
save_dir: Optional[str] = None,
) -> Dict[str, float]:
"""
Evaluate model on a dataset.
Args:
dataset_dir: Directory containing image pairs
num_samples: Number of samples to evaluate (None for all)
save_dir: Directory to save visualizations (optional)
Returns:
Dictionary of evaluation metrics
"""
# Create dataset
dataset = HomographyDataset(
root_dir=dataset_dir,
transform=None, # We'll handle transforms manually
augment=False,
image_size=tuple(self.config.get("image_size", [256, 256])),
cache_homographies=False,
)
if num_samples is not None:
indices = list(range(min(num_samples, len(dataset))))
else:
indices = list(range(len(dataset)))
errors = []
corner_errors = []
print(f"Evaluating on {len(indices)} samples...")
for idx in indices:
# Get sample without augmentation
sample = dataset.get_sample_without_augmentation(idx)
google_img = sample["google_img"]
yandex_img = sample["yandex_img"]
target_homography = sample["homography"]
# Predict homography
pred_homography = self.predict(
google_img, yandex_img, return_matrix=True, normalize=True
)
# Convert to numpy
pred_homography_np = pred_homography.cpu().numpy()
target_homography_np = target_homography
# Compute matrix error
matrix_error = np.mean((pred_homography_np - target_homography_np) ** 2)
errors.append(matrix_error)
# Compute corner error
corners = np.array(
[
[-1, -1, 1],
[1, -1, 1],
[1, 1, 1],
[-1, 1, 1],
],
dtype=np.float32,
).T
pred_corners = pred_homography_np @ corners
pred_corners = pred_corners / (pred_corners[2:3, :] + 1e-8)
target_corners = target_homography_np @ corners
target_corners = target_corners / (target_corners[2:3, :] + 1e-8)
corner_error = np.mean(
np.linalg.norm(pred_corners[:2, :] - target_corners[:2, :], axis=0)
)
corner_errors.append(corner_error)
# Save visualization if requested
if save_dir:
os.makedirs(save_dir, exist_ok=True)
vis_path = os.path.join(save_dir, f"sample_{idx:04d}.png")
self.visualize_alignment(
google_img,
yandex_img,
pred_homography_np,
save_path=vis_path,
show=False,
)
# Compute metrics
metrics = {
"mean_matrix_error": float(np.mean(errors)),
"std_matrix_error": float(np.std(errors)),
"mean_corner_error": float(np.mean(corner_errors)),
"std_corner_error": float(np.std(corner_errors)),
"median_corner_error": float(np.median(corner_errors)),
"max_corner_error": float(np.max(corner_errors)),
"min_corner_error": float(np.min(corner_errors)),
}
print("\nEvaluation Results:")
for key, value in metrics.items():
print(f" {key}: {value:.6f}")
return metrics
def main():
"""Main inference function."""
parser = argparse.ArgumentParser(description="Inference for homography estimation")
# Model arguments
parser.add_argument(
"--model_path",
type=str,
required=True,
help="Path to trained model checkpoint",
)
parser.add_argument(
"--config_path",
type=str,
help="Path to model configuration file (optional)",
)
parser.add_argument(
"--device",
type=str,
choices=["cuda", "cpu"],
help="Device to run inference on",
)
# Inference mode
parser.add_argument(
"--mode",
type=str,
default="single",
choices=["single", "dataset", "batch"],
help="Inference mode",
)
# Single image mode
parser.add_argument(
"--google_path",
type=str,
help="Path to Google map image (single mode)",
)
parser.add_argument(
"--yandex_path",
type=str,
help="Path to Yandex map image (single mode)",
)
parser.add_argument(
"--output_vis",
type=str,
help="Path to save visualization (single mode)",
)
# Dataset mode
parser.add_argument(
"--dataset_dir",
type=str,
default=r"C:\Users\admin\Projects\autopilot\datasets\ya_go_maps\images",
help="Directory containing image pairs (dataset mode)",
)
parser.add_argument(
"--num_samples",
type=int,
help="Number of samples to evaluate (dataset mode)",
)
parser.add_argument(
"--save_vis_dir",
type=str,
help="Directory to save visualizations (dataset mode)",
)
parser.add_argument(
"--save_results",
type=str,
help="Path to save evaluation results (dataset mode)",
)
args = parser.parse_args()
# Create inference object
inference = HomographyInference(
model_path=args.model_path,
config_path=args.config_path,
device=args.device,
)
if args.mode == "single":
# Single image pair inference
if not args.google_path or not args.yandex_path:
raise ValueError(
"Both --google_path and --yandex_path are required for single mode"
)
print(f"Processing single image pair:")
print(f" Google: {args.google_path}")
print(f" Yandex: {args.yandex_path}")
# Predict homography
homography = inference.predict_from_paths(args.google_path, args.yandex_path)
print(f"\nPredicted homography matrix:")
print(homography.cpu().numpy())
# Visualize alignment
if args.output_vis:
google_img = Image.open(args.google_path)
yandex_img = Image.open(args.yandex_path)
inference.visualize_alignment(
google_img,
yandex_img,
homography.cpu().numpy(),
save_path=args.output_vis,
show=True,
)
elif args.mode == "dataset":
# Evaluate on dataset
metrics = inference.evaluate_on_dataset(
dataset_dir=args.dataset_dir,
num_samples=args.num_samples,
save_dir=args.save_vis_dir,
)
# Save results if requested
if args.save_results:
with open(args.save_results, "w") as f:
json.dump(metrics, f, indent=2)
print(f"\nResults saved to {args.save_results}")
elif args.mode == "batch":
# Batch processing (placeholder for future implementation)
print("Batch mode not yet implemented")
# Could implement processing multiple image pairs from a directory
else:
raise ValueError(f"Unknown mode: {args.mode}")
if __name__ == "__main__":
main()