feat: add similarity model
This commit is contained in:
146
models/SiaN-similarity/predict.py
Normal file
146
models/SiaN-similarity/predict.py
Normal file
@@ -0,0 +1,146 @@
|
||||
"""
|
||||
Script for predicting similarity between two images.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from model import SimilarityCNN
|
||||
from PIL import Image
|
||||
from torchvision import transforms
|
||||
|
||||
|
||||
def load_image(image_path: str, image_size: tuple = (256, 256)) -> torch.Tensor:
|
||||
"""Load and preprocess image."""
|
||||
transform = transforms.Compose(
|
||||
[
|
||||
transforms.Resize(image_size),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
||||
]
|
||||
)
|
||||
|
||||
image = Image.open(image_path).convert("RGB")
|
||||
return transform(image).unsqueeze(0) # Add batch dimension
|
||||
|
||||
|
||||
def predict_similarity(
|
||||
model: SimilarityCNN,
|
||||
image1_path: str,
|
||||
image2_path: str,
|
||||
device: torch.device,
|
||||
image_size: tuple = (256, 256),
|
||||
) -> float:
|
||||
"""Predict similarity between two images."""
|
||||
model.eval()
|
||||
|
||||
img1 = load_image(image1_path, image_size).to(device)
|
||||
img2 = load_image(image2_path, image_size).to(device)
|
||||
|
||||
with torch.no_grad():
|
||||
similarity = model(img1, img2)
|
||||
|
||||
return similarity.item()
|
||||
|
||||
|
||||
def load_model(
|
||||
checkpoint_path: str,
|
||||
device: torch.device,
|
||||
**model_kwargs,
|
||||
) -> SimilarityCNN:
|
||||
"""Load model from checkpoint."""
|
||||
model = SimilarityCNN(**model_kwargs).to(device)
|
||||
|
||||
checkpoint = torch.load(checkpoint_path, map_location=device)
|
||||
model.load_state_dict(checkpoint["model_state_dict"])
|
||||
|
||||
return model
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Predict similarity between two images"
|
||||
)
|
||||
parser.add_argument("--image1", type=str, required=True, help="Path to first image")
|
||||
parser.add_argument(
|
||||
"--image2", type=str, required=True, help="Path to second image"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--checkpoint",
|
||||
type=str,
|
||||
default="runs/similarity/checkpoints/best_model.pt",
|
||||
help="Path to model checkpoint",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--device",
|
||||
type=str,
|
||||
default="cuda" if torch.cuda.is_available() else "cpu",
|
||||
help="Device to use for inference",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--image_size",
|
||||
type=int,
|
||||
default=256,
|
||||
help="Image size for model input",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
device = torch.device(args.device)
|
||||
print(f"Using device: {device}")
|
||||
|
||||
if not os.path.exists(args.image1):
|
||||
print(f"Error: Image not found: {args.image1}")
|
||||
return
|
||||
|
||||
if not os.path.exists(args.image2):
|
||||
print(f"Error: Image not found: {args.image2}")
|
||||
return
|
||||
|
||||
if not os.path.exists(args.checkpoint):
|
||||
print(f"Warning: Checkpoint not found: {args.checkpoint}")
|
||||
print("Using randomly initialized model for demonstration")
|
||||
model = SimilarityCNN(
|
||||
input_channels=3,
|
||||
hidden_channels=64,
|
||||
num_blocks=4,
|
||||
dropout_rate=0.3,
|
||||
use_batch_norm=True,
|
||||
).to(device)
|
||||
else:
|
||||
print(f"Loading model from: {args.checkpoint}")
|
||||
model = load_model(
|
||||
checkpoint_path=args.checkpoint,
|
||||
device=device,
|
||||
input_channels=3,
|
||||
hidden_channels=64,
|
||||
num_blocks=4,
|
||||
dropout_rate=0.3,
|
||||
use_batch_norm=True,
|
||||
)
|
||||
|
||||
print(
|
||||
f"Model loaded with {sum(p.numel() for p in model.parameters()):,} parameters"
|
||||
)
|
||||
|
||||
similarity = predict_similarity(
|
||||
model=model,
|
||||
image1_path=args.image1,
|
||||
image2_path=args.image2,
|
||||
device=device,
|
||||
image_size=(args.image_size, args.image_size),
|
||||
)
|
||||
|
||||
print(f"\nSimilarity between images:")
|
||||
print(f" Image 1: {args.image1}")
|
||||
print(f" Image 2: {args.image2}")
|
||||
print(f" Similarity score: {similarity:.4f}")
|
||||
print(f" Interpretation: {'Similar' if similarity > 0.5 else 'Different'}")
|
||||
|
||||
return similarity
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user