Files
autopilot/models/SiaN-similarity/demo_evaluation.ipynb.py

341 lines
12 KiB
Python

"""
Demo Evaluation Notebook-style File
====================================
This file demonstrates how to use the evaluation functions from evaluation.py
in a notebook-like style. You can run this file directly to see all the plots
and analysis.
Think of this as the next cell in your notebook after training!
"""
import os
import sys
# Add the current directory to the path so we can import our modules
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
# Import our evaluation module
import matplotlib.pyplot as plt
import numpy as np
# Import other necessary modules
import torch
from dataloader import config, create_data_loaders
from evaluation import (
analyze_model_performance,
generate_performance_report,
plot_confusion_matrix,
plot_probability_distribution,
plot_roc_curve,
plot_training_metrics,
test_model_on_examples,
)
from model import create_similarity_model
print("=" * 70)
print("DEMO: EVALUATING IMAGE SIMILARITY MODEL")
print("=" * 70)
print("\nThis demo shows you how to analyze your trained model.")
print("Think of this as the 'results' section of your notebook!\n")
# ============================================================================
# STEP 1: SETUP
# ============================================================================
print("STEP 1: Setting up the environment")
print("-" * 40)
# Check if GPU is available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"✓ Using device: {device}")
# Load configuration
config_dict = config.copy()
if isinstance(config_dict.get("image_size"), list):
config_dict["image_size"] = tuple(config_dict["image_size"])
print(f"✓ Image size: {config_dict['image_size']}")
print(f"✓ Batch size: {config_dict['batch_size']}")
# ============================================================================
# STEP 2: LOAD DATA
# ============================================================================
print("\nSTEP 2: Loading validation data")
print("-" * 40)
# Create validation data loader
_, val_loader = create_data_loaders(
root_dir=config_dict["data_dir"],
batch_size=config_dict["batch_size"],
train_split=config_dict["train_split"],
num_workers=config_dict["num_workers"],
image_size=config_dict["image_size"],
augment_train=False,
augment_val=False,
device=device,
)
print(f"✓ Validation batches loaded: {len(val_loader)}")
print(f"✓ Each batch has {config_dict['batch_size']} image pairs")
# ============================================================================
# STEP 3: LOAD TRAINED MODEL
# ============================================================================
print("\nSTEP 3: Loading the trained model")
print("-" * 40)
# Create model architecture
model = create_similarity_model(
model_type="cnn",
input_size=config_dict["image_size"][0],
input_channels=3,
hidden_channels=64,
num_blocks=4,
dropout_rate=0.3,
use_batch_norm=True,
)
# Try to load the best checkpoint
checkpoint_dir = os.path.join(
config_dict.get("output_dir", "runs/similarity"), "checkpoints"
)
best_checkpoint = os.path.join(checkpoint_dir, "best_model.pt")
if os.path.exists(best_checkpoint):
checkpoint = torch.load(best_checkpoint, map_location=device)
model.load_state_dict(checkpoint["model_state_dict"])
print(f"✓ Loaded best model from epoch {checkpoint['epoch']}")
print(f"✓ Best validation loss: {checkpoint['val_loss']:.4f}")
else:
print("⚠ Warning: Best model checkpoint not found!")
print(" Using randomly initialized model for demonstration.")
print(" (This is normal if you haven't trained the model yet)")
model = model.to(device)
print(f"✓ Model moved to {device}")
# Count parameters
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"✓ Total parameters: {total_params:,}")
print(f"✓ Trainable parameters: {trainable_params:,}")
# ============================================================================
# STEP 4: PLOT TRAINING METRICS
# ============================================================================
print("\nSTEP 4: Plotting training metrics")
print("-" * 40)
print("This shows how the model learned over time:")
# This will show 4 plots:
# 1. Training and validation loss
# 2. Training and validation accuracy
# 3. Overfitting indicator
# 4. Learning rate schedule
plot_training_metrics(config_dict.get("output_dir", "runs/similarity"))
print("✓ Training metrics plotted!")
print(" Look for 'training_metrics.png' in your runs directory")
# ============================================================================
# STEP 5: ANALYZE MODEL PERFORMANCE
# ============================================================================
print("\nSTEP 5: Analyzing model performance on validation set")
print("-" * 40)
print("Calculating metrics like accuracy, precision, recall, F1 score...")
# Analyze the model
metrics = analyze_model_performance(model, val_loader, device, threshold=0.5)
print("\n📊 PERFORMANCE METRICS:")
print(" Accuracy: {:.2%}".format(metrics["accuracy"]))
print(" Precision: {:.2%}".format(metrics["precision"]))
print(" Recall: {:.2%}".format(metrics["recall"]))
print(" F1 Score: {:.2%}".format(metrics["f1_score"]))
print(" ROC AUC: {:.4f}".format(metrics["roc_auc"]))
# ============================================================================
# STEP 6: SHOW CONFUSION MATRIX
# ============================================================================
print("\nSTEP 6: Confusion Matrix")
print("-" * 40)
print("This shows how many predictions were correct/wrong:")
plot_confusion_matrix(metrics["confusion_matrix"])
# ============================================================================
# STEP 7: ROC CURVE
# ============================================================================
print("\nSTEP 7: ROC Curve")
print("-" * 40)
print("This shows how well the model distinguishes between classes:")
# Get probabilities for ROC curve
model.eval()
all_probabilities = []
all_targets = []
with torch.no_grad():
for batch in val_loader:
google_img = batch["google_img"].to(device)
yandex_img = batch["yandex_img"].to(device)
target = batch["same_domain"].float().to(device)
output = model(google_img, yandex_img)
probabilities = torch.sigmoid(output).squeeze()
all_probabilities.extend(probabilities.cpu().numpy())
all_targets.extend(target.cpu().numpy())
all_probabilities = np.array(all_probabilities)
all_targets = np.array(all_targets)
from sklearn.metrics import auc, roc_curve
fpr, tpr, _ = roc_curve(all_targets, all_probabilities)
roc_auc = auc(fpr, tpr)
plot_roc_curve(fpr, tpr, roc_auc)
# ============================================================================
# STEP 8: PROBABILITY DISTRIBUTION
# ============================================================================
print("\nSTEP 8: Probability Distribution")
print("-" * 40)
print("This shows how confident the model is for different classes:")
plot_probability_distribution(all_probabilities, all_targets)
# ============================================================================
# STEP 9: TEST ON EXAMPLE IMAGES
# ============================================================================
print("\nSTEP 9: Testing on example images")
print("-" * 40)
print("Let's see how the model performs on some examples:")
test_results = test_model_on_examples(model, device)
# ============================================================================
# STEP 10: GENERATE REPORT
# ============================================================================
print("\nSTEP 10: Generating performance report")
print("-" * 40)
print("Creating a detailed report with all metrics...")
final_metrics = generate_performance_report(model, val_loader, device)
print("\n" + "=" * 70)
print("🎉 DEMO COMPLETED SUCCESSFULLY!")
print("=" * 70)
print("\n📁 What was created:")
print(" 1. Training metrics plots (saved to runs/similarity/)")
print(" 2. Confusion matrix visualization")
print(" 3. ROC curve plot")
print(" 4. Probability distribution plot")
print(" 5. Performance report (saved to reports/)")
print("\n🔍 Key things to check in your model:")
print(" ✓ Accuracy should be above 70% for a good model")
print(" ✓ Precision: High = few false positives")
print(" ✓ Recall: High = few false negatives")
print(" ✓ ROC AUC: Above 0.8 = good discrimination")
print("\n🔄 If results are poor, try:")
print(" 1. Train for more epochs")
print(" 2. Adjust learning rate")
print(" 3. Use more training data")
print(" 4. Try different model architecture")
print(
"\n💡 Pro tip: The optimal threshold is {:.3f}".format(
final_metrics["optimal_threshold"]
)
)
print(" You can use this instead of 0.5 for better results!")
# ============================================================================
# BONUS: QUICK DIAGNOSTICS TABLE
# ============================================================================
print("\n" + "=" * 70)
print("BONUS: Quick Diagnostics Table")
print("=" * 70)
# Create a simple table of what each metric means
diagnostics = [
["Metric", "Value", "What it means", "Is it good?"],
["-" * 15, "-" * 10, "-" * 30, "-" * 15],
["Accuracy", f"{metrics['accuracy']:.2%}", "Overall correctness", ">70% is good"],
["Precision", f"{metrics['precision']:.2%}", "Few false positives", ">70% is good"],
["Recall", f"{metrics['recall']:.2%}", "Few false negatives", ">70% is good"],
[
"F1 Score",
f"{metrics['f1_score']:.2%}",
"Balance of precision/recall",
">70% is good",
],
["ROC AUC", f"{metrics['roc_auc']:.4f}", "Discrimination ability", ">0.8 is good"],
]
for row in diagnostics:
print("{:<15} {:<10} {:<30} {:<15}".format(*row))
print("\n" + "=" * 70)
print("To run this again, just execute: python demo_evaluation.ipynb.py")
print("=" * 70)
# ============================================================================
# EXTRA: SAVE PREDICTIONS FOR FURTHER ANALYSIS
# ============================================================================
print("\n💾 Saving predictions for further analysis...")
# Get all predictions
model.eval()
all_predictions = []
all_targets = []
all_probabilities = []
image_indices = []
with torch.no_grad():
for batch_idx, batch in enumerate(val_loader):
google_img = batch["google_img"].to(device)
yandex_img = batch["yandex_img"].to(device)
target = batch["same_domain"].float().to(device)
output = model(google_img, yandex_img)
probabilities = torch.sigmoid(output).squeeze()
predictions = (probabilities > 0.5).float()
all_predictions.extend(predictions.cpu().numpy())
all_targets.extend(target.cpu().numpy())
all_probabilities.extend(probabilities.cpu().numpy())
image_indices.extend(
range(batch_idx * len(target), (batch_idx + 1) * len(target))
)
# Save to CSV for further analysis
import pandas as pd
predictions_df = pd.DataFrame(
{
"image_pair_index": image_indices,
"true_label": all_targets,
"predicted_label": all_predictions,
"probability": all_probabilities,
"correct": np.array(all_targets) == np.array(all_predictions),
}
)
predictions_path = os.path.join(
config_dict.get("output_dir", "runs/similarity"), "predictions_analysis.csv"
)
predictions_df.to_csv(predictions_path, index=False)
print(f"✓ Predictions saved to: {predictions_path}")
print(f"✓ Total predictions: {len(predictions_df)}")
print(
f"✓ Correct predictions: {predictions_df['correct'].sum()} ({predictions_df['correct'].mean():.2%})"
)
print("\n" + "🎯 You can now analyze individual predictions in the CSV file!")
print(" Look for patterns in the mistakes your model makes.")