Files
autopilot/models/SiaN-similarity/evaluation.py

664 lines
21 KiB
Python

"""
Evaluation and visualization for image similarity model.
This file contains code for plotting training metrics, analyzing model performance,
and testing the trained model.
"""
import os
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
import torch
import torch.nn as nn
from dataloader import config, create_data_loaders
from model import create_similarity_model
from sklearn.metrics import auc, classification_report, confusion_matrix, roc_curve
from torch.utils.data import DataLoader
from train import SimilarityTrainer
# Set style for plots
plt.style.use("seaborn-v0_8-darkgrid")
sns.set_palette("husl")
def plot_training_metrics(log_dir="runs/similarity"):
"""
Plot training and validation metrics from TensorBoard logs or saved metrics.
Args:
log_dir: Directory containing training logs
"""
# In a real scenario, we would read from TensorBoard logs
# For this example, we'll create simulated data to show what plots would look like
# Simulated training data (in reality, you would load this from logs)
epochs = list(range(1, 51))
# Simulated metrics
train_loss = [0.8 - 0.015 * i + np.random.normal(0, 0.02) for i in range(50)]
val_loss = [0.75 - 0.012 * i + np.random.normal(0, 0.03) for i in range(50)]
train_acc = [0.55 + 0.008 * i + np.random.normal(0, 0.01) for i in range(50)]
val_acc = [0.6 + 0.006 * i + np.random.normal(0, 0.015) for i in range(50)]
# Create figure with subplots
fig, axes = plt.subplots(2, 2, figsize=(14, 10))
# Plot 1: Training and Validation Loss
axes[0, 0].plot(epochs, train_loss, "b-", linewidth=2, label="Training Loss")
axes[0, 0].plot(epochs, val_loss, "r-", linewidth=2, label="Validation Loss")
axes[0, 0].set_xlabel("Epoch")
axes[0, 0].set_ylabel("Loss")
axes[0, 0].set_title("Training and Validation Loss")
axes[0, 0].legend()
axes[0, 0].grid(True, alpha=0.3)
# Plot 2: Training and Validation Accuracy
axes[0, 1].plot(epochs, train_acc, "b-", linewidth=2, label="Training Accuracy")
axes[0, 1].plot(epochs, val_acc, "r-", linewidth=2, label="Validation Accuracy")
axes[0, 1].set_xlabel("Epoch")
axes[0, 1].set_ylabel("Accuracy")
axes[0, 1].set_title("Training and Validation Accuracy")
axes[0, 1].legend()
axes[0, 1].grid(True, alpha=0.3)
# Plot 3: Loss difference (train - val)
loss_diff = [t - v for t, v in zip(train_loss, val_loss)]
axes[1, 0].plot(epochs, loss_diff, "g-", linewidth=2)
axes[1, 0].axhline(y=0, color="r", linestyle="--", alpha=0.5)
axes[1, 0].fill_between(
epochs,
0,
loss_diff,
where=np.array(loss_diff) > 0,
alpha=0.3,
color="red",
label="Overfitting (train > val)",
)
axes[1, 0].fill_between(
epochs,
0,
loss_diff,
where=np.array(loss_diff) < 0,
alpha=0.3,
color="green",
label="Underfitting (train < val)",
)
axes[1, 0].set_xlabel("Epoch")
axes[1, 0].set_ylabel("Loss Difference")
axes[1, 0].set_title("Train Loss - Val Loss (Overfitting Indicator)")
axes[1, 0].legend()
axes[1, 0].grid(True, alpha=0.3)
# Plot 4: Learning rate schedule (if available)
axes[1, 1].plot(
epochs, [0.0002 * (0.95**i) for i in range(50)], "purple-", linewidth=2
)
axes[1, 1].set_xlabel("Epoch")
axes[1, 1].set_ylabel("Learning Rate")
axes[1, 1].set_title("Learning Rate Schedule")
axes[1, 1].set_yscale("log")
axes[1, 1].grid(True, alpha=0.3)
plt.tight_layout()
plt.savefig(
os.path.join(log_dir, "training_metrics.png"), dpi=150, bbox_inches="tight"
)
plt.show()
print(
"Training metrics plots saved to:",
os.path.join(log_dir, "training_metrics.png"),
)
def analyze_model_performance(model, data_loader, device, threshold=0.5):
"""
Analyze model performance on a dataset.
Args:
model: Trained model
data_loader: DataLoader with test/validation data
device: torch device
threshold: Decision threshold for binary classification
Returns:
Dictionary with performance metrics
"""
model.eval()
all_predictions = []
all_targets = []
all_probabilities = []
with torch.no_grad():
for batch in data_loader:
google_img = batch["google_img"].to(device)
yandex_img = batch["yandex_img"].to(device)
target = batch["same_domain"].float().to(device)
output = model(google_img, yandex_img)
probabilities = torch.sigmoid(output).squeeze()
predictions = (probabilities > threshold).float()
all_predictions.extend(predictions.cpu().numpy())
all_targets.extend(target.cpu().numpy())
all_probabilities.extend(probabilities.cpu().numpy())
# Convert to numpy arrays
all_predictions = np.array(all_predictions)
all_targets = np.array(all_targets)
all_probabilities = np.array(all_probabilities)
# Calculate confusion matrix
cm = confusion_matrix(all_targets, all_predictions)
# Calculate metrics
tn, fp, fn, tp = cm.ravel()
accuracy = (tp + tn) / (tp + tn + fp + fn)
precision = tp / (tp + fp) if (tp + fp) > 0 else 0
recall = tp / (tp + fn) if (tp + fn) > 0 else 0
f1_score = (
2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0
)
# Create classification report
report = classification_report(
all_targets, all_predictions, target_names=["Different", "Same"]
)
# Calculate ROC curve
fpr, tpr, thresholds = roc_curve(all_targets, all_probabilities)
roc_auc = auc(fpr, tpr)
# Find optimal threshold (Youden's J statistic)
youden_j = tpr - fpr
optimal_idx = np.argmax(youden_j)
optimal_threshold = thresholds[optimal_idx]
metrics = {
"confusion_matrix": cm,
"accuracy": accuracy,
"precision": precision,
"recall": recall,
"f1_score": f1_score,
"roc_auc": roc_auc,
"optimal_threshold": optimal_threshold,
"classification_report": report,
"true_negatives": tn,
"false_positives": fp,
"false_negatives": fn,
"true_positives": tp,
}
return metrics
def plot_confusion_matrix(cm, class_names=["Different", "Same"]):
"""
Plot confusion matrix with annotations.
Args:
cm: Confusion matrix
class_names: List of class names
"""
plt.figure(figsize=(8, 6))
# Create heatmap
sns.heatmap(
cm,
annot=True,
fmt="d",
cmap="Blues",
xticklabels=class_names,
yticklabels=class_names,
)
plt.title("Confusion Matrix")
plt.ylabel("True Label")
plt.xlabel("Predicted Label")
# Add text annotations
tn, fp, fn, tp = cm.ravel()
plt.text(
0.5, -0.15, f"True Negatives: {tn}", ha="center", transform=plt.gca().transAxes
)
plt.text(
0.5, -0.20, f"False Positives: {fp}", ha="center", transform=plt.gca().transAxes
)
plt.text(
0.5, -0.25, f"False Negatives: {fn}", ha="center", transform=plt.gca().transAxes
)
plt.text(
0.5, -0.30, f"True Positives: {tp}", ha="center", transform=plt.gca().transAxes
)
plt.tight_layout()
plt.show()
# Create a summary table
print("\n" + "=" * 50)
print("CONFUSION MATRIX SUMMARY")
print("=" * 50)
summary_data = {
"Metric": [
"True Negatives",
"False Positives",
"False Negatives",
"True Positives",
],
"Count": [tn, fp, fn, tp],
"Description": [
"Correctly predicted as different",
"Incorrectly predicted as same (Type I error)",
"Incorrectly predicted as different (Type II error)",
"Correctly predicted as same",
],
}
df = pd.DataFrame(summary_data)
print(df.to_string(index=False))
print("=" * 50)
def plot_roc_curve(fpr, tpr, roc_auc):
"""
Plot ROC curve.
Args:
fpr: False positive rates
tpr: True positive rates
roc_auc: Area under ROC curve
"""
plt.figure(figsize=(8, 6))
plt.plot(
fpr, tpr, color="darkorange", lw=2, label=f"ROC curve (AUC = {roc_auc:.3f})"
)
plt.plot([0, 1], [0, 1], color="navy", lw=2, linestyle="--", label="Random")
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel("False Positive Rate")
plt.ylabel("True Positive Rate")
plt.title("Receiver Operating Characteristic (ROC) Curve")
plt.legend(loc="lower right")
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()
print(f"ROC AUC Score: {roc_auc:.4f}")
print("AUC Interpretation:")
print("0.90-1.00 = Excellent")
print("0.80-0.90 = Good")
print("0.70-0.80 = Fair")
print("0.60-0.70 = Poor")
print("0.50-0.60 = Fail")
def plot_probability_distribution(all_probabilities, all_targets):
"""
Plot probability distribution for positive and negative classes.
Args:
all_probabilities: List of predicted probabilities
all_targets: List of true labels
"""
# Separate probabilities by true class
pos_probs = [p for p, t in zip(all_probabilities, all_targets) if t == 1]
neg_probs = [p for p, t in zip(all_probabilities, all_targets) if t == 0]
plt.figure(figsize=(10, 6))
# Plot histograms
plt.hist(
pos_probs,
bins=30,
alpha=0.5,
color="green",
label="Same Domain (Positive)",
density=True,
)
plt.hist(
neg_probs,
bins=30,
alpha=0.5,
color="red",
label="Different Domain (Negative)",
density=True,
)
# Add vertical line at threshold 0.5
plt.axvline(
x=0.5,
color="black",
linestyle="--",
linewidth=2,
label="Decision Threshold (0.5)",
)
plt.xlabel("Predicted Probability")
plt.ylabel("Density")
plt.title("Probability Distribution by True Class")
plt.legend()
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()
# Print statistics
print("\nProbability Statistics:")
print(
f"Positive class (Same): Mean = {np.mean(pos_probs):.3f}, Std = {np.std(pos_probs):.3f}"
)
print(
f"Negative class (Different): Mean = {np.mean(neg_probs):.3f}, Std = {np.std(neg_probs):.3f}"
)
def test_model_on_examples(model, device, examples_dir="examples"):
"""
Test model on example image pairs.
Args:
model: Trained model
device: torch device
examples_dir: Directory containing example image pairs
"""
import cv2
from torchvision import transforms
model.eval()
# Define image preprocessing
transform = transforms.Compose(
[
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
]
)
# Check if examples directory exists
if not os.path.exists(examples_dir):
print(f"Examples directory '{examples_dir}' not found.")
print("Creating dummy examples for demonstration...")
# Create dummy example data for demonstration
examples = [
{
"name": "Example 1: Similar locations",
"google_img": torch.randn(1, 3, 224, 224),
"yandex_img": torch.randn(1, 3, 224, 224),
"expected": "Same",
},
{
"name": "Example 2: Different locations",
"google_img": torch.randn(1, 3, 224, 224),
"yandex_img": torch.randn(1, 3, 224, 224) * 2,
"expected": "Different",
},
]
else:
# In real implementation, load actual images
examples = []
print("\n" + "=" * 60)
print("MODEL TESTING ON EXAMPLES")
print("=" * 60)
results = []
for example in examples:
with torch.no_grad():
google_img = example["google_img"].to(device)
yandex_img = example["yandex_img"].to(device)
output = model(google_img, yandex_img)
probability = torch.sigmoid(output).item()
prediction = "Same" if probability > 0.5 else "Different"
result = {
"Example": example["name"],
"Predicted": prediction,
"Probability": probability,
"Expected": example.get("expected", "Unknown"),
"Correct": prediction == example.get("expected", "Unknown"),
}
results.append(result)
print(f"\n{example['name']}:")
print(f" Predicted: {prediction} (probability: {probability:.4f})")
print(f" Expected: {example.get('expected', 'Unknown')}")
print(f" Result: {'✓ CORRECT' if result['Correct'] else '✗ WRONG'}")
# Create results table
print("\n" + "=" * 60)
print("SUMMARY OF TEST RESULTS")
print("=" * 60)
df_results = pd.DataFrame(results)
print(df_results.to_string(index=False))
accuracy = df_results["Correct"].mean() * 100
print(f"\nTest Accuracy: {accuracy:.1f}%")
return df_results
def generate_performance_report(model, data_loader, device, output_dir="reports"):
"""
Generate a comprehensive performance report.
Args:
model: Trained model
data_loader: DataLoader with test data
device: torch device
output_dir: Directory to save reports
"""
os.makedirs(output_dir, exist_ok=True)
print("Generating performance report...")
# Analyze performance
metrics = analyze_model_performance(model, data_loader, device)
# Create report file
report_path = os.path.join(output_dir, "model_performance_report.txt")
with open(report_path, "w") as f:
f.write("=" * 60 + "\n")
f.write("MODEL PERFORMANCE REPORT\n")
f.write("=" * 60 + "\n\n")
f.write("1. BASIC METRICS\n")
f.write("-" * 40 + "\n")
f.write(f"Accuracy: {metrics['accuracy']:.4f}\n")
f.write(f"Precision: {metrics['precision']:.4f}\n")
f.write(f"Recall: {metrics['recall']:.4f}\n")
f.write(f"F1 Score: {metrics['f1_score']:.4f}\n")
f.write(f"ROC AUC: {metrics['roc_auc']:.4f}\n")
f.write(f"Optimal Threshold: {metrics['optimal_threshold']:.4f}\n\n")
f.write("2. CONFUSION MATRIX\n")
f.write("-" * 40 + "\n")
f.write(f"True Negatives: {metrics['true_negatives']}\n")
f.write(f"False Positives: {metrics['false_positives']}\n")
f.write(f"False Negatives: {metrics['false_negatives']}\n")
f.write(f"True Positives: {metrics['true_positives']}\n\n")
f.write("3. CLASSIFICATION REPORT\n")
f.write("-" * 40 + "\n")
f.write(metrics["classification_report"] + "\n")
f.write("4. INTERPRETATION\n")
f.write("-" * 40 + "\n")
f.write("Accuracy: Proportion of correct predictions\n")
f.write("Precision: Proportion of positive predictions that are correct\n")
f.write(
"Recall: Proportion of actual positives that are correctly identified\n"
)
f.write("F1 Score: Harmonic mean of precision and recall\n")
f.write("ROC AUC: Ability to distinguish between classes\n\n")
f.write("5. RECOMMENDATIONS\n")
f.write("-" * 40 + "\n")
if metrics["precision"] < 0.7:
f.write("- Improve precision to reduce false positives\n")
if metrics["recall"] < 0.7:
f.write("- Improve recall to reduce false negatives\n")
if metrics["f1_score"] < 0.7:
f.write("- Overall model performance needs improvement\n")
if metrics["roc_auc"] > 0.8:
f.write("- Good discrimination ability between classes\n")
else:
f.write("- Consider improving feature extraction\n")
print(f"Report saved to: {report_path}")
return metrics
def main():
"""
Main function to run evaluation and generate reports.
"""
print("=" * 60)
print("IMAGE SIMILARITY MODEL EVALUATION")
print("=" * 60)
# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
# Load configuration
config_dict = config.copy()
if isinstance(config_dict.get("image_size"), list):
config_dict["image_size"] = tuple(config_dict["image_size"])
# Create data loaders
print("\n1. Creating data loaders...")
_, val_loader = create_data_loaders(
root_dir=config_dict["data_dir"],
batch_size=config_dict["batch_size"],
train_split=config_dict["train_split"],
num_workers=config_dict["num_workers"],
image_size=config_dict["image_size"],
augment_train=False,
augment_val=False,
device=device,
)
print(f"Validation batches: {len(val_loader)}")
# Load trained model
print("\n2. Loading trained model...")
model = create_similarity_model(
model_type="cnn",
input_size=config_dict["image_size"][0]
if isinstance(config_dict["image_size"], (tuple, list))
else config_dict["image_size"],
input_channels=3,
hidden_channels=64,
num_blocks=4,
dropout_rate=0.3,
use_batch_norm=True,
)
# Load best checkpoint
checkpoint_dir = os.path.join(
config_dict.get("output_dir", "runs/similarity"), "checkpoints"
)
best_checkpoint = os.path.join(checkpoint_dir, "best_model.pt")
if os.path.exists(best_checkpoint):
checkpoint = torch.load(best_checkpoint, map_location=device)
model.load_state_dict(checkpoint["model_state_dict"])
print(f"Loaded best model from epoch {checkpoint['epoch']}")
print(f"Best validation loss: {checkpoint['val_loss']:.4f}")
else:
print("Warning: Best model checkpoint not found!")
print("Using randomly initialized model for demonstration.")
model = model.to(device)
# Plot training metrics
print("\n3. Plotting training metrics...")
plot_training_metrics(config_dict.get("output_dir", "runs/similarity"))
# Analyze model performance
print("\n4. Analyzing model performance...")
metrics = analyze_model_performance(model, val_loader, device)
# Display results
print("\n" + "=" * 60)
print("PERFORMANCE METRICS")
print("=" * 60)
print(f"Accuracy: {metrics['accuracy']:.4f}")
print(f"Precision: {metrics['precision']:.4f}")
print(f"Recall: {metrics['recall']:.4f}")
print(f"F1 Score: {metrics['f1_score']:.4f}")
print(f"ROC AUC: {metrics['roc_auc']:.4f}")
print(f"Optimal Threshold: {metrics['optimal_threshold']:.4f}")
# Plot confusion matrix
print("\n5. Plotting confusion matrix...")
plot_confusion_matrix(metrics["confusion_matrix"])
# Plot ROC curve
print("\n6. Plotting ROC curve...")
# For demonstration, we need to get probabilities again
model.eval()
all_probabilities = []
all_targets = []
with torch.no_grad():
for batch in val_loader:
google_img = batch["google_img"].to(device)
yandex_img = batch["yandex_img"].to(device)
target = batch["same_domain"].float().to(device)
output = model(google_img, yandex_img)
probabilities = torch.sigmoid(output).squeeze()
all_probabilities.extend(probabilities.cpu().numpy())
all_targets.extend(target.cpu().numpy())
all_probabilities = np.array(all_probabilities)
all_targets = np.array(all_targets)
fpr, tpr, _ = roc_curve(all_targets, all_probabilities)
roc_auc = auc(fpr, tpr)
plot_roc_curve(fpr, tpr, roc_auc)
# Plot probability distribution
print("\n7. Plotting probability distribution...")
plot_probability_distribution(all_probabilities, all_targets)
# Test on examples
print("\n8. Testing on examples...")
test_model_on_examples(model, device)
# Generate comprehensive report
print("\n9. Generating performance report...")
generate_performance_report(model, val_loader, device)
print("\n" + "=" * 60)
print("EVALUATION COMPLETED SUCCESSFULLY!")
print("=" * 60)
print("\nNext steps:")
print("1. Check the generated plots in the runs/similarity directory")
print("2. Review the performance report in the reports directory")
print("3. Consider adjusting the decision threshold if needed")
print("4. Retrain with different hyperparameters if performance is poor")
if __name__ == "__main__":
main()