664 lines
21 KiB
Python
664 lines
21 KiB
Python
"""
|
|
Evaluation and visualization for image similarity model.
|
|
This file contains code for plotting training metrics, analyzing model performance,
|
|
and testing the trained model.
|
|
"""
|
|
|
|
import os
|
|
|
|
import matplotlib.pyplot as plt
|
|
import numpy as np
|
|
import pandas as pd
|
|
import seaborn as sns
|
|
import torch
|
|
import torch.nn as nn
|
|
from dataloader import config, create_data_loaders
|
|
from model import create_similarity_model
|
|
from sklearn.metrics import auc, classification_report, confusion_matrix, roc_curve
|
|
from torch.utils.data import DataLoader
|
|
from train import SimilarityTrainer
|
|
|
|
# Set style for plots
|
|
plt.style.use("seaborn-v0_8-darkgrid")
|
|
sns.set_palette("husl")
|
|
|
|
|
|
def plot_training_metrics(log_dir="runs/similarity"):
|
|
"""
|
|
Plot training and validation metrics from TensorBoard logs or saved metrics.
|
|
|
|
Args:
|
|
log_dir: Directory containing training logs
|
|
"""
|
|
# In a real scenario, we would read from TensorBoard logs
|
|
# For this example, we'll create simulated data to show what plots would look like
|
|
|
|
# Simulated training data (in reality, you would load this from logs)
|
|
epochs = list(range(1, 51))
|
|
|
|
# Simulated metrics
|
|
train_loss = [0.8 - 0.015 * i + np.random.normal(0, 0.02) for i in range(50)]
|
|
val_loss = [0.75 - 0.012 * i + np.random.normal(0, 0.03) for i in range(50)]
|
|
train_acc = [0.55 + 0.008 * i + np.random.normal(0, 0.01) for i in range(50)]
|
|
val_acc = [0.6 + 0.006 * i + np.random.normal(0, 0.015) for i in range(50)]
|
|
|
|
# Create figure with subplots
|
|
fig, axes = plt.subplots(2, 2, figsize=(14, 10))
|
|
|
|
# Plot 1: Training and Validation Loss
|
|
axes[0, 0].plot(epochs, train_loss, "b-", linewidth=2, label="Training Loss")
|
|
axes[0, 0].plot(epochs, val_loss, "r-", linewidth=2, label="Validation Loss")
|
|
axes[0, 0].set_xlabel("Epoch")
|
|
axes[0, 0].set_ylabel("Loss")
|
|
axes[0, 0].set_title("Training and Validation Loss")
|
|
axes[0, 0].legend()
|
|
axes[0, 0].grid(True, alpha=0.3)
|
|
|
|
# Plot 2: Training and Validation Accuracy
|
|
axes[0, 1].plot(epochs, train_acc, "b-", linewidth=2, label="Training Accuracy")
|
|
axes[0, 1].plot(epochs, val_acc, "r-", linewidth=2, label="Validation Accuracy")
|
|
axes[0, 1].set_xlabel("Epoch")
|
|
axes[0, 1].set_ylabel("Accuracy")
|
|
axes[0, 1].set_title("Training and Validation Accuracy")
|
|
axes[0, 1].legend()
|
|
axes[0, 1].grid(True, alpha=0.3)
|
|
|
|
# Plot 3: Loss difference (train - val)
|
|
loss_diff = [t - v for t, v in zip(train_loss, val_loss)]
|
|
axes[1, 0].plot(epochs, loss_diff, "g-", linewidth=2)
|
|
axes[1, 0].axhline(y=0, color="r", linestyle="--", alpha=0.5)
|
|
axes[1, 0].fill_between(
|
|
epochs,
|
|
0,
|
|
loss_diff,
|
|
where=np.array(loss_diff) > 0,
|
|
alpha=0.3,
|
|
color="red",
|
|
label="Overfitting (train > val)",
|
|
)
|
|
axes[1, 0].fill_between(
|
|
epochs,
|
|
0,
|
|
loss_diff,
|
|
where=np.array(loss_diff) < 0,
|
|
alpha=0.3,
|
|
color="green",
|
|
label="Underfitting (train < val)",
|
|
)
|
|
axes[1, 0].set_xlabel("Epoch")
|
|
axes[1, 0].set_ylabel("Loss Difference")
|
|
axes[1, 0].set_title("Train Loss - Val Loss (Overfitting Indicator)")
|
|
axes[1, 0].legend()
|
|
axes[1, 0].grid(True, alpha=0.3)
|
|
|
|
# Plot 4: Learning rate schedule (if available)
|
|
axes[1, 1].plot(
|
|
epochs, [0.0002 * (0.95**i) for i in range(50)], "purple-", linewidth=2
|
|
)
|
|
axes[1, 1].set_xlabel("Epoch")
|
|
axes[1, 1].set_ylabel("Learning Rate")
|
|
axes[1, 1].set_title("Learning Rate Schedule")
|
|
axes[1, 1].set_yscale("log")
|
|
axes[1, 1].grid(True, alpha=0.3)
|
|
|
|
plt.tight_layout()
|
|
plt.savefig(
|
|
os.path.join(log_dir, "training_metrics.png"), dpi=150, bbox_inches="tight"
|
|
)
|
|
plt.show()
|
|
|
|
print(
|
|
"Training metrics plots saved to:",
|
|
os.path.join(log_dir, "training_metrics.png"),
|
|
)
|
|
|
|
|
|
def analyze_model_performance(model, data_loader, device, threshold=0.5):
|
|
"""
|
|
Analyze model performance on a dataset.
|
|
|
|
Args:
|
|
model: Trained model
|
|
data_loader: DataLoader with test/validation data
|
|
device: torch device
|
|
threshold: Decision threshold for binary classification
|
|
|
|
Returns:
|
|
Dictionary with performance metrics
|
|
"""
|
|
model.eval()
|
|
|
|
all_predictions = []
|
|
all_targets = []
|
|
all_probabilities = []
|
|
|
|
with torch.no_grad():
|
|
for batch in data_loader:
|
|
google_img = batch["google_img"].to(device)
|
|
yandex_img = batch["yandex_img"].to(device)
|
|
target = batch["same_domain"].float().to(device)
|
|
|
|
output = model(google_img, yandex_img)
|
|
probabilities = torch.sigmoid(output).squeeze()
|
|
|
|
predictions = (probabilities > threshold).float()
|
|
|
|
all_predictions.extend(predictions.cpu().numpy())
|
|
all_targets.extend(target.cpu().numpy())
|
|
all_probabilities.extend(probabilities.cpu().numpy())
|
|
|
|
# Convert to numpy arrays
|
|
all_predictions = np.array(all_predictions)
|
|
all_targets = np.array(all_targets)
|
|
all_probabilities = np.array(all_probabilities)
|
|
|
|
# Calculate confusion matrix
|
|
cm = confusion_matrix(all_targets, all_predictions)
|
|
|
|
# Calculate metrics
|
|
tn, fp, fn, tp = cm.ravel()
|
|
|
|
accuracy = (tp + tn) / (tp + tn + fp + fn)
|
|
precision = tp / (tp + fp) if (tp + fp) > 0 else 0
|
|
recall = tp / (tp + fn) if (tp + fn) > 0 else 0
|
|
f1_score = (
|
|
2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0
|
|
)
|
|
|
|
# Create classification report
|
|
report = classification_report(
|
|
all_targets, all_predictions, target_names=["Different", "Same"]
|
|
)
|
|
|
|
# Calculate ROC curve
|
|
fpr, tpr, thresholds = roc_curve(all_targets, all_probabilities)
|
|
roc_auc = auc(fpr, tpr)
|
|
|
|
# Find optimal threshold (Youden's J statistic)
|
|
youden_j = tpr - fpr
|
|
optimal_idx = np.argmax(youden_j)
|
|
optimal_threshold = thresholds[optimal_idx]
|
|
|
|
metrics = {
|
|
"confusion_matrix": cm,
|
|
"accuracy": accuracy,
|
|
"precision": precision,
|
|
"recall": recall,
|
|
"f1_score": f1_score,
|
|
"roc_auc": roc_auc,
|
|
"optimal_threshold": optimal_threshold,
|
|
"classification_report": report,
|
|
"true_negatives": tn,
|
|
"false_positives": fp,
|
|
"false_negatives": fn,
|
|
"true_positives": tp,
|
|
}
|
|
|
|
return metrics
|
|
|
|
|
|
def plot_confusion_matrix(cm, class_names=["Different", "Same"]):
|
|
"""
|
|
Plot confusion matrix with annotations.
|
|
|
|
Args:
|
|
cm: Confusion matrix
|
|
class_names: List of class names
|
|
"""
|
|
plt.figure(figsize=(8, 6))
|
|
|
|
# Create heatmap
|
|
sns.heatmap(
|
|
cm,
|
|
annot=True,
|
|
fmt="d",
|
|
cmap="Blues",
|
|
xticklabels=class_names,
|
|
yticklabels=class_names,
|
|
)
|
|
|
|
plt.title("Confusion Matrix")
|
|
plt.ylabel("True Label")
|
|
plt.xlabel("Predicted Label")
|
|
|
|
# Add text annotations
|
|
tn, fp, fn, tp = cm.ravel()
|
|
|
|
plt.text(
|
|
0.5, -0.15, f"True Negatives: {tn}", ha="center", transform=plt.gca().transAxes
|
|
)
|
|
plt.text(
|
|
0.5, -0.20, f"False Positives: {fp}", ha="center", transform=plt.gca().transAxes
|
|
)
|
|
plt.text(
|
|
0.5, -0.25, f"False Negatives: {fn}", ha="center", transform=plt.gca().transAxes
|
|
)
|
|
plt.text(
|
|
0.5, -0.30, f"True Positives: {tp}", ha="center", transform=plt.gca().transAxes
|
|
)
|
|
|
|
plt.tight_layout()
|
|
plt.show()
|
|
|
|
# Create a summary table
|
|
print("\n" + "=" * 50)
|
|
print("CONFUSION MATRIX SUMMARY")
|
|
print("=" * 50)
|
|
|
|
summary_data = {
|
|
"Metric": [
|
|
"True Negatives",
|
|
"False Positives",
|
|
"False Negatives",
|
|
"True Positives",
|
|
],
|
|
"Count": [tn, fp, fn, tp],
|
|
"Description": [
|
|
"Correctly predicted as different",
|
|
"Incorrectly predicted as same (Type I error)",
|
|
"Incorrectly predicted as different (Type II error)",
|
|
"Correctly predicted as same",
|
|
],
|
|
}
|
|
|
|
df = pd.DataFrame(summary_data)
|
|
print(df.to_string(index=False))
|
|
print("=" * 50)
|
|
|
|
|
|
def plot_roc_curve(fpr, tpr, roc_auc):
|
|
"""
|
|
Plot ROC curve.
|
|
|
|
Args:
|
|
fpr: False positive rates
|
|
tpr: True positive rates
|
|
roc_auc: Area under ROC curve
|
|
"""
|
|
plt.figure(figsize=(8, 6))
|
|
|
|
plt.plot(
|
|
fpr, tpr, color="darkorange", lw=2, label=f"ROC curve (AUC = {roc_auc:.3f})"
|
|
)
|
|
plt.plot([0, 1], [0, 1], color="navy", lw=2, linestyle="--", label="Random")
|
|
|
|
plt.xlim([0.0, 1.0])
|
|
plt.ylim([0.0, 1.05])
|
|
plt.xlabel("False Positive Rate")
|
|
plt.ylabel("True Positive Rate")
|
|
plt.title("Receiver Operating Characteristic (ROC) Curve")
|
|
plt.legend(loc="lower right")
|
|
plt.grid(True, alpha=0.3)
|
|
|
|
plt.tight_layout()
|
|
plt.show()
|
|
|
|
print(f"ROC AUC Score: {roc_auc:.4f}")
|
|
print("AUC Interpretation:")
|
|
print("0.90-1.00 = Excellent")
|
|
print("0.80-0.90 = Good")
|
|
print("0.70-0.80 = Fair")
|
|
print("0.60-0.70 = Poor")
|
|
print("0.50-0.60 = Fail")
|
|
|
|
|
|
def plot_probability_distribution(all_probabilities, all_targets):
|
|
"""
|
|
Plot probability distribution for positive and negative classes.
|
|
|
|
Args:
|
|
all_probabilities: List of predicted probabilities
|
|
all_targets: List of true labels
|
|
"""
|
|
# Separate probabilities by true class
|
|
pos_probs = [p for p, t in zip(all_probabilities, all_targets) if t == 1]
|
|
neg_probs = [p for p, t in zip(all_probabilities, all_targets) if t == 0]
|
|
|
|
plt.figure(figsize=(10, 6))
|
|
|
|
# Plot histograms
|
|
plt.hist(
|
|
pos_probs,
|
|
bins=30,
|
|
alpha=0.5,
|
|
color="green",
|
|
label="Same Domain (Positive)",
|
|
density=True,
|
|
)
|
|
plt.hist(
|
|
neg_probs,
|
|
bins=30,
|
|
alpha=0.5,
|
|
color="red",
|
|
label="Different Domain (Negative)",
|
|
density=True,
|
|
)
|
|
|
|
# Add vertical line at threshold 0.5
|
|
plt.axvline(
|
|
x=0.5,
|
|
color="black",
|
|
linestyle="--",
|
|
linewidth=2,
|
|
label="Decision Threshold (0.5)",
|
|
)
|
|
|
|
plt.xlabel("Predicted Probability")
|
|
plt.ylabel("Density")
|
|
plt.title("Probability Distribution by True Class")
|
|
plt.legend()
|
|
plt.grid(True, alpha=0.3)
|
|
|
|
plt.tight_layout()
|
|
plt.show()
|
|
|
|
# Print statistics
|
|
print("\nProbability Statistics:")
|
|
print(
|
|
f"Positive class (Same): Mean = {np.mean(pos_probs):.3f}, Std = {np.std(pos_probs):.3f}"
|
|
)
|
|
print(
|
|
f"Negative class (Different): Mean = {np.mean(neg_probs):.3f}, Std = {np.std(neg_probs):.3f}"
|
|
)
|
|
|
|
|
|
def test_model_on_examples(model, device, examples_dir="examples"):
|
|
"""
|
|
Test model on example image pairs.
|
|
|
|
Args:
|
|
model: Trained model
|
|
device: torch device
|
|
examples_dir: Directory containing example image pairs
|
|
"""
|
|
import cv2
|
|
from torchvision import transforms
|
|
|
|
model.eval()
|
|
|
|
# Define image preprocessing
|
|
transform = transforms.Compose(
|
|
[
|
|
transforms.ToTensor(),
|
|
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
|
]
|
|
)
|
|
|
|
# Check if examples directory exists
|
|
if not os.path.exists(examples_dir):
|
|
print(f"Examples directory '{examples_dir}' not found.")
|
|
print("Creating dummy examples for demonstration...")
|
|
|
|
# Create dummy example data for demonstration
|
|
examples = [
|
|
{
|
|
"name": "Example 1: Similar locations",
|
|
"google_img": torch.randn(1, 3, 224, 224),
|
|
"yandex_img": torch.randn(1, 3, 224, 224),
|
|
"expected": "Same",
|
|
},
|
|
{
|
|
"name": "Example 2: Different locations",
|
|
"google_img": torch.randn(1, 3, 224, 224),
|
|
"yandex_img": torch.randn(1, 3, 224, 224) * 2,
|
|
"expected": "Different",
|
|
},
|
|
]
|
|
else:
|
|
# In real implementation, load actual images
|
|
examples = []
|
|
|
|
print("\n" + "=" * 60)
|
|
print("MODEL TESTING ON EXAMPLES")
|
|
print("=" * 60)
|
|
|
|
results = []
|
|
|
|
for example in examples:
|
|
with torch.no_grad():
|
|
google_img = example["google_img"].to(device)
|
|
yandex_img = example["yandex_img"].to(device)
|
|
|
|
output = model(google_img, yandex_img)
|
|
probability = torch.sigmoid(output).item()
|
|
prediction = "Same" if probability > 0.5 else "Different"
|
|
|
|
result = {
|
|
"Example": example["name"],
|
|
"Predicted": prediction,
|
|
"Probability": probability,
|
|
"Expected": example.get("expected", "Unknown"),
|
|
"Correct": prediction == example.get("expected", "Unknown"),
|
|
}
|
|
|
|
results.append(result)
|
|
|
|
print(f"\n{example['name']}:")
|
|
print(f" Predicted: {prediction} (probability: {probability:.4f})")
|
|
print(f" Expected: {example.get('expected', 'Unknown')}")
|
|
print(f" Result: {'✓ CORRECT' if result['Correct'] else '✗ WRONG'}")
|
|
|
|
# Create results table
|
|
print("\n" + "=" * 60)
|
|
print("SUMMARY OF TEST RESULTS")
|
|
print("=" * 60)
|
|
|
|
df_results = pd.DataFrame(results)
|
|
print(df_results.to_string(index=False))
|
|
|
|
accuracy = df_results["Correct"].mean() * 100
|
|
print(f"\nTest Accuracy: {accuracy:.1f}%")
|
|
|
|
return df_results
|
|
|
|
|
|
def generate_performance_report(model, data_loader, device, output_dir="reports"):
|
|
"""
|
|
Generate a comprehensive performance report.
|
|
|
|
Args:
|
|
model: Trained model
|
|
data_loader: DataLoader with test data
|
|
device: torch device
|
|
output_dir: Directory to save reports
|
|
"""
|
|
os.makedirs(output_dir, exist_ok=True)
|
|
|
|
print("Generating performance report...")
|
|
|
|
# Analyze performance
|
|
metrics = analyze_model_performance(model, data_loader, device)
|
|
|
|
# Create report file
|
|
report_path = os.path.join(output_dir, "model_performance_report.txt")
|
|
|
|
with open(report_path, "w") as f:
|
|
f.write("=" * 60 + "\n")
|
|
f.write("MODEL PERFORMANCE REPORT\n")
|
|
f.write("=" * 60 + "\n\n")
|
|
|
|
f.write("1. BASIC METRICS\n")
|
|
f.write("-" * 40 + "\n")
|
|
f.write(f"Accuracy: {metrics['accuracy']:.4f}\n")
|
|
f.write(f"Precision: {metrics['precision']:.4f}\n")
|
|
f.write(f"Recall: {metrics['recall']:.4f}\n")
|
|
f.write(f"F1 Score: {metrics['f1_score']:.4f}\n")
|
|
f.write(f"ROC AUC: {metrics['roc_auc']:.4f}\n")
|
|
f.write(f"Optimal Threshold: {metrics['optimal_threshold']:.4f}\n\n")
|
|
|
|
f.write("2. CONFUSION MATRIX\n")
|
|
f.write("-" * 40 + "\n")
|
|
f.write(f"True Negatives: {metrics['true_negatives']}\n")
|
|
f.write(f"False Positives: {metrics['false_positives']}\n")
|
|
f.write(f"False Negatives: {metrics['false_negatives']}\n")
|
|
f.write(f"True Positives: {metrics['true_positives']}\n\n")
|
|
|
|
f.write("3. CLASSIFICATION REPORT\n")
|
|
f.write("-" * 40 + "\n")
|
|
f.write(metrics["classification_report"] + "\n")
|
|
|
|
f.write("4. INTERPRETATION\n")
|
|
f.write("-" * 40 + "\n")
|
|
f.write("Accuracy: Proportion of correct predictions\n")
|
|
f.write("Precision: Proportion of positive predictions that are correct\n")
|
|
f.write(
|
|
"Recall: Proportion of actual positives that are correctly identified\n"
|
|
)
|
|
f.write("F1 Score: Harmonic mean of precision and recall\n")
|
|
f.write("ROC AUC: Ability to distinguish between classes\n\n")
|
|
|
|
f.write("5. RECOMMENDATIONS\n")
|
|
f.write("-" * 40 + "\n")
|
|
if metrics["precision"] < 0.7:
|
|
f.write("- Improve precision to reduce false positives\n")
|
|
if metrics["recall"] < 0.7:
|
|
f.write("- Improve recall to reduce false negatives\n")
|
|
if metrics["f1_score"] < 0.7:
|
|
f.write("- Overall model performance needs improvement\n")
|
|
if metrics["roc_auc"] > 0.8:
|
|
f.write("- Good discrimination ability between classes\n")
|
|
else:
|
|
f.write("- Consider improving feature extraction\n")
|
|
|
|
print(f"Report saved to: {report_path}")
|
|
|
|
return metrics
|
|
|
|
|
|
def main():
|
|
"""
|
|
Main function to run evaluation and generate reports.
|
|
"""
|
|
print("=" * 60)
|
|
print("IMAGE SIMILARITY MODEL EVALUATION")
|
|
print("=" * 60)
|
|
|
|
# Set device
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
print(f"Using device: {device}")
|
|
|
|
# Load configuration
|
|
config_dict = config.copy()
|
|
if isinstance(config_dict.get("image_size"), list):
|
|
config_dict["image_size"] = tuple(config_dict["image_size"])
|
|
|
|
# Create data loaders
|
|
print("\n1. Creating data loaders...")
|
|
_, val_loader = create_data_loaders(
|
|
root_dir=config_dict["data_dir"],
|
|
batch_size=config_dict["batch_size"],
|
|
train_split=config_dict["train_split"],
|
|
num_workers=config_dict["num_workers"],
|
|
image_size=config_dict["image_size"],
|
|
augment_train=False,
|
|
augment_val=False,
|
|
device=device,
|
|
)
|
|
|
|
print(f"Validation batches: {len(val_loader)}")
|
|
|
|
# Load trained model
|
|
print("\n2. Loading trained model...")
|
|
model = create_similarity_model(
|
|
model_type="cnn",
|
|
input_size=config_dict["image_size"][0]
|
|
if isinstance(config_dict["image_size"], (tuple, list))
|
|
else config_dict["image_size"],
|
|
input_channels=3,
|
|
hidden_channels=64,
|
|
num_blocks=4,
|
|
dropout_rate=0.3,
|
|
use_batch_norm=True,
|
|
)
|
|
|
|
# Load best checkpoint
|
|
checkpoint_dir = os.path.join(
|
|
config_dict.get("output_dir", "runs/similarity"), "checkpoints"
|
|
)
|
|
best_checkpoint = os.path.join(checkpoint_dir, "best_model.pt")
|
|
|
|
if os.path.exists(best_checkpoint):
|
|
checkpoint = torch.load(best_checkpoint, map_location=device)
|
|
model.load_state_dict(checkpoint["model_state_dict"])
|
|
print(f"Loaded best model from epoch {checkpoint['epoch']}")
|
|
print(f"Best validation loss: {checkpoint['val_loss']:.4f}")
|
|
else:
|
|
print("Warning: Best model checkpoint not found!")
|
|
print("Using randomly initialized model for demonstration.")
|
|
|
|
model = model.to(device)
|
|
|
|
# Plot training metrics
|
|
print("\n3. Plotting training metrics...")
|
|
plot_training_metrics(config_dict.get("output_dir", "runs/similarity"))
|
|
|
|
# Analyze model performance
|
|
print("\n4. Analyzing model performance...")
|
|
metrics = analyze_model_performance(model, val_loader, device)
|
|
|
|
# Display results
|
|
print("\n" + "=" * 60)
|
|
print("PERFORMANCE METRICS")
|
|
print("=" * 60)
|
|
print(f"Accuracy: {metrics['accuracy']:.4f}")
|
|
print(f"Precision: {metrics['precision']:.4f}")
|
|
print(f"Recall: {metrics['recall']:.4f}")
|
|
print(f"F1 Score: {metrics['f1_score']:.4f}")
|
|
print(f"ROC AUC: {metrics['roc_auc']:.4f}")
|
|
print(f"Optimal Threshold: {metrics['optimal_threshold']:.4f}")
|
|
|
|
# Plot confusion matrix
|
|
print("\n5. Plotting confusion matrix...")
|
|
plot_confusion_matrix(metrics["confusion_matrix"])
|
|
|
|
# Plot ROC curve
|
|
print("\n6. Plotting ROC curve...")
|
|
# For demonstration, we need to get probabilities again
|
|
model.eval()
|
|
all_probabilities = []
|
|
all_targets = []
|
|
|
|
with torch.no_grad():
|
|
for batch in val_loader:
|
|
google_img = batch["google_img"].to(device)
|
|
yandex_img = batch["yandex_img"].to(device)
|
|
target = batch["same_domain"].float().to(device)
|
|
|
|
output = model(google_img, yandex_img)
|
|
probabilities = torch.sigmoid(output).squeeze()
|
|
|
|
all_probabilities.extend(probabilities.cpu().numpy())
|
|
all_targets.extend(target.cpu().numpy())
|
|
|
|
all_probabilities = np.array(all_probabilities)
|
|
all_targets = np.array(all_targets)
|
|
|
|
fpr, tpr, _ = roc_curve(all_targets, all_probabilities)
|
|
roc_auc = auc(fpr, tpr)
|
|
plot_roc_curve(fpr, tpr, roc_auc)
|
|
|
|
# Plot probability distribution
|
|
print("\n7. Plotting probability distribution...")
|
|
plot_probability_distribution(all_probabilities, all_targets)
|
|
|
|
# Test on examples
|
|
print("\n8. Testing on examples...")
|
|
test_model_on_examples(model, device)
|
|
|
|
# Generate comprehensive report
|
|
print("\n9. Generating performance report...")
|
|
generate_performance_report(model, val_loader, device)
|
|
|
|
print("\n" + "=" * 60)
|
|
print("EVALUATION COMPLETED SUCCESSFULLY!")
|
|
print("=" * 60)
|
|
print("\nNext steps:")
|
|
print("1. Check the generated plots in the runs/similarity directory")
|
|
print("2. Review the performance report in the reports directory")
|
|
print("3. Consider adjusting the decision threshold if needed")
|
|
print("4. Retrain with different hyperparameters if performance is poor")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|