147 lines
4.0 KiB
Python
147 lines
4.0 KiB
Python
"""
|
|
Script for predicting similarity between two images.
|
|
"""
|
|
|
|
import argparse
|
|
import os
|
|
from pathlib import Path
|
|
|
|
import torch
|
|
from model import SimilarityCNN
|
|
from PIL import Image
|
|
from torchvision import transforms
|
|
|
|
|
|
def load_image(image_path: str, image_size: tuple = (256, 256)) -> torch.Tensor:
|
|
"""Load and preprocess image."""
|
|
transform = transforms.Compose(
|
|
[
|
|
transforms.Resize(image_size),
|
|
transforms.ToTensor(),
|
|
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
|
]
|
|
)
|
|
|
|
image = Image.open(image_path).convert("RGB")
|
|
return transform(image).unsqueeze(0) # Add batch dimension
|
|
|
|
|
|
def predict_similarity(
|
|
model: SimilarityCNN,
|
|
image1_path: str,
|
|
image2_path: str,
|
|
device: torch.device,
|
|
image_size: tuple = (256, 256),
|
|
) -> float:
|
|
"""Predict similarity between two images."""
|
|
model.eval()
|
|
|
|
img1 = load_image(image1_path, image_size).to(device)
|
|
img2 = load_image(image2_path, image_size).to(device)
|
|
|
|
with torch.no_grad():
|
|
similarity = model(img1, img2)
|
|
|
|
return similarity.item()
|
|
|
|
|
|
def load_model(
|
|
checkpoint_path: str,
|
|
device: torch.device,
|
|
**model_kwargs,
|
|
) -> SimilarityCNN:
|
|
"""Load model from checkpoint."""
|
|
model = SimilarityCNN(**model_kwargs).to(device)
|
|
|
|
checkpoint = torch.load(checkpoint_path, map_location=device)
|
|
model.load_state_dict(checkpoint["model_state_dict"])
|
|
|
|
return model
|
|
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser(
|
|
description="Predict similarity between two images"
|
|
)
|
|
parser.add_argument("--image1", type=str, required=True, help="Path to first image")
|
|
parser.add_argument(
|
|
"--image2", type=str, required=True, help="Path to second image"
|
|
)
|
|
parser.add_argument(
|
|
"--checkpoint",
|
|
type=str,
|
|
default="runs/similarity/checkpoints/best_model.pt",
|
|
help="Path to model checkpoint",
|
|
)
|
|
parser.add_argument(
|
|
"--device",
|
|
type=str,
|
|
default="cuda" if torch.cuda.is_available() else "cpu",
|
|
help="Device to use for inference",
|
|
)
|
|
parser.add_argument(
|
|
"--image_size",
|
|
type=int,
|
|
default=256,
|
|
help="Image size for model input",
|
|
)
|
|
|
|
args = parser.parse_args()
|
|
|
|
device = torch.device(args.device)
|
|
print(f"Using device: {device}")
|
|
|
|
if not os.path.exists(args.image1):
|
|
print(f"Error: Image not found: {args.image1}")
|
|
return
|
|
|
|
if not os.path.exists(args.image2):
|
|
print(f"Error: Image not found: {args.image2}")
|
|
return
|
|
|
|
if not os.path.exists(args.checkpoint):
|
|
print(f"Warning: Checkpoint not found: {args.checkpoint}")
|
|
print("Using randomly initialized model for demonstration")
|
|
model = SimilarityCNN(
|
|
input_channels=3,
|
|
hidden_channels=64,
|
|
num_blocks=4,
|
|
dropout_rate=0.3,
|
|
use_batch_norm=True,
|
|
).to(device)
|
|
else:
|
|
print(f"Loading model from: {args.checkpoint}")
|
|
model = load_model(
|
|
checkpoint_path=args.checkpoint,
|
|
device=device,
|
|
input_channels=3,
|
|
hidden_channels=64,
|
|
num_blocks=4,
|
|
dropout_rate=0.3,
|
|
use_batch_norm=True,
|
|
)
|
|
|
|
print(
|
|
f"Model loaded with {sum(p.numel() for p in model.parameters()):,} parameters"
|
|
)
|
|
|
|
similarity = predict_similarity(
|
|
model=model,
|
|
image1_path=args.image1,
|
|
image2_path=args.image2,
|
|
device=device,
|
|
image_size=(args.image_size, args.image_size),
|
|
)
|
|
|
|
print(f"\nSimilarity between images:")
|
|
print(f" Image 1: {args.image1}")
|
|
print(f" Image 2: {args.image2}")
|
|
print(f" Similarity score: {similarity:.4f}")
|
|
print(f" Interpretation: {'Similar' if similarity > 0.5 else 'Different'}")
|
|
|
|
return similarity
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|