Files

147 lines
4.0 KiB
Python

"""
Script for predicting similarity between two images.
"""
import argparse
import os
from pathlib import Path
import torch
from model import SimilarityCNN
from PIL import Image
from torchvision import transforms
def load_image(image_path: str, image_size: tuple = (256, 256)) -> torch.Tensor:
"""Load and preprocess image."""
transform = transforms.Compose(
[
transforms.Resize(image_size),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
]
)
image = Image.open(image_path).convert("RGB")
return transform(image).unsqueeze(0) # Add batch dimension
def predict_similarity(
model: SimilarityCNN,
image1_path: str,
image2_path: str,
device: torch.device,
image_size: tuple = (256, 256),
) -> float:
"""Predict similarity between two images."""
model.eval()
img1 = load_image(image1_path, image_size).to(device)
img2 = load_image(image2_path, image_size).to(device)
with torch.no_grad():
similarity = model(img1, img2)
return similarity.item()
def load_model(
checkpoint_path: str,
device: torch.device,
**model_kwargs,
) -> SimilarityCNN:
"""Load model from checkpoint."""
model = SimilarityCNN(**model_kwargs).to(device)
checkpoint = torch.load(checkpoint_path, map_location=device)
model.load_state_dict(checkpoint["model_state_dict"])
return model
def main():
parser = argparse.ArgumentParser(
description="Predict similarity between two images"
)
parser.add_argument("--image1", type=str, required=True, help="Path to first image")
parser.add_argument(
"--image2", type=str, required=True, help="Path to second image"
)
parser.add_argument(
"--checkpoint",
type=str,
default="runs/similarity/checkpoints/best_model.pt",
help="Path to model checkpoint",
)
parser.add_argument(
"--device",
type=str,
default="cuda" if torch.cuda.is_available() else "cpu",
help="Device to use for inference",
)
parser.add_argument(
"--image_size",
type=int,
default=256,
help="Image size for model input",
)
args = parser.parse_args()
device = torch.device(args.device)
print(f"Using device: {device}")
if not os.path.exists(args.image1):
print(f"Error: Image not found: {args.image1}")
return
if not os.path.exists(args.image2):
print(f"Error: Image not found: {args.image2}")
return
if not os.path.exists(args.checkpoint):
print(f"Warning: Checkpoint not found: {args.checkpoint}")
print("Using randomly initialized model for demonstration")
model = SimilarityCNN(
input_channels=3,
hidden_channels=64,
num_blocks=4,
dropout_rate=0.3,
use_batch_norm=True,
).to(device)
else:
print(f"Loading model from: {args.checkpoint}")
model = load_model(
checkpoint_path=args.checkpoint,
device=device,
input_channels=3,
hidden_channels=64,
num_blocks=4,
dropout_rate=0.3,
use_batch_norm=True,
)
print(
f"Model loaded with {sum(p.numel() for p in model.parameters()):,} parameters"
)
similarity = predict_similarity(
model=model,
image1_path=args.image1,
image2_path=args.image2,
device=device,
image_size=(args.image_size, args.image_size),
)
print(f"\nSimilarity between images:")
print(f" Image 1: {args.image1}")
print(f" Image 2: {args.image2}")
print(f" Similarity score: {similarity:.4f}")
print(f" Interpretation: {'Similar' if similarity > 0.5 else 'Different'}")
return similarity
if __name__ == "__main__":
main()