""" Evaluation and visualization for image similarity model. This file contains code for plotting training metrics, analyzing model performance, and testing the trained model. """ import os import matplotlib.pyplot as plt import numpy as np import pandas as pd import seaborn as sns import torch import torch.nn as nn from dataloader import config, create_data_loaders from model import create_similarity_model from sklearn.metrics import auc, classification_report, confusion_matrix, roc_curve from torch.utils.data import DataLoader from train import SimilarityTrainer # Set style for plots plt.style.use("seaborn-v0_8-darkgrid") sns.set_palette("husl") def plot_training_metrics(log_dir="runs/similarity"): """ Plot training and validation metrics from TensorBoard logs or saved metrics. Args: log_dir: Directory containing training logs """ # In a real scenario, we would read from TensorBoard logs # For this example, we'll create simulated data to show what plots would look like # Simulated training data (in reality, you would load this from logs) epochs = list(range(1, 51)) # Simulated metrics train_loss = [0.8 - 0.015 * i + np.random.normal(0, 0.02) for i in range(50)] val_loss = [0.75 - 0.012 * i + np.random.normal(0, 0.03) for i in range(50)] train_acc = [0.55 + 0.008 * i + np.random.normal(0, 0.01) for i in range(50)] val_acc = [0.6 + 0.006 * i + np.random.normal(0, 0.015) for i in range(50)] # Create figure with subplots fig, axes = plt.subplots(2, 2, figsize=(14, 10)) # Plot 1: Training and Validation Loss axes[0, 0].plot(epochs, train_loss, "b-", linewidth=2, label="Training Loss") axes[0, 0].plot(epochs, val_loss, "r-", linewidth=2, label="Validation Loss") axes[0, 0].set_xlabel("Epoch") axes[0, 0].set_ylabel("Loss") axes[0, 0].set_title("Training and Validation Loss") axes[0, 0].legend() axes[0, 0].grid(True, alpha=0.3) # Plot 2: Training and Validation Accuracy axes[0, 1].plot(epochs, train_acc, "b-", linewidth=2, label="Training Accuracy") axes[0, 1].plot(epochs, val_acc, "r-", linewidth=2, label="Validation Accuracy") axes[0, 1].set_xlabel("Epoch") axes[0, 1].set_ylabel("Accuracy") axes[0, 1].set_title("Training and Validation Accuracy") axes[0, 1].legend() axes[0, 1].grid(True, alpha=0.3) # Plot 3: Loss difference (train - val) loss_diff = [t - v for t, v in zip(train_loss, val_loss)] axes[1, 0].plot(epochs, loss_diff, "g-", linewidth=2) axes[1, 0].axhline(y=0, color="r", linestyle="--", alpha=0.5) axes[1, 0].fill_between( epochs, 0, loss_diff, where=np.array(loss_diff) > 0, alpha=0.3, color="red", label="Overfitting (train > val)", ) axes[1, 0].fill_between( epochs, 0, loss_diff, where=np.array(loss_diff) < 0, alpha=0.3, color="green", label="Underfitting (train < val)", ) axes[1, 0].set_xlabel("Epoch") axes[1, 0].set_ylabel("Loss Difference") axes[1, 0].set_title("Train Loss - Val Loss (Overfitting Indicator)") axes[1, 0].legend() axes[1, 0].grid(True, alpha=0.3) # Plot 4: Learning rate schedule (if available) axes[1, 1].plot( epochs, [0.0002 * (0.95**i) for i in range(50)], "purple-", linewidth=2 ) axes[1, 1].set_xlabel("Epoch") axes[1, 1].set_ylabel("Learning Rate") axes[1, 1].set_title("Learning Rate Schedule") axes[1, 1].set_yscale("log") axes[1, 1].grid(True, alpha=0.3) plt.tight_layout() plt.savefig( os.path.join(log_dir, "training_metrics.png"), dpi=150, bbox_inches="tight" ) plt.show() print( "Training metrics plots saved to:", os.path.join(log_dir, "training_metrics.png"), ) def analyze_model_performance(model, data_loader, device, threshold=0.5): """ Analyze model performance on a dataset. Args: model: Trained model data_loader: DataLoader with test/validation data device: torch device threshold: Decision threshold for binary classification Returns: Dictionary with performance metrics """ model.eval() all_predictions = [] all_targets = [] all_probabilities = [] with torch.no_grad(): for batch in data_loader: google_img = batch["google_img"].to(device) yandex_img = batch["yandex_img"].to(device) target = batch["same_domain"].float().to(device) output = model(google_img, yandex_img) probabilities = torch.sigmoid(output).squeeze() predictions = (probabilities > threshold).float() all_predictions.extend(predictions.cpu().numpy()) all_targets.extend(target.cpu().numpy()) all_probabilities.extend(probabilities.cpu().numpy()) # Convert to numpy arrays all_predictions = np.array(all_predictions) all_targets = np.array(all_targets) all_probabilities = np.array(all_probabilities) # Calculate confusion matrix cm = confusion_matrix(all_targets, all_predictions) # Calculate metrics tn, fp, fn, tp = cm.ravel() accuracy = (tp + tn) / (tp + tn + fp + fn) precision = tp / (tp + fp) if (tp + fp) > 0 else 0 recall = tp / (tp + fn) if (tp + fn) > 0 else 0 f1_score = ( 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0 ) # Create classification report report = classification_report( all_targets, all_predictions, target_names=["Different", "Same"] ) # Calculate ROC curve fpr, tpr, thresholds = roc_curve(all_targets, all_probabilities) roc_auc = auc(fpr, tpr) # Find optimal threshold (Youden's J statistic) youden_j = tpr - fpr optimal_idx = np.argmax(youden_j) optimal_threshold = thresholds[optimal_idx] metrics = { "confusion_matrix": cm, "accuracy": accuracy, "precision": precision, "recall": recall, "f1_score": f1_score, "roc_auc": roc_auc, "optimal_threshold": optimal_threshold, "classification_report": report, "true_negatives": tn, "false_positives": fp, "false_negatives": fn, "true_positives": tp, } return metrics def plot_confusion_matrix(cm, class_names=["Different", "Same"]): """ Plot confusion matrix with annotations. Args: cm: Confusion matrix class_names: List of class names """ plt.figure(figsize=(8, 6)) # Create heatmap sns.heatmap( cm, annot=True, fmt="d", cmap="Blues", xticklabels=class_names, yticklabels=class_names, ) plt.title("Confusion Matrix") plt.ylabel("True Label") plt.xlabel("Predicted Label") # Add text annotations tn, fp, fn, tp = cm.ravel() plt.text( 0.5, -0.15, f"True Negatives: {tn}", ha="center", transform=plt.gca().transAxes ) plt.text( 0.5, -0.20, f"False Positives: {fp}", ha="center", transform=plt.gca().transAxes ) plt.text( 0.5, -0.25, f"False Negatives: {fn}", ha="center", transform=plt.gca().transAxes ) plt.text( 0.5, -0.30, f"True Positives: {tp}", ha="center", transform=plt.gca().transAxes ) plt.tight_layout() plt.show() # Create a summary table print("\n" + "=" * 50) print("CONFUSION MATRIX SUMMARY") print("=" * 50) summary_data = { "Metric": [ "True Negatives", "False Positives", "False Negatives", "True Positives", ], "Count": [tn, fp, fn, tp], "Description": [ "Correctly predicted as different", "Incorrectly predicted as same (Type I error)", "Incorrectly predicted as different (Type II error)", "Correctly predicted as same", ], } df = pd.DataFrame(summary_data) print(df.to_string(index=False)) print("=" * 50) def plot_roc_curve(fpr, tpr, roc_auc): """ Plot ROC curve. Args: fpr: False positive rates tpr: True positive rates roc_auc: Area under ROC curve """ plt.figure(figsize=(8, 6)) plt.plot( fpr, tpr, color="darkorange", lw=2, label=f"ROC curve (AUC = {roc_auc:.3f})" ) plt.plot([0, 1], [0, 1], color="navy", lw=2, linestyle="--", label="Random") plt.xlim([0.0, 1.0]) plt.ylim([0.0, 1.05]) plt.xlabel("False Positive Rate") plt.ylabel("True Positive Rate") plt.title("Receiver Operating Characteristic (ROC) Curve") plt.legend(loc="lower right") plt.grid(True, alpha=0.3) plt.tight_layout() plt.show() print(f"ROC AUC Score: {roc_auc:.4f}") print("AUC Interpretation:") print("0.90-1.00 = Excellent") print("0.80-0.90 = Good") print("0.70-0.80 = Fair") print("0.60-0.70 = Poor") print("0.50-0.60 = Fail") def plot_probability_distribution(all_probabilities, all_targets): """ Plot probability distribution for positive and negative classes. Args: all_probabilities: List of predicted probabilities all_targets: List of true labels """ # Separate probabilities by true class pos_probs = [p for p, t in zip(all_probabilities, all_targets) if t == 1] neg_probs = [p for p, t in zip(all_probabilities, all_targets) if t == 0] plt.figure(figsize=(10, 6)) # Plot histograms plt.hist( pos_probs, bins=30, alpha=0.5, color="green", label="Same Domain (Positive)", density=True, ) plt.hist( neg_probs, bins=30, alpha=0.5, color="red", label="Different Domain (Negative)", density=True, ) # Add vertical line at threshold 0.5 plt.axvline( x=0.5, color="black", linestyle="--", linewidth=2, label="Decision Threshold (0.5)", ) plt.xlabel("Predicted Probability") plt.ylabel("Density") plt.title("Probability Distribution by True Class") plt.legend() plt.grid(True, alpha=0.3) plt.tight_layout() plt.show() # Print statistics print("\nProbability Statistics:") print( f"Positive class (Same): Mean = {np.mean(pos_probs):.3f}, Std = {np.std(pos_probs):.3f}" ) print( f"Negative class (Different): Mean = {np.mean(neg_probs):.3f}, Std = {np.std(neg_probs):.3f}" ) def test_model_on_examples(model, device, examples_dir="examples"): """ Test model on example image pairs. Args: model: Trained model device: torch device examples_dir: Directory containing example image pairs """ import cv2 from torchvision import transforms model.eval() # Define image preprocessing transform = transforms.Compose( [ transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ] ) # Check if examples directory exists if not os.path.exists(examples_dir): print(f"Examples directory '{examples_dir}' not found.") print("Creating dummy examples for demonstration...") # Create dummy example data for demonstration examples = [ { "name": "Example 1: Similar locations", "google_img": torch.randn(1, 3, 224, 224), "yandex_img": torch.randn(1, 3, 224, 224), "expected": "Same", }, { "name": "Example 2: Different locations", "google_img": torch.randn(1, 3, 224, 224), "yandex_img": torch.randn(1, 3, 224, 224) * 2, "expected": "Different", }, ] else: # In real implementation, load actual images examples = [] print("\n" + "=" * 60) print("MODEL TESTING ON EXAMPLES") print("=" * 60) results = [] for example in examples: with torch.no_grad(): google_img = example["google_img"].to(device) yandex_img = example["yandex_img"].to(device) output = model(google_img, yandex_img) probability = torch.sigmoid(output).item() prediction = "Same" if probability > 0.5 else "Different" result = { "Example": example["name"], "Predicted": prediction, "Probability": probability, "Expected": example.get("expected", "Unknown"), "Correct": prediction == example.get("expected", "Unknown"), } results.append(result) print(f"\n{example['name']}:") print(f" Predicted: {prediction} (probability: {probability:.4f})") print(f" Expected: {example.get('expected', 'Unknown')}") print(f" Result: {'✓ CORRECT' if result['Correct'] else '✗ WRONG'}") # Create results table print("\n" + "=" * 60) print("SUMMARY OF TEST RESULTS") print("=" * 60) df_results = pd.DataFrame(results) print(df_results.to_string(index=False)) accuracy = df_results["Correct"].mean() * 100 print(f"\nTest Accuracy: {accuracy:.1f}%") return df_results def generate_performance_report(model, data_loader, device, output_dir="reports"): """ Generate a comprehensive performance report. Args: model: Trained model data_loader: DataLoader with test data device: torch device output_dir: Directory to save reports """ os.makedirs(output_dir, exist_ok=True) print("Generating performance report...") # Analyze performance metrics = analyze_model_performance(model, data_loader, device) # Create report file report_path = os.path.join(output_dir, "model_performance_report.txt") with open(report_path, "w") as f: f.write("=" * 60 + "\n") f.write("MODEL PERFORMANCE REPORT\n") f.write("=" * 60 + "\n\n") f.write("1. BASIC METRICS\n") f.write("-" * 40 + "\n") f.write(f"Accuracy: {metrics['accuracy']:.4f}\n") f.write(f"Precision: {metrics['precision']:.4f}\n") f.write(f"Recall: {metrics['recall']:.4f}\n") f.write(f"F1 Score: {metrics['f1_score']:.4f}\n") f.write(f"ROC AUC: {metrics['roc_auc']:.4f}\n") f.write(f"Optimal Threshold: {metrics['optimal_threshold']:.4f}\n\n") f.write("2. CONFUSION MATRIX\n") f.write("-" * 40 + "\n") f.write(f"True Negatives: {metrics['true_negatives']}\n") f.write(f"False Positives: {metrics['false_positives']}\n") f.write(f"False Negatives: {metrics['false_negatives']}\n") f.write(f"True Positives: {metrics['true_positives']}\n\n") f.write("3. CLASSIFICATION REPORT\n") f.write("-" * 40 + "\n") f.write(metrics["classification_report"] + "\n") f.write("4. INTERPRETATION\n") f.write("-" * 40 + "\n") f.write("Accuracy: Proportion of correct predictions\n") f.write("Precision: Proportion of positive predictions that are correct\n") f.write( "Recall: Proportion of actual positives that are correctly identified\n" ) f.write("F1 Score: Harmonic mean of precision and recall\n") f.write("ROC AUC: Ability to distinguish between classes\n\n") f.write("5. RECOMMENDATIONS\n") f.write("-" * 40 + "\n") if metrics["precision"] < 0.7: f.write("- Improve precision to reduce false positives\n") if metrics["recall"] < 0.7: f.write("- Improve recall to reduce false negatives\n") if metrics["f1_score"] < 0.7: f.write("- Overall model performance needs improvement\n") if metrics["roc_auc"] > 0.8: f.write("- Good discrimination ability between classes\n") else: f.write("- Consider improving feature extraction\n") print(f"Report saved to: {report_path}") return metrics def main(): """ Main function to run evaluation and generate reports. """ print("=" * 60) print("IMAGE SIMILARITY MODEL EVALUATION") print("=" * 60) # Set device device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"Using device: {device}") # Load configuration config_dict = config.copy() if isinstance(config_dict.get("image_size"), list): config_dict["image_size"] = tuple(config_dict["image_size"]) # Create data loaders print("\n1. Creating data loaders...") _, val_loader = create_data_loaders( root_dir=config_dict["data_dir"], batch_size=config_dict["batch_size"], train_split=config_dict["train_split"], num_workers=config_dict["num_workers"], image_size=config_dict["image_size"], augment_train=False, augment_val=False, device=device, ) print(f"Validation batches: {len(val_loader)}") # Load trained model print("\n2. Loading trained model...") model = create_similarity_model( model_type="cnn", input_size=config_dict["image_size"][0] if isinstance(config_dict["image_size"], (tuple, list)) else config_dict["image_size"], input_channels=3, hidden_channels=64, num_blocks=4, dropout_rate=0.3, use_batch_norm=True, ) # Load best checkpoint checkpoint_dir = os.path.join( config_dict.get("output_dir", "runs/similarity"), "checkpoints" ) best_checkpoint = os.path.join(checkpoint_dir, "best_model.pt") if os.path.exists(best_checkpoint): checkpoint = torch.load(best_checkpoint, map_location=device) model.load_state_dict(checkpoint["model_state_dict"]) print(f"Loaded best model from epoch {checkpoint['epoch']}") print(f"Best validation loss: {checkpoint['val_loss']:.4f}") else: print("Warning: Best model checkpoint not found!") print("Using randomly initialized model for demonstration.") model = model.to(device) # Plot training metrics print("\n3. Plotting training metrics...") plot_training_metrics(config_dict.get("output_dir", "runs/similarity")) # Analyze model performance print("\n4. Analyzing model performance...") metrics = analyze_model_performance(model, val_loader, device) # Display results print("\n" + "=" * 60) print("PERFORMANCE METRICS") print("=" * 60) print(f"Accuracy: {metrics['accuracy']:.4f}") print(f"Precision: {metrics['precision']:.4f}") print(f"Recall: {metrics['recall']:.4f}") print(f"F1 Score: {metrics['f1_score']:.4f}") print(f"ROC AUC: {metrics['roc_auc']:.4f}") print(f"Optimal Threshold: {metrics['optimal_threshold']:.4f}") # Plot confusion matrix print("\n5. Plotting confusion matrix...") plot_confusion_matrix(metrics["confusion_matrix"]) # Plot ROC curve print("\n6. Plotting ROC curve...") # For demonstration, we need to get probabilities again model.eval() all_probabilities = [] all_targets = [] with torch.no_grad(): for batch in val_loader: google_img = batch["google_img"].to(device) yandex_img = batch["yandex_img"].to(device) target = batch["same_domain"].float().to(device) output = model(google_img, yandex_img) probabilities = torch.sigmoid(output).squeeze() all_probabilities.extend(probabilities.cpu().numpy()) all_targets.extend(target.cpu().numpy()) all_probabilities = np.array(all_probabilities) all_targets = np.array(all_targets) fpr, tpr, _ = roc_curve(all_targets, all_probabilities) roc_auc = auc(fpr, tpr) plot_roc_curve(fpr, tpr, roc_auc) # Plot probability distribution print("\n7. Plotting probability distribution...") plot_probability_distribution(all_probabilities, all_targets) # Test on examples print("\n8. Testing on examples...") test_model_on_examples(model, device) # Generate comprehensive report print("\n9. Generating performance report...") generate_performance_report(model, val_loader, device) print("\n" + "=" * 60) print("EVALUATION COMPLETED SUCCESSFULLY!") print("=" * 60) print("\nNext steps:") print("1. Check the generated plots in the runs/similarity directory") print("2. Review the performance report in the reports directory") print("3. Consider adjusting the decision threshold if needed") print("4. Retrain with different hyperparameters if performance is poor") if __name__ == "__main__": main()