""" Script for predicting similarity between two images. """ import argparse import os from pathlib import Path import torch from model import SimilarityCNN from PIL import Image from torchvision import transforms def load_image(image_path: str, image_size: tuple = (256, 256)) -> torch.Tensor: """Load and preprocess image.""" transform = transforms.Compose( [ transforms.Resize(image_size), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ] ) image = Image.open(image_path).convert("RGB") return transform(image).unsqueeze(0) # Add batch dimension def predict_similarity( model: SimilarityCNN, image1_path: str, image2_path: str, device: torch.device, image_size: tuple = (256, 256), ) -> float: """Predict similarity between two images.""" model.eval() img1 = load_image(image1_path, image_size).to(device) img2 = load_image(image2_path, image_size).to(device) with torch.no_grad(): similarity = model(img1, img2) return similarity.item() def load_model( checkpoint_path: str, device: torch.device, **model_kwargs, ) -> SimilarityCNN: """Load model from checkpoint.""" model = SimilarityCNN(**model_kwargs).to(device) checkpoint = torch.load(checkpoint_path, map_location=device) model.load_state_dict(checkpoint["model_state_dict"]) return model def main(): parser = argparse.ArgumentParser( description="Predict similarity between two images" ) parser.add_argument("--image1", type=str, required=True, help="Path to first image") parser.add_argument( "--image2", type=str, required=True, help="Path to second image" ) parser.add_argument( "--checkpoint", type=str, default="runs/similarity/checkpoints/best_model.pt", help="Path to model checkpoint", ) parser.add_argument( "--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu", help="Device to use for inference", ) parser.add_argument( "--image_size", type=int, default=256, help="Image size for model input", ) args = parser.parse_args() device = torch.device(args.device) print(f"Using device: {device}") if not os.path.exists(args.image1): print(f"Error: Image not found: {args.image1}") return if not os.path.exists(args.image2): print(f"Error: Image not found: {args.image2}") return if not os.path.exists(args.checkpoint): print(f"Warning: Checkpoint not found: {args.checkpoint}") print("Using randomly initialized model for demonstration") model = SimilarityCNN( input_channels=3, hidden_channels=64, num_blocks=4, dropout_rate=0.3, use_batch_norm=True, ).to(device) else: print(f"Loading model from: {args.checkpoint}") model = load_model( checkpoint_path=args.checkpoint, device=device, input_channels=3, hidden_channels=64, num_blocks=4, dropout_rate=0.3, use_batch_norm=True, ) print( f"Model loaded with {sum(p.numel() for p in model.parameters()):,} parameters" ) similarity = predict_similarity( model=model, image1_path=args.image1, image2_path=args.image2, device=device, image_size=(args.image_size, args.image_size), ) print(f"\nSimilarity between images:") print(f" Image 1: {args.image1}") print(f" Image 2: {args.image2}") print(f" Similarity score: {similarity:.4f}") print(f" Interpretation: {'Similar' if similarity > 0.5 else 'Different'}") return similarity if __name__ == "__main__": main()