554 lines
18 KiB
Python
554 lines
18 KiB
Python
"""
|
|
Inference script for homography estimation between Google and Yandex map images.
|
|
|
|
This script loads a trained homography estimation model and performs inference
|
|
on new image pairs or the test dataset.
|
|
"""
|
|
|
|
import argparse
|
|
import json
|
|
import os
|
|
from pathlib import Path
|
|
from typing import Dict, List, Optional, Tuple
|
|
|
|
import cv2
|
|
import matplotlib.pyplot as plt
|
|
import numpy as np
|
|
import torch
|
|
from homography import HomographyDataset
|
|
from homography_cnn import HomographyCNN, create_homography_model
|
|
from PIL import Image
|
|
from torchvision import transforms
|
|
|
|
|
|
class HomographyInference:
|
|
"""Class for performing inference with homography estimation model."""
|
|
|
|
def __init__(
|
|
self,
|
|
model_path: str,
|
|
config_path: Optional[str] = None,
|
|
device: Optional[str] = None,
|
|
):
|
|
"""
|
|
Initialize the inference class.
|
|
|
|
Args:
|
|
model_path: Path to trained model checkpoint
|
|
config_path: Path to model configuration file (optional)
|
|
device: Device to run inference on ('cuda' or 'cpu')
|
|
"""
|
|
# Set device
|
|
if device is None:
|
|
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
else:
|
|
self.device = torch.device(device)
|
|
|
|
print(f"Using device: {self.device}")
|
|
|
|
# Load configuration
|
|
if config_path is None:
|
|
# Try to find config in the same directory as model
|
|
model_dir = Path(model_path).parent
|
|
config_path = model_dir / "config.json"
|
|
|
|
if os.path.exists(config_path):
|
|
with open(config_path, "r") as f:
|
|
self.config = json.load(f)
|
|
print(f"Loaded configuration from {config_path}")
|
|
else:
|
|
# Use default configuration
|
|
self.config = {
|
|
"image_size": [256, 256],
|
|
"hidden_channels": 64,
|
|
"num_blocks": 4,
|
|
"dropout_rate": 0.3,
|
|
"use_batch_norm": True,
|
|
}
|
|
print("Using default configuration")
|
|
|
|
# Create model
|
|
self.model = self._create_model()
|
|
self._load_model(model_path)
|
|
|
|
# Set up transforms
|
|
self.transform = self._create_transforms()
|
|
|
|
# Set model to evaluation mode
|
|
self.model.eval()
|
|
|
|
def _create_model(self) -> HomographyCNN:
|
|
"""Create model based on configuration."""
|
|
image_size = self.config.get("image_size", [256, 256])
|
|
|
|
model = create_homography_model(
|
|
model_type="cnn",
|
|
input_size=tuple(image_size),
|
|
input_channels=3,
|
|
hidden_channels=self.config.get("hidden_channels", 64),
|
|
num_blocks=self.config.get("num_blocks", 4),
|
|
dropout_rate=self.config.get("dropout_rate", 0.3),
|
|
use_batch_norm=self.config.get("use_batch_norm", True),
|
|
)
|
|
|
|
return model
|
|
|
|
def _load_model(self, model_path: str):
|
|
"""Load model weights from checkpoint."""
|
|
if not os.path.exists(model_path):
|
|
raise FileNotFoundError(f"Model file not found: {model_path}")
|
|
|
|
# Load checkpoint
|
|
checkpoint = torch.load(model_path, map_location=self.device)
|
|
|
|
if isinstance(checkpoint, dict) and "model_state_dict" in checkpoint:
|
|
# Trainer checkpoint format
|
|
self.model.load_state_dict(checkpoint["model_state_dict"])
|
|
else:
|
|
# Raw model weights format
|
|
self.model.load_state_dict(checkpoint)
|
|
|
|
self.model = self.model.to(self.device)
|
|
print(f"Loaded model from {model_path}")
|
|
|
|
def _create_transforms(self):
|
|
"""Create image transforms for inference."""
|
|
return transforms.Compose(
|
|
[
|
|
transforms.Resize(tuple(self.config.get("image_size", [256, 256]))),
|
|
transforms.ToTensor(),
|
|
transforms.Normalize(
|
|
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
|
|
),
|
|
]
|
|
)
|
|
|
|
def preprocess_images(
|
|
self, google_img: Image.Image, yandex_img: Image.Image
|
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
"""
|
|
Preprocess images for inference.
|
|
|
|
Args:
|
|
google_img: Google map image (PIL Image)
|
|
yandex_img: Yandex map image (PIL Image)
|
|
|
|
Returns:
|
|
Tuple of preprocessed image tensors
|
|
"""
|
|
# Convert to RGB if needed
|
|
if google_img.mode != "RGB":
|
|
google_img = google_img.convert("RGB")
|
|
if yandex_img.mode != "RGB":
|
|
yandex_img = yandex_img.convert("RGB")
|
|
|
|
# Apply transforms
|
|
google_tensor = self.transform(google_img).unsqueeze(0) # Add batch dimension
|
|
yandex_tensor = self.transform(yandex_img).unsqueeze(0)
|
|
|
|
return google_tensor, yandex_tensor
|
|
|
|
def predict(
|
|
self,
|
|
google_img: Image.Image,
|
|
yandex_img: Image.Image,
|
|
return_matrix: bool = True,
|
|
normalize: bool = True,
|
|
) -> torch.Tensor:
|
|
"""
|
|
Predict homography matrix for image pair.
|
|
|
|
Args:
|
|
google_img: Google map image (PIL Image)
|
|
yandex_img: Yandex map image (PIL Image)
|
|
return_matrix: If True, return 3x3 matrix; if False, return flattened vector
|
|
normalize: Whether to normalize the homography matrix
|
|
|
|
Returns:
|
|
Predicted homography matrix or vector
|
|
"""
|
|
# Preprocess images
|
|
google_tensor, yandex_tensor = self.preprocess_images(google_img, yandex_img)
|
|
|
|
# Move to device
|
|
google_tensor = google_tensor.to(self.device)
|
|
yandex_tensor = yandex_tensor.to(self.device)
|
|
|
|
# Perform inference
|
|
with torch.no_grad():
|
|
homography = self.model(
|
|
google_tensor, yandex_tensor, return_matrix=return_matrix
|
|
)
|
|
|
|
if return_matrix and normalize:
|
|
# Normalize so that last element is 1
|
|
homography = homography / homography[:, 2, 2].view(-1, 1, 1)
|
|
|
|
return homography.squeeze(0) # Remove batch dimension
|
|
|
|
def predict_from_paths(
|
|
self,
|
|
google_path: str,
|
|
yandex_path: str,
|
|
return_matrix: bool = True,
|
|
normalize: bool = True,
|
|
) -> torch.Tensor:
|
|
"""
|
|
Predict homography matrix from image file paths.
|
|
|
|
Args:
|
|
google_path: Path to Google map image
|
|
yandex_path: Path to Yandex map image
|
|
return_matrix: If True, return 3x3 matrix; if False, return flattened vector
|
|
normalize: Whether to normalize the homography matrix
|
|
|
|
Returns:
|
|
Predicted homography matrix or vector
|
|
"""
|
|
# Load images
|
|
google_img = Image.open(google_path)
|
|
yandex_img = Image.open(yandex_path)
|
|
|
|
return self.predict(google_img, yandex_img, return_matrix, normalize)
|
|
|
|
def warp_image(
|
|
self,
|
|
img: Image.Image,
|
|
homography: np.ndarray,
|
|
output_size: Optional[Tuple[int, int]] = None,
|
|
) -> Image.Image:
|
|
"""
|
|
Warp image using homography matrix.
|
|
|
|
Args:
|
|
img: Input image (PIL Image)
|
|
homography: 3x3 homography matrix (numpy array)
|
|
output_size: Output image size (width, height). If None, uses input size.
|
|
|
|
Returns:
|
|
Warped image (PIL Image)
|
|
"""
|
|
# Convert to numpy array
|
|
img_np = np.array(img)
|
|
|
|
# Get output size
|
|
if output_size is None:
|
|
output_size = (img_np.shape[1], img_np.shape[0])
|
|
|
|
# Apply homography transformation
|
|
warped_np = cv2.warpPerspective(
|
|
img_np,
|
|
homography,
|
|
output_size,
|
|
flags=cv2.INTER_LINEAR,
|
|
borderMode=cv2.BORDER_REFLECT,
|
|
)
|
|
|
|
# Convert back to PIL Image
|
|
return Image.fromarray(warped_np)
|
|
|
|
def visualize_alignment(
|
|
self,
|
|
google_img: Image.Image,
|
|
yandex_img: Image.Image,
|
|
homography: np.ndarray,
|
|
save_path: Optional[str] = None,
|
|
show: bool = True,
|
|
):
|
|
"""
|
|
Visualize alignment between images using homography.
|
|
|
|
Args:
|
|
google_img: Google map image
|
|
yandex_img: Yandex map image
|
|
homography: Homography matrix
|
|
save_path: Path to save visualization (optional)
|
|
show: Whether to display the visualization
|
|
"""
|
|
# Warp yandex image to align with google
|
|
yandex_warped = self.warp_image(yandex_img, homography)
|
|
|
|
# Convert images to numpy arrays for visualization
|
|
google_np = np.array(google_img)
|
|
yandex_np = np.array(yandex_img)
|
|
yandex_warped_np = np.array(yandex_warped)
|
|
|
|
# Create visualization
|
|
fig, axes = plt.subplots(2, 2, figsize=(12, 10))
|
|
|
|
# Original images
|
|
axes[0, 0].imshow(google_np)
|
|
axes[0, 0].set_title("Google Map (Original)")
|
|
axes[0, 0].axis("off")
|
|
|
|
axes[0, 1].imshow(yandex_np)
|
|
axes[0, 1].set_title("Yandex Map (Original)")
|
|
axes[0, 1].axis("off")
|
|
|
|
# Warped image
|
|
axes[1, 0].imshow(yandex_warped_np)
|
|
axes[1, 0].set_title("Yandex Map (Warped)")
|
|
axes[1, 0].axis("off")
|
|
|
|
# Overlay (50% transparency)
|
|
overlay = cv2.addWeighted(google_np, 0.5, yandex_warped_np, 0.5, 0)
|
|
axes[1, 1].imshow(overlay)
|
|
axes[1, 1].set_title("Overlay (Google + Warped Yandex)")
|
|
axes[1, 1].axis("off")
|
|
|
|
plt.tight_layout()
|
|
|
|
if save_path:
|
|
plt.savefig(save_path, dpi=150, bbox_inches="tight")
|
|
print(f"Visualization saved to {save_path}")
|
|
|
|
if show:
|
|
plt.show()
|
|
else:
|
|
plt.close()
|
|
|
|
def evaluate_on_dataset(
|
|
self,
|
|
dataset_dir: str,
|
|
num_samples: Optional[int] = None,
|
|
save_dir: Optional[str] = None,
|
|
) -> Dict[str, float]:
|
|
"""
|
|
Evaluate model on a dataset.
|
|
|
|
Args:
|
|
dataset_dir: Directory containing image pairs
|
|
num_samples: Number of samples to evaluate (None for all)
|
|
save_dir: Directory to save visualizations (optional)
|
|
|
|
Returns:
|
|
Dictionary of evaluation metrics
|
|
"""
|
|
# Create dataset
|
|
dataset = HomographyDataset(
|
|
root_dir=dataset_dir,
|
|
transform=None, # We'll handle transforms manually
|
|
augment=False,
|
|
image_size=tuple(self.config.get("image_size", [256, 256])),
|
|
cache_homographies=False,
|
|
)
|
|
|
|
if num_samples is not None:
|
|
indices = list(range(min(num_samples, len(dataset))))
|
|
else:
|
|
indices = list(range(len(dataset)))
|
|
|
|
errors = []
|
|
corner_errors = []
|
|
|
|
print(f"Evaluating on {len(indices)} samples...")
|
|
|
|
for idx in indices:
|
|
# Get sample without augmentation
|
|
sample = dataset.get_sample_without_augmentation(idx)
|
|
|
|
google_img = sample["google_img"]
|
|
yandex_img = sample["yandex_img"]
|
|
target_homography = sample["homography"]
|
|
|
|
# Predict homography
|
|
pred_homography = self.predict(
|
|
google_img, yandex_img, return_matrix=True, normalize=True
|
|
)
|
|
|
|
# Convert to numpy
|
|
pred_homography_np = pred_homography.cpu().numpy()
|
|
target_homography_np = target_homography
|
|
|
|
# Compute matrix error
|
|
matrix_error = np.mean((pred_homography_np - target_homography_np) ** 2)
|
|
errors.append(matrix_error)
|
|
|
|
# Compute corner error
|
|
corners = np.array(
|
|
[
|
|
[-1, -1, 1],
|
|
[1, -1, 1],
|
|
[1, 1, 1],
|
|
[-1, 1, 1],
|
|
],
|
|
dtype=np.float32,
|
|
).T
|
|
|
|
pred_corners = pred_homography_np @ corners
|
|
pred_corners = pred_corners / (pred_corners[2:3, :] + 1e-8)
|
|
|
|
target_corners = target_homography_np @ corners
|
|
target_corners = target_corners / (target_corners[2:3, :] + 1e-8)
|
|
|
|
corner_error = np.mean(
|
|
np.linalg.norm(pred_corners[:2, :] - target_corners[:2, :], axis=0)
|
|
)
|
|
corner_errors.append(corner_error)
|
|
|
|
# Save visualization if requested
|
|
if save_dir:
|
|
os.makedirs(save_dir, exist_ok=True)
|
|
vis_path = os.path.join(save_dir, f"sample_{idx:04d}.png")
|
|
self.visualize_alignment(
|
|
google_img,
|
|
yandex_img,
|
|
pred_homography_np,
|
|
save_path=vis_path,
|
|
show=False,
|
|
)
|
|
|
|
# Compute metrics
|
|
metrics = {
|
|
"mean_matrix_error": float(np.mean(errors)),
|
|
"std_matrix_error": float(np.std(errors)),
|
|
"mean_corner_error": float(np.mean(corner_errors)),
|
|
"std_corner_error": float(np.std(corner_errors)),
|
|
"median_corner_error": float(np.median(corner_errors)),
|
|
"max_corner_error": float(np.max(corner_errors)),
|
|
"min_corner_error": float(np.min(corner_errors)),
|
|
}
|
|
|
|
print("\nEvaluation Results:")
|
|
for key, value in metrics.items():
|
|
print(f" {key}: {value:.6f}")
|
|
|
|
return metrics
|
|
|
|
|
|
def main():
|
|
"""Main inference function."""
|
|
parser = argparse.ArgumentParser(description="Inference for homography estimation")
|
|
|
|
# Model arguments
|
|
parser.add_argument(
|
|
"--model_path",
|
|
type=str,
|
|
required=True,
|
|
help="Path to trained model checkpoint",
|
|
)
|
|
parser.add_argument(
|
|
"--config_path",
|
|
type=str,
|
|
help="Path to model configuration file (optional)",
|
|
)
|
|
parser.add_argument(
|
|
"--device",
|
|
type=str,
|
|
choices=["cuda", "cpu"],
|
|
help="Device to run inference on",
|
|
)
|
|
|
|
# Inference mode
|
|
parser.add_argument(
|
|
"--mode",
|
|
type=str,
|
|
default="single",
|
|
choices=["single", "dataset", "batch"],
|
|
help="Inference mode",
|
|
)
|
|
|
|
# Single image mode
|
|
parser.add_argument(
|
|
"--google_path",
|
|
type=str,
|
|
help="Path to Google map image (single mode)",
|
|
)
|
|
parser.add_argument(
|
|
"--yandex_path",
|
|
type=str,
|
|
help="Path to Yandex map image (single mode)",
|
|
)
|
|
parser.add_argument(
|
|
"--output_vis",
|
|
type=str,
|
|
help="Path to save visualization (single mode)",
|
|
)
|
|
|
|
# Dataset mode
|
|
parser.add_argument(
|
|
"--dataset_dir",
|
|
type=str,
|
|
default=r"C:\Users\admin\Projects\autopilot\datasets\ya_go_maps\images",
|
|
help="Directory containing image pairs (dataset mode)",
|
|
)
|
|
parser.add_argument(
|
|
"--num_samples",
|
|
type=int,
|
|
help="Number of samples to evaluate (dataset mode)",
|
|
)
|
|
parser.add_argument(
|
|
"--save_vis_dir",
|
|
type=str,
|
|
help="Directory to save visualizations (dataset mode)",
|
|
)
|
|
parser.add_argument(
|
|
"--save_results",
|
|
type=str,
|
|
help="Path to save evaluation results (dataset mode)",
|
|
)
|
|
|
|
args = parser.parse_args()
|
|
|
|
# Create inference object
|
|
inference = HomographyInference(
|
|
model_path=args.model_path,
|
|
config_path=args.config_path,
|
|
device=args.device,
|
|
)
|
|
|
|
if args.mode == "single":
|
|
# Single image pair inference
|
|
if not args.google_path or not args.yandex_path:
|
|
raise ValueError(
|
|
"Both --google_path and --yandex_path are required for single mode"
|
|
)
|
|
|
|
print(f"Processing single image pair:")
|
|
print(f" Google: {args.google_path}")
|
|
print(f" Yandex: {args.yandex_path}")
|
|
|
|
# Predict homography
|
|
homography = inference.predict_from_paths(args.google_path, args.yandex_path)
|
|
|
|
print(f"\nPredicted homography matrix:")
|
|
print(homography.cpu().numpy())
|
|
|
|
# Visualize alignment
|
|
if args.output_vis:
|
|
google_img = Image.open(args.google_path)
|
|
yandex_img = Image.open(args.yandex_path)
|
|
inference.visualize_alignment(
|
|
google_img,
|
|
yandex_img,
|
|
homography.cpu().numpy(),
|
|
save_path=args.output_vis,
|
|
show=True,
|
|
)
|
|
|
|
elif args.mode == "dataset":
|
|
# Evaluate on dataset
|
|
metrics = inference.evaluate_on_dataset(
|
|
dataset_dir=args.dataset_dir,
|
|
num_samples=args.num_samples,
|
|
save_dir=args.save_vis_dir,
|
|
)
|
|
|
|
# Save results if requested
|
|
if args.save_results:
|
|
with open(args.save_results, "w") as f:
|
|
json.dump(metrics, f, indent=2)
|
|
print(f"\nResults saved to {args.save_results}")
|
|
|
|
elif args.mode == "batch":
|
|
# Batch processing (placeholder for future implementation)
|
|
print("Batch mode not yet implemented")
|
|
# Could implement processing multiple image pairs from a directory
|
|
|
|
else:
|
|
raise ValueError(f"Unknown mode: {args.mode}")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|