341 lines
12 KiB
Python
341 lines
12 KiB
Python
"""
|
|
Demo Evaluation Notebook-style File
|
|
====================================
|
|
|
|
This file demonstrates how to use the evaluation functions from evaluation.py
|
|
in a notebook-like style. You can run this file directly to see all the plots
|
|
and analysis.
|
|
|
|
Think of this as the next cell in your notebook after training!
|
|
"""
|
|
|
|
import os
|
|
import sys
|
|
|
|
# Add the current directory to the path so we can import our modules
|
|
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
|
|
|
|
# Import our evaluation module
|
|
import matplotlib.pyplot as plt
|
|
import numpy as np
|
|
|
|
# Import other necessary modules
|
|
import torch
|
|
from dataloader import config, create_data_loaders
|
|
from evaluation import (
|
|
analyze_model_performance,
|
|
generate_performance_report,
|
|
plot_confusion_matrix,
|
|
plot_probability_distribution,
|
|
plot_roc_curve,
|
|
plot_training_metrics,
|
|
test_model_on_examples,
|
|
)
|
|
from model import create_similarity_model
|
|
|
|
print("=" * 70)
|
|
print("DEMO: EVALUATING IMAGE SIMILARITY MODEL")
|
|
print("=" * 70)
|
|
print("\nThis demo shows you how to analyze your trained model.")
|
|
print("Think of this as the 'results' section of your notebook!\n")
|
|
|
|
# ============================================================================
|
|
# STEP 1: SETUP
|
|
# ============================================================================
|
|
print("STEP 1: Setting up the environment")
|
|
print("-" * 40)
|
|
|
|
# Check if GPU is available
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
print(f"✓ Using device: {device}")
|
|
|
|
# Load configuration
|
|
config_dict = config.copy()
|
|
if isinstance(config_dict.get("image_size"), list):
|
|
config_dict["image_size"] = tuple(config_dict["image_size"])
|
|
|
|
print(f"✓ Image size: {config_dict['image_size']}")
|
|
print(f"✓ Batch size: {config_dict['batch_size']}")
|
|
|
|
# ============================================================================
|
|
# STEP 2: LOAD DATA
|
|
# ============================================================================
|
|
print("\nSTEP 2: Loading validation data")
|
|
print("-" * 40)
|
|
|
|
# Create validation data loader
|
|
_, val_loader = create_data_loaders(
|
|
root_dir=config_dict["data_dir"],
|
|
batch_size=config_dict["batch_size"],
|
|
train_split=config_dict["train_split"],
|
|
num_workers=config_dict["num_workers"],
|
|
image_size=config_dict["image_size"],
|
|
augment_train=False,
|
|
augment_val=False,
|
|
device=device,
|
|
)
|
|
|
|
print(f"✓ Validation batches loaded: {len(val_loader)}")
|
|
print(f"✓ Each batch has {config_dict['batch_size']} image pairs")
|
|
|
|
# ============================================================================
|
|
# STEP 3: LOAD TRAINED MODEL
|
|
# ============================================================================
|
|
print("\nSTEP 3: Loading the trained model")
|
|
print("-" * 40)
|
|
|
|
# Create model architecture
|
|
model = create_similarity_model(
|
|
model_type="cnn",
|
|
input_size=config_dict["image_size"][0],
|
|
input_channels=3,
|
|
hidden_channels=64,
|
|
num_blocks=4,
|
|
dropout_rate=0.3,
|
|
use_batch_norm=True,
|
|
)
|
|
|
|
# Try to load the best checkpoint
|
|
checkpoint_dir = os.path.join(
|
|
config_dict.get("output_dir", "runs/similarity"), "checkpoints"
|
|
)
|
|
best_checkpoint = os.path.join(checkpoint_dir, "best_model.pt")
|
|
|
|
if os.path.exists(best_checkpoint):
|
|
checkpoint = torch.load(best_checkpoint, map_location=device)
|
|
model.load_state_dict(checkpoint["model_state_dict"])
|
|
print(f"✓ Loaded best model from epoch {checkpoint['epoch']}")
|
|
print(f"✓ Best validation loss: {checkpoint['val_loss']:.4f}")
|
|
else:
|
|
print("⚠ Warning: Best model checkpoint not found!")
|
|
print(" Using randomly initialized model for demonstration.")
|
|
print(" (This is normal if you haven't trained the model yet)")
|
|
|
|
model = model.to(device)
|
|
print(f"✓ Model moved to {device}")
|
|
|
|
# Count parameters
|
|
total_params = sum(p.numel() for p in model.parameters())
|
|
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
|
print(f"✓ Total parameters: {total_params:,}")
|
|
print(f"✓ Trainable parameters: {trainable_params:,}")
|
|
|
|
# ============================================================================
|
|
# STEP 4: PLOT TRAINING METRICS
|
|
# ============================================================================
|
|
print("\nSTEP 4: Plotting training metrics")
|
|
print("-" * 40)
|
|
print("This shows how the model learned over time:")
|
|
|
|
# This will show 4 plots:
|
|
# 1. Training and validation loss
|
|
# 2. Training and validation accuracy
|
|
# 3. Overfitting indicator
|
|
# 4. Learning rate schedule
|
|
plot_training_metrics(config_dict.get("output_dir", "runs/similarity"))
|
|
|
|
print("✓ Training metrics plotted!")
|
|
print(" Look for 'training_metrics.png' in your runs directory")
|
|
|
|
# ============================================================================
|
|
# STEP 5: ANALYZE MODEL PERFORMANCE
|
|
# ============================================================================
|
|
print("\nSTEP 5: Analyzing model performance on validation set")
|
|
print("-" * 40)
|
|
print("Calculating metrics like accuracy, precision, recall, F1 score...")
|
|
|
|
# Analyze the model
|
|
metrics = analyze_model_performance(model, val_loader, device, threshold=0.5)
|
|
|
|
print("\n📊 PERFORMANCE METRICS:")
|
|
print(" Accuracy: {:.2%}".format(metrics["accuracy"]))
|
|
print(" Precision: {:.2%}".format(metrics["precision"]))
|
|
print(" Recall: {:.2%}".format(metrics["recall"]))
|
|
print(" F1 Score: {:.2%}".format(metrics["f1_score"]))
|
|
print(" ROC AUC: {:.4f}".format(metrics["roc_auc"]))
|
|
|
|
# ============================================================================
|
|
# STEP 6: SHOW CONFUSION MATRIX
|
|
# ============================================================================
|
|
print("\nSTEP 6: Confusion Matrix")
|
|
print("-" * 40)
|
|
print("This shows how many predictions were correct/wrong:")
|
|
|
|
plot_confusion_matrix(metrics["confusion_matrix"])
|
|
|
|
# ============================================================================
|
|
# STEP 7: ROC CURVE
|
|
# ============================================================================
|
|
print("\nSTEP 7: ROC Curve")
|
|
print("-" * 40)
|
|
print("This shows how well the model distinguishes between classes:")
|
|
|
|
# Get probabilities for ROC curve
|
|
model.eval()
|
|
all_probabilities = []
|
|
all_targets = []
|
|
|
|
with torch.no_grad():
|
|
for batch in val_loader:
|
|
google_img = batch["google_img"].to(device)
|
|
yandex_img = batch["yandex_img"].to(device)
|
|
target = batch["same_domain"].float().to(device)
|
|
|
|
output = model(google_img, yandex_img)
|
|
probabilities = torch.sigmoid(output).squeeze()
|
|
|
|
all_probabilities.extend(probabilities.cpu().numpy())
|
|
all_targets.extend(target.cpu().numpy())
|
|
|
|
all_probabilities = np.array(all_probabilities)
|
|
all_targets = np.array(all_targets)
|
|
|
|
from sklearn.metrics import auc, roc_curve
|
|
|
|
fpr, tpr, _ = roc_curve(all_targets, all_probabilities)
|
|
roc_auc = auc(fpr, tpr)
|
|
|
|
plot_roc_curve(fpr, tpr, roc_auc)
|
|
|
|
# ============================================================================
|
|
# STEP 8: PROBABILITY DISTRIBUTION
|
|
# ============================================================================
|
|
print("\nSTEP 8: Probability Distribution")
|
|
print("-" * 40)
|
|
print("This shows how confident the model is for different classes:")
|
|
|
|
plot_probability_distribution(all_probabilities, all_targets)
|
|
|
|
# ============================================================================
|
|
# STEP 9: TEST ON EXAMPLE IMAGES
|
|
# ============================================================================
|
|
print("\nSTEP 9: Testing on example images")
|
|
print("-" * 40)
|
|
print("Let's see how the model performs on some examples:")
|
|
|
|
test_results = test_model_on_examples(model, device)
|
|
|
|
# ============================================================================
|
|
# STEP 10: GENERATE REPORT
|
|
# ============================================================================
|
|
print("\nSTEP 10: Generating performance report")
|
|
print("-" * 40)
|
|
print("Creating a detailed report with all metrics...")
|
|
|
|
final_metrics = generate_performance_report(model, val_loader, device)
|
|
|
|
print("\n" + "=" * 70)
|
|
print("🎉 DEMO COMPLETED SUCCESSFULLY!")
|
|
print("=" * 70)
|
|
|
|
print("\n📁 What was created:")
|
|
print(" 1. Training metrics plots (saved to runs/similarity/)")
|
|
print(" 2. Confusion matrix visualization")
|
|
print(" 3. ROC curve plot")
|
|
print(" 4. Probability distribution plot")
|
|
print(" 5. Performance report (saved to reports/)")
|
|
|
|
print("\n🔍 Key things to check in your model:")
|
|
print(" ✓ Accuracy should be above 70% for a good model")
|
|
print(" ✓ Precision: High = few false positives")
|
|
print(" ✓ Recall: High = few false negatives")
|
|
print(" ✓ ROC AUC: Above 0.8 = good discrimination")
|
|
|
|
print("\n🔄 If results are poor, try:")
|
|
print(" 1. Train for more epochs")
|
|
print(" 2. Adjust learning rate")
|
|
print(" 3. Use more training data")
|
|
print(" 4. Try different model architecture")
|
|
|
|
print(
|
|
"\n💡 Pro tip: The optimal threshold is {:.3f}".format(
|
|
final_metrics["optimal_threshold"]
|
|
)
|
|
)
|
|
print(" You can use this instead of 0.5 for better results!")
|
|
|
|
# ============================================================================
|
|
# BONUS: QUICK DIAGNOSTICS TABLE
|
|
# ============================================================================
|
|
print("\n" + "=" * 70)
|
|
print("BONUS: Quick Diagnostics Table")
|
|
print("=" * 70)
|
|
|
|
# Create a simple table of what each metric means
|
|
diagnostics = [
|
|
["Metric", "Value", "What it means", "Is it good?"],
|
|
["-" * 15, "-" * 10, "-" * 30, "-" * 15],
|
|
["Accuracy", f"{metrics['accuracy']:.2%}", "Overall correctness", ">70% is good"],
|
|
["Precision", f"{metrics['precision']:.2%}", "Few false positives", ">70% is good"],
|
|
["Recall", f"{metrics['recall']:.2%}", "Few false negatives", ">70% is good"],
|
|
[
|
|
"F1 Score",
|
|
f"{metrics['f1_score']:.2%}",
|
|
"Balance of precision/recall",
|
|
">70% is good",
|
|
],
|
|
["ROC AUC", f"{metrics['roc_auc']:.4f}", "Discrimination ability", ">0.8 is good"],
|
|
]
|
|
|
|
for row in diagnostics:
|
|
print("{:<15} {:<10} {:<30} {:<15}".format(*row))
|
|
|
|
print("\n" + "=" * 70)
|
|
print("To run this again, just execute: python demo_evaluation.ipynb.py")
|
|
print("=" * 70)
|
|
|
|
# ============================================================================
|
|
# EXTRA: SAVE PREDICTIONS FOR FURTHER ANALYSIS
|
|
# ============================================================================
|
|
print("\n💾 Saving predictions for further analysis...")
|
|
|
|
# Get all predictions
|
|
model.eval()
|
|
all_predictions = []
|
|
all_targets = []
|
|
all_probabilities = []
|
|
image_indices = []
|
|
|
|
with torch.no_grad():
|
|
for batch_idx, batch in enumerate(val_loader):
|
|
google_img = batch["google_img"].to(device)
|
|
yandex_img = batch["yandex_img"].to(device)
|
|
target = batch["same_domain"].float().to(device)
|
|
|
|
output = model(google_img, yandex_img)
|
|
probabilities = torch.sigmoid(output).squeeze()
|
|
predictions = (probabilities > 0.5).float()
|
|
|
|
all_predictions.extend(predictions.cpu().numpy())
|
|
all_targets.extend(target.cpu().numpy())
|
|
all_probabilities.extend(probabilities.cpu().numpy())
|
|
image_indices.extend(
|
|
range(batch_idx * len(target), (batch_idx + 1) * len(target))
|
|
)
|
|
|
|
# Save to CSV for further analysis
|
|
import pandas as pd
|
|
|
|
predictions_df = pd.DataFrame(
|
|
{
|
|
"image_pair_index": image_indices,
|
|
"true_label": all_targets,
|
|
"predicted_label": all_predictions,
|
|
"probability": all_probabilities,
|
|
"correct": np.array(all_targets) == np.array(all_predictions),
|
|
}
|
|
)
|
|
|
|
predictions_path = os.path.join(
|
|
config_dict.get("output_dir", "runs/similarity"), "predictions_analysis.csv"
|
|
)
|
|
predictions_df.to_csv(predictions_path, index=False)
|
|
print(f"✓ Predictions saved to: {predictions_path}")
|
|
print(f"✓ Total predictions: {len(predictions_df)}")
|
|
print(
|
|
f"✓ Correct predictions: {predictions_df['correct'].sum()} ({predictions_df['correct'].mean():.2%})"
|
|
)
|
|
|
|
print("\n" + "🎯 You can now analyze individual predictions in the CSV file!")
|
|
print(" Look for patterns in the mistakes your model makes.")
|