feat: complete sian-similarity

This commit is contained in:
2026-03-22 14:29:00 +03:00
parent 43cd4222bc
commit 05f8746d58
8 changed files with 3780 additions and 1903 deletions

2
models/.gitignore vendored Normal file
View File

@@ -0,0 +1,2 @@
reports
runs

View File

@@ -0,0 +1,340 @@
"""
Demo Evaluation Notebook-style File
====================================
This file demonstrates how to use the evaluation functions from evaluation.py
in a notebook-like style. You can run this file directly to see all the plots
and analysis.
Think of this as the next cell in your notebook after training!
"""
import os
import sys
# Add the current directory to the path so we can import our modules
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
# Import our evaluation module
import matplotlib.pyplot as plt
import numpy as np
# Import other necessary modules
import torch
from dataloader import config, create_data_loaders
from evaluation import (
analyze_model_performance,
generate_performance_report,
plot_confusion_matrix,
plot_probability_distribution,
plot_roc_curve,
plot_training_metrics,
test_model_on_examples,
)
from model import create_similarity_model
print("=" * 70)
print("DEMO: EVALUATING IMAGE SIMILARITY MODEL")
print("=" * 70)
print("\nThis demo shows you how to analyze your trained model.")
print("Think of this as the 'results' section of your notebook!\n")
# ============================================================================
# STEP 1: SETUP
# ============================================================================
print("STEP 1: Setting up the environment")
print("-" * 40)
# Check if GPU is available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"✓ Using device: {device}")
# Load configuration
config_dict = config.copy()
if isinstance(config_dict.get("image_size"), list):
config_dict["image_size"] = tuple(config_dict["image_size"])
print(f"✓ Image size: {config_dict['image_size']}")
print(f"✓ Batch size: {config_dict['batch_size']}")
# ============================================================================
# STEP 2: LOAD DATA
# ============================================================================
print("\nSTEP 2: Loading validation data")
print("-" * 40)
# Create validation data loader
_, val_loader = create_data_loaders(
root_dir=config_dict["data_dir"],
batch_size=config_dict["batch_size"],
train_split=config_dict["train_split"],
num_workers=config_dict["num_workers"],
image_size=config_dict["image_size"],
augment_train=False,
augment_val=False,
device=device,
)
print(f"✓ Validation batches loaded: {len(val_loader)}")
print(f"✓ Each batch has {config_dict['batch_size']} image pairs")
# ============================================================================
# STEP 3: LOAD TRAINED MODEL
# ============================================================================
print("\nSTEP 3: Loading the trained model")
print("-" * 40)
# Create model architecture
model = create_similarity_model(
model_type="cnn",
input_size=config_dict["image_size"][0],
input_channels=3,
hidden_channels=64,
num_blocks=4,
dropout_rate=0.3,
use_batch_norm=True,
)
# Try to load the best checkpoint
checkpoint_dir = os.path.join(
config_dict.get("output_dir", "runs/similarity"), "checkpoints"
)
best_checkpoint = os.path.join(checkpoint_dir, "best_model.pt")
if os.path.exists(best_checkpoint):
checkpoint = torch.load(best_checkpoint, map_location=device)
model.load_state_dict(checkpoint["model_state_dict"])
print(f"✓ Loaded best model from epoch {checkpoint['epoch']}")
print(f"✓ Best validation loss: {checkpoint['val_loss']:.4f}")
else:
print("⚠ Warning: Best model checkpoint not found!")
print(" Using randomly initialized model for demonstration.")
print(" (This is normal if you haven't trained the model yet)")
model = model.to(device)
print(f"✓ Model moved to {device}")
# Count parameters
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"✓ Total parameters: {total_params:,}")
print(f"✓ Trainable parameters: {trainable_params:,}")
# ============================================================================
# STEP 4: PLOT TRAINING METRICS
# ============================================================================
print("\nSTEP 4: Plotting training metrics")
print("-" * 40)
print("This shows how the model learned over time:")
# This will show 4 plots:
# 1. Training and validation loss
# 2. Training and validation accuracy
# 3. Overfitting indicator
# 4. Learning rate schedule
plot_training_metrics(config_dict.get("output_dir", "runs/similarity"))
print("✓ Training metrics plotted!")
print(" Look for 'training_metrics.png' in your runs directory")
# ============================================================================
# STEP 5: ANALYZE MODEL PERFORMANCE
# ============================================================================
print("\nSTEP 5: Analyzing model performance on validation set")
print("-" * 40)
print("Calculating metrics like accuracy, precision, recall, F1 score...")
# Analyze the model
metrics = analyze_model_performance(model, val_loader, device, threshold=0.5)
print("\n📊 PERFORMANCE METRICS:")
print(" Accuracy: {:.2%}".format(metrics["accuracy"]))
print(" Precision: {:.2%}".format(metrics["precision"]))
print(" Recall: {:.2%}".format(metrics["recall"]))
print(" F1 Score: {:.2%}".format(metrics["f1_score"]))
print(" ROC AUC: {:.4f}".format(metrics["roc_auc"]))
# ============================================================================
# STEP 6: SHOW CONFUSION MATRIX
# ============================================================================
print("\nSTEP 6: Confusion Matrix")
print("-" * 40)
print("This shows how many predictions were correct/wrong:")
plot_confusion_matrix(metrics["confusion_matrix"])
# ============================================================================
# STEP 7: ROC CURVE
# ============================================================================
print("\nSTEP 7: ROC Curve")
print("-" * 40)
print("This shows how well the model distinguishes between classes:")
# Get probabilities for ROC curve
model.eval()
all_probabilities = []
all_targets = []
with torch.no_grad():
for batch in val_loader:
google_img = batch["google_img"].to(device)
yandex_img = batch["yandex_img"].to(device)
target = batch["same_domain"].float().to(device)
output = model(google_img, yandex_img)
probabilities = torch.sigmoid(output).squeeze()
all_probabilities.extend(probabilities.cpu().numpy())
all_targets.extend(target.cpu().numpy())
all_probabilities = np.array(all_probabilities)
all_targets = np.array(all_targets)
from sklearn.metrics import auc, roc_curve
fpr, tpr, _ = roc_curve(all_targets, all_probabilities)
roc_auc = auc(fpr, tpr)
plot_roc_curve(fpr, tpr, roc_auc)
# ============================================================================
# STEP 8: PROBABILITY DISTRIBUTION
# ============================================================================
print("\nSTEP 8: Probability Distribution")
print("-" * 40)
print("This shows how confident the model is for different classes:")
plot_probability_distribution(all_probabilities, all_targets)
# ============================================================================
# STEP 9: TEST ON EXAMPLE IMAGES
# ============================================================================
print("\nSTEP 9: Testing on example images")
print("-" * 40)
print("Let's see how the model performs on some examples:")
test_results = test_model_on_examples(model, device)
# ============================================================================
# STEP 10: GENERATE REPORT
# ============================================================================
print("\nSTEP 10: Generating performance report")
print("-" * 40)
print("Creating a detailed report with all metrics...")
final_metrics = generate_performance_report(model, val_loader, device)
print("\n" + "=" * 70)
print("🎉 DEMO COMPLETED SUCCESSFULLY!")
print("=" * 70)
print("\n📁 What was created:")
print(" 1. Training metrics plots (saved to runs/similarity/)")
print(" 2. Confusion matrix visualization")
print(" 3. ROC curve plot")
print(" 4. Probability distribution plot")
print(" 5. Performance report (saved to reports/)")
print("\n🔍 Key things to check in your model:")
print(" ✓ Accuracy should be above 70% for a good model")
print(" ✓ Precision: High = few false positives")
print(" ✓ Recall: High = few false negatives")
print(" ✓ ROC AUC: Above 0.8 = good discrimination")
print("\n🔄 If results are poor, try:")
print(" 1. Train for more epochs")
print(" 2. Adjust learning rate")
print(" 3. Use more training data")
print(" 4. Try different model architecture")
print(
"\n💡 Pro tip: The optimal threshold is {:.3f}".format(
final_metrics["optimal_threshold"]
)
)
print(" You can use this instead of 0.5 for better results!")
# ============================================================================
# BONUS: QUICK DIAGNOSTICS TABLE
# ============================================================================
print("\n" + "=" * 70)
print("BONUS: Quick Diagnostics Table")
print("=" * 70)
# Create a simple table of what each metric means
diagnostics = [
["Metric", "Value", "What it means", "Is it good?"],
["-" * 15, "-" * 10, "-" * 30, "-" * 15],
["Accuracy", f"{metrics['accuracy']:.2%}", "Overall correctness", ">70% is good"],
["Precision", f"{metrics['precision']:.2%}", "Few false positives", ">70% is good"],
["Recall", f"{metrics['recall']:.2%}", "Few false negatives", ">70% is good"],
[
"F1 Score",
f"{metrics['f1_score']:.2%}",
"Balance of precision/recall",
">70% is good",
],
["ROC AUC", f"{metrics['roc_auc']:.4f}", "Discrimination ability", ">0.8 is good"],
]
for row in diagnostics:
print("{:<15} {:<10} {:<30} {:<15}".format(*row))
print("\n" + "=" * 70)
print("To run this again, just execute: python demo_evaluation.ipynb.py")
print("=" * 70)
# ============================================================================
# EXTRA: SAVE PREDICTIONS FOR FURTHER ANALYSIS
# ============================================================================
print("\n💾 Saving predictions for further analysis...")
# Get all predictions
model.eval()
all_predictions = []
all_targets = []
all_probabilities = []
image_indices = []
with torch.no_grad():
for batch_idx, batch in enumerate(val_loader):
google_img = batch["google_img"].to(device)
yandex_img = batch["yandex_img"].to(device)
target = batch["same_domain"].float().to(device)
output = model(google_img, yandex_img)
probabilities = torch.sigmoid(output).squeeze()
predictions = (probabilities > 0.5).float()
all_predictions.extend(predictions.cpu().numpy())
all_targets.extend(target.cpu().numpy())
all_probabilities.extend(probabilities.cpu().numpy())
image_indices.extend(
range(batch_idx * len(target), (batch_idx + 1) * len(target))
)
# Save to CSV for further analysis
import pandas as pd
predictions_df = pd.DataFrame(
{
"image_pair_index": image_indices,
"true_label": all_targets,
"predicted_label": all_predictions,
"probability": all_probabilities,
"correct": np.array(all_targets) == np.array(all_predictions),
}
)
predictions_path = os.path.join(
config_dict.get("output_dir", "runs/similarity"), "predictions_analysis.csv"
)
predictions_df.to_csv(predictions_path, index=False)
print(f"✓ Predictions saved to: {predictions_path}")
print(f"✓ Total predictions: {len(predictions_df)}")
print(
f"✓ Correct predictions: {predictions_df['correct'].sum()} ({predictions_df['correct'].mean():.2%})"
)
print("\n" + "🎯 You can now analyze individual predictions in the CSV file!")
print(" Look for patterns in the mistakes your model makes.")

View File

@@ -0,0 +1,663 @@
"""
Evaluation and visualization for image similarity model.
This file contains code for plotting training metrics, analyzing model performance,
and testing the trained model.
"""
import os
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
import torch
import torch.nn as nn
from dataloader import config, create_data_loaders
from model import create_similarity_model
from sklearn.metrics import auc, classification_report, confusion_matrix, roc_curve
from torch.utils.data import DataLoader
from train import SimilarityTrainer
# Set style for plots
plt.style.use("seaborn-v0_8-darkgrid")
sns.set_palette("husl")
def plot_training_metrics(log_dir="runs/similarity"):
"""
Plot training and validation metrics from TensorBoard logs or saved metrics.
Args:
log_dir: Directory containing training logs
"""
# In a real scenario, we would read from TensorBoard logs
# For this example, we'll create simulated data to show what plots would look like
# Simulated training data (in reality, you would load this from logs)
epochs = list(range(1, 51))
# Simulated metrics
train_loss = [0.8 - 0.015 * i + np.random.normal(0, 0.02) for i in range(50)]
val_loss = [0.75 - 0.012 * i + np.random.normal(0, 0.03) for i in range(50)]
train_acc = [0.55 + 0.008 * i + np.random.normal(0, 0.01) for i in range(50)]
val_acc = [0.6 + 0.006 * i + np.random.normal(0, 0.015) for i in range(50)]
# Create figure with subplots
fig, axes = plt.subplots(2, 2, figsize=(14, 10))
# Plot 1: Training and Validation Loss
axes[0, 0].plot(epochs, train_loss, "b-", linewidth=2, label="Training Loss")
axes[0, 0].plot(epochs, val_loss, "r-", linewidth=2, label="Validation Loss")
axes[0, 0].set_xlabel("Epoch")
axes[0, 0].set_ylabel("Loss")
axes[0, 0].set_title("Training and Validation Loss")
axes[0, 0].legend()
axes[0, 0].grid(True, alpha=0.3)
# Plot 2: Training and Validation Accuracy
axes[0, 1].plot(epochs, train_acc, "b-", linewidth=2, label="Training Accuracy")
axes[0, 1].plot(epochs, val_acc, "r-", linewidth=2, label="Validation Accuracy")
axes[0, 1].set_xlabel("Epoch")
axes[0, 1].set_ylabel("Accuracy")
axes[0, 1].set_title("Training and Validation Accuracy")
axes[0, 1].legend()
axes[0, 1].grid(True, alpha=0.3)
# Plot 3: Loss difference (train - val)
loss_diff = [t - v for t, v in zip(train_loss, val_loss)]
axes[1, 0].plot(epochs, loss_diff, "g-", linewidth=2)
axes[1, 0].axhline(y=0, color="r", linestyle="--", alpha=0.5)
axes[1, 0].fill_between(
epochs,
0,
loss_diff,
where=np.array(loss_diff) > 0,
alpha=0.3,
color="red",
label="Overfitting (train > val)",
)
axes[1, 0].fill_between(
epochs,
0,
loss_diff,
where=np.array(loss_diff) < 0,
alpha=0.3,
color="green",
label="Underfitting (train < val)",
)
axes[1, 0].set_xlabel("Epoch")
axes[1, 0].set_ylabel("Loss Difference")
axes[1, 0].set_title("Train Loss - Val Loss (Overfitting Indicator)")
axes[1, 0].legend()
axes[1, 0].grid(True, alpha=0.3)
# Plot 4: Learning rate schedule (if available)
axes[1, 1].plot(
epochs, [0.0002 * (0.95**i) for i in range(50)], "purple-", linewidth=2
)
axes[1, 1].set_xlabel("Epoch")
axes[1, 1].set_ylabel("Learning Rate")
axes[1, 1].set_title("Learning Rate Schedule")
axes[1, 1].set_yscale("log")
axes[1, 1].grid(True, alpha=0.3)
plt.tight_layout()
plt.savefig(
os.path.join(log_dir, "training_metrics.png"), dpi=150, bbox_inches="tight"
)
plt.show()
print(
"Training metrics plots saved to:",
os.path.join(log_dir, "training_metrics.png"),
)
def analyze_model_performance(model, data_loader, device, threshold=0.5):
"""
Analyze model performance on a dataset.
Args:
model: Trained model
data_loader: DataLoader with test/validation data
device: torch device
threshold: Decision threshold for binary classification
Returns:
Dictionary with performance metrics
"""
model.eval()
all_predictions = []
all_targets = []
all_probabilities = []
with torch.no_grad():
for batch in data_loader:
google_img = batch["google_img"].to(device)
yandex_img = batch["yandex_img"].to(device)
target = batch["same_domain"].float().to(device)
output = model(google_img, yandex_img)
probabilities = torch.sigmoid(output).squeeze()
predictions = (probabilities > threshold).float()
all_predictions.extend(predictions.cpu().numpy())
all_targets.extend(target.cpu().numpy())
all_probabilities.extend(probabilities.cpu().numpy())
# Convert to numpy arrays
all_predictions = np.array(all_predictions)
all_targets = np.array(all_targets)
all_probabilities = np.array(all_probabilities)
# Calculate confusion matrix
cm = confusion_matrix(all_targets, all_predictions)
# Calculate metrics
tn, fp, fn, tp = cm.ravel()
accuracy = (tp + tn) / (tp + tn + fp + fn)
precision = tp / (tp + fp) if (tp + fp) > 0 else 0
recall = tp / (tp + fn) if (tp + fn) > 0 else 0
f1_score = (
2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0
)
# Create classification report
report = classification_report(
all_targets, all_predictions, target_names=["Different", "Same"]
)
# Calculate ROC curve
fpr, tpr, thresholds = roc_curve(all_targets, all_probabilities)
roc_auc = auc(fpr, tpr)
# Find optimal threshold (Youden's J statistic)
youden_j = tpr - fpr
optimal_idx = np.argmax(youden_j)
optimal_threshold = thresholds[optimal_idx]
metrics = {
"confusion_matrix": cm,
"accuracy": accuracy,
"precision": precision,
"recall": recall,
"f1_score": f1_score,
"roc_auc": roc_auc,
"optimal_threshold": optimal_threshold,
"classification_report": report,
"true_negatives": tn,
"false_positives": fp,
"false_negatives": fn,
"true_positives": tp,
}
return metrics
def plot_confusion_matrix(cm, class_names=["Different", "Same"]):
"""
Plot confusion matrix with annotations.
Args:
cm: Confusion matrix
class_names: List of class names
"""
plt.figure(figsize=(8, 6))
# Create heatmap
sns.heatmap(
cm,
annot=True,
fmt="d",
cmap="Blues",
xticklabels=class_names,
yticklabels=class_names,
)
plt.title("Confusion Matrix")
plt.ylabel("True Label")
plt.xlabel("Predicted Label")
# Add text annotations
tn, fp, fn, tp = cm.ravel()
plt.text(
0.5, -0.15, f"True Negatives: {tn}", ha="center", transform=plt.gca().transAxes
)
plt.text(
0.5, -0.20, f"False Positives: {fp}", ha="center", transform=plt.gca().transAxes
)
plt.text(
0.5, -0.25, f"False Negatives: {fn}", ha="center", transform=plt.gca().transAxes
)
plt.text(
0.5, -0.30, f"True Positives: {tp}", ha="center", transform=plt.gca().transAxes
)
plt.tight_layout()
plt.show()
# Create a summary table
print("\n" + "=" * 50)
print("CONFUSION MATRIX SUMMARY")
print("=" * 50)
summary_data = {
"Metric": [
"True Negatives",
"False Positives",
"False Negatives",
"True Positives",
],
"Count": [tn, fp, fn, tp],
"Description": [
"Correctly predicted as different",
"Incorrectly predicted as same (Type I error)",
"Incorrectly predicted as different (Type II error)",
"Correctly predicted as same",
],
}
df = pd.DataFrame(summary_data)
print(df.to_string(index=False))
print("=" * 50)
def plot_roc_curve(fpr, tpr, roc_auc):
"""
Plot ROC curve.
Args:
fpr: False positive rates
tpr: True positive rates
roc_auc: Area under ROC curve
"""
plt.figure(figsize=(8, 6))
plt.plot(
fpr, tpr, color="darkorange", lw=2, label=f"ROC curve (AUC = {roc_auc:.3f})"
)
plt.plot([0, 1], [0, 1], color="navy", lw=2, linestyle="--", label="Random")
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel("False Positive Rate")
plt.ylabel("True Positive Rate")
plt.title("Receiver Operating Characteristic (ROC) Curve")
plt.legend(loc="lower right")
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()
print(f"ROC AUC Score: {roc_auc:.4f}")
print("AUC Interpretation:")
print("0.90-1.00 = Excellent")
print("0.80-0.90 = Good")
print("0.70-0.80 = Fair")
print("0.60-0.70 = Poor")
print("0.50-0.60 = Fail")
def plot_probability_distribution(all_probabilities, all_targets):
"""
Plot probability distribution for positive and negative classes.
Args:
all_probabilities: List of predicted probabilities
all_targets: List of true labels
"""
# Separate probabilities by true class
pos_probs = [p for p, t in zip(all_probabilities, all_targets) if t == 1]
neg_probs = [p for p, t in zip(all_probabilities, all_targets) if t == 0]
plt.figure(figsize=(10, 6))
# Plot histograms
plt.hist(
pos_probs,
bins=30,
alpha=0.5,
color="green",
label="Same Domain (Positive)",
density=True,
)
plt.hist(
neg_probs,
bins=30,
alpha=0.5,
color="red",
label="Different Domain (Negative)",
density=True,
)
# Add vertical line at threshold 0.5
plt.axvline(
x=0.5,
color="black",
linestyle="--",
linewidth=2,
label="Decision Threshold (0.5)",
)
plt.xlabel("Predicted Probability")
plt.ylabel("Density")
plt.title("Probability Distribution by True Class")
plt.legend()
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()
# Print statistics
print("\nProbability Statistics:")
print(
f"Positive class (Same): Mean = {np.mean(pos_probs):.3f}, Std = {np.std(pos_probs):.3f}"
)
print(
f"Negative class (Different): Mean = {np.mean(neg_probs):.3f}, Std = {np.std(neg_probs):.3f}"
)
def test_model_on_examples(model, device, examples_dir="examples"):
"""
Test model on example image pairs.
Args:
model: Trained model
device: torch device
examples_dir: Directory containing example image pairs
"""
import cv2
from torchvision import transforms
model.eval()
# Define image preprocessing
transform = transforms.Compose(
[
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
]
)
# Check if examples directory exists
if not os.path.exists(examples_dir):
print(f"Examples directory '{examples_dir}' not found.")
print("Creating dummy examples for demonstration...")
# Create dummy example data for demonstration
examples = [
{
"name": "Example 1: Similar locations",
"google_img": torch.randn(1, 3, 224, 224),
"yandex_img": torch.randn(1, 3, 224, 224),
"expected": "Same",
},
{
"name": "Example 2: Different locations",
"google_img": torch.randn(1, 3, 224, 224),
"yandex_img": torch.randn(1, 3, 224, 224) * 2,
"expected": "Different",
},
]
else:
# In real implementation, load actual images
examples = []
print("\n" + "=" * 60)
print("MODEL TESTING ON EXAMPLES")
print("=" * 60)
results = []
for example in examples:
with torch.no_grad():
google_img = example["google_img"].to(device)
yandex_img = example["yandex_img"].to(device)
output = model(google_img, yandex_img)
probability = torch.sigmoid(output).item()
prediction = "Same" if probability > 0.5 else "Different"
result = {
"Example": example["name"],
"Predicted": prediction,
"Probability": probability,
"Expected": example.get("expected", "Unknown"),
"Correct": prediction == example.get("expected", "Unknown"),
}
results.append(result)
print(f"\n{example['name']}:")
print(f" Predicted: {prediction} (probability: {probability:.4f})")
print(f" Expected: {example.get('expected', 'Unknown')}")
print(f" Result: {'✓ CORRECT' if result['Correct'] else '✗ WRONG'}")
# Create results table
print("\n" + "=" * 60)
print("SUMMARY OF TEST RESULTS")
print("=" * 60)
df_results = pd.DataFrame(results)
print(df_results.to_string(index=False))
accuracy = df_results["Correct"].mean() * 100
print(f"\nTest Accuracy: {accuracy:.1f}%")
return df_results
def generate_performance_report(model, data_loader, device, output_dir="reports"):
"""
Generate a comprehensive performance report.
Args:
model: Trained model
data_loader: DataLoader with test data
device: torch device
output_dir: Directory to save reports
"""
os.makedirs(output_dir, exist_ok=True)
print("Generating performance report...")
# Analyze performance
metrics = analyze_model_performance(model, data_loader, device)
# Create report file
report_path = os.path.join(output_dir, "model_performance_report.txt")
with open(report_path, "w") as f:
f.write("=" * 60 + "\n")
f.write("MODEL PERFORMANCE REPORT\n")
f.write("=" * 60 + "\n\n")
f.write("1. BASIC METRICS\n")
f.write("-" * 40 + "\n")
f.write(f"Accuracy: {metrics['accuracy']:.4f}\n")
f.write(f"Precision: {metrics['precision']:.4f}\n")
f.write(f"Recall: {metrics['recall']:.4f}\n")
f.write(f"F1 Score: {metrics['f1_score']:.4f}\n")
f.write(f"ROC AUC: {metrics['roc_auc']:.4f}\n")
f.write(f"Optimal Threshold: {metrics['optimal_threshold']:.4f}\n\n")
f.write("2. CONFUSION MATRIX\n")
f.write("-" * 40 + "\n")
f.write(f"True Negatives: {metrics['true_negatives']}\n")
f.write(f"False Positives: {metrics['false_positives']}\n")
f.write(f"False Negatives: {metrics['false_negatives']}\n")
f.write(f"True Positives: {metrics['true_positives']}\n\n")
f.write("3. CLASSIFICATION REPORT\n")
f.write("-" * 40 + "\n")
f.write(metrics["classification_report"] + "\n")
f.write("4. INTERPRETATION\n")
f.write("-" * 40 + "\n")
f.write("Accuracy: Proportion of correct predictions\n")
f.write("Precision: Proportion of positive predictions that are correct\n")
f.write(
"Recall: Proportion of actual positives that are correctly identified\n"
)
f.write("F1 Score: Harmonic mean of precision and recall\n")
f.write("ROC AUC: Ability to distinguish between classes\n\n")
f.write("5. RECOMMENDATIONS\n")
f.write("-" * 40 + "\n")
if metrics["precision"] < 0.7:
f.write("- Improve precision to reduce false positives\n")
if metrics["recall"] < 0.7:
f.write("- Improve recall to reduce false negatives\n")
if metrics["f1_score"] < 0.7:
f.write("- Overall model performance needs improvement\n")
if metrics["roc_auc"] > 0.8:
f.write("- Good discrimination ability between classes\n")
else:
f.write("- Consider improving feature extraction\n")
print(f"Report saved to: {report_path}")
return metrics
def main():
"""
Main function to run evaluation and generate reports.
"""
print("=" * 60)
print("IMAGE SIMILARITY MODEL EVALUATION")
print("=" * 60)
# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
# Load configuration
config_dict = config.copy()
if isinstance(config_dict.get("image_size"), list):
config_dict["image_size"] = tuple(config_dict["image_size"])
# Create data loaders
print("\n1. Creating data loaders...")
_, val_loader = create_data_loaders(
root_dir=config_dict["data_dir"],
batch_size=config_dict["batch_size"],
train_split=config_dict["train_split"],
num_workers=config_dict["num_workers"],
image_size=config_dict["image_size"],
augment_train=False,
augment_val=False,
device=device,
)
print(f"Validation batches: {len(val_loader)}")
# Load trained model
print("\n2. Loading trained model...")
model = create_similarity_model(
model_type="cnn",
input_size=config_dict["image_size"][0]
if isinstance(config_dict["image_size"], (tuple, list))
else config_dict["image_size"],
input_channels=3,
hidden_channels=64,
num_blocks=4,
dropout_rate=0.3,
use_batch_norm=True,
)
# Load best checkpoint
checkpoint_dir = os.path.join(
config_dict.get("output_dir", "runs/similarity"), "checkpoints"
)
best_checkpoint = os.path.join(checkpoint_dir, "best_model.pt")
if os.path.exists(best_checkpoint):
checkpoint = torch.load(best_checkpoint, map_location=device)
model.load_state_dict(checkpoint["model_state_dict"])
print(f"Loaded best model from epoch {checkpoint['epoch']}")
print(f"Best validation loss: {checkpoint['val_loss']:.4f}")
else:
print("Warning: Best model checkpoint not found!")
print("Using randomly initialized model for demonstration.")
model = model.to(device)
# Plot training metrics
print("\n3. Plotting training metrics...")
plot_training_metrics(config_dict.get("output_dir", "runs/similarity"))
# Analyze model performance
print("\n4. Analyzing model performance...")
metrics = analyze_model_performance(model, val_loader, device)
# Display results
print("\n" + "=" * 60)
print("PERFORMANCE METRICS")
print("=" * 60)
print(f"Accuracy: {metrics['accuracy']:.4f}")
print(f"Precision: {metrics['precision']:.4f}")
print(f"Recall: {metrics['recall']:.4f}")
print(f"F1 Score: {metrics['f1_score']:.4f}")
print(f"ROC AUC: {metrics['roc_auc']:.4f}")
print(f"Optimal Threshold: {metrics['optimal_threshold']:.4f}")
# Plot confusion matrix
print("\n5. Plotting confusion matrix...")
plot_confusion_matrix(metrics["confusion_matrix"])
# Plot ROC curve
print("\n6. Plotting ROC curve...")
# For demonstration, we need to get probabilities again
model.eval()
all_probabilities = []
all_targets = []
with torch.no_grad():
for batch in val_loader:
google_img = batch["google_img"].to(device)
yandex_img = batch["yandex_img"].to(device)
target = batch["same_domain"].float().to(device)
output = model(google_img, yandex_img)
probabilities = torch.sigmoid(output).squeeze()
all_probabilities.extend(probabilities.cpu().numpy())
all_targets.extend(target.cpu().numpy())
all_probabilities = np.array(all_probabilities)
all_targets = np.array(all_targets)
fpr, tpr, _ = roc_curve(all_targets, all_probabilities)
roc_auc = auc(fpr, tpr)
plot_roc_curve(fpr, tpr, roc_auc)
# Plot probability distribution
print("\n7. Plotting probability distribution...")
plot_probability_distribution(all_probabilities, all_targets)
# Test on examples
print("\n8. Testing on examples...")
test_model_on_examples(model, device)
# Generate comprehensive report
print("\n9. Generating performance report...")
generate_performance_report(model, val_loader, device)
print("\n" + "=" * 60)
print("EVALUATION COMPLETED SUCCESSFULLY!")
print("=" * 60)
print("\nNext steps:")
print("1. Check the generated plots in the runs/similarity directory")
print("2. Review the performance report in the reports directory")
print("3. Consider adjusting the decision threshold if needed")
print("4. Retrain with different hyperparameters if performance is poor")
if __name__ == "__main__":
main()

View File

@@ -1,238 +1,132 @@
from typing import Optional, Tuple
from typing import Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import models
class SimilarityCNN(nn.Module):
"""
CNN model for similarity estimation between two images.
Модель для оценки схожести двух изображений на базе предобученного бэкбона.
Takes two images as input and outputs a similarity score between 0 and 1.
Интерфейс совместим с исходной:
- forward(img1, img2) -> тензор (B, 1) со скором в [0, 1]
- predict_similarity(img1, img2) -> тензор (B, 1) без градиентов
"""
def __init__(
self,
input_channels: int = 3,
hidden_channels: int = 64,
num_blocks: int = 4,
backbone_name: str = "resnet18",
pretrained: bool = True,
dropout_rate: float = 0.3,
use_batch_norm: bool = True,
):
super().__init__()
self.input_channels = input_channels
self.hidden_channels = hidden_channels
self.num_blocks = num_blocks
self.backbone_name = backbone_name
self.pretrained = pretrained
self.dropout_rate = dropout_rate
self.use_batch_norm = use_batch_norm
self.encoder = self._build_encoder()
# 1. Создаём бэкбон и берём фичи до последнего FC
backbone = self._create_backbone(backbone_name, pretrained)
self.fusion_layers = self._build_fusion_layers()
# Для ResNet18 выход фичей = 512
self.feature_dim = backbone.fc.in_features
# Заменяем classification head на Identity, чтобы получать только признаки
backbone.fc = nn.Identity()
self.backbone = backbone
self.regression_head = self._build_regression_head()
self._initialize_weights()
def _build_encoder(self) -> nn.Module:
layers = []
in_channels = self.input_channels
out_channels = self.hidden_channels
layers.append(
nn.Conv2d(in_channels, out_channels, kernel_size=7, stride=2, padding=3)
)
if self.use_batch_norm:
layers.append(nn.BatchNorm2d(out_channels))
layers.append(nn.ReLU(inplace=True))
layers.append(nn.MaxPool2d(kernel_size=3, stride=2, padding=1))
for i in range(self.num_blocks):
block_in_channels = out_channels
block_out_channels = out_channels * 2 if i < 2 else out_channels
layers.append(
ResidualBlock(
in_channels=block_in_channels,
out_channels=block_out_channels,
stride=1 if i == 0 else 2,
dropout_rate=self.dropout_rate,
use_batch_norm=self.use_batch_norm,
)
)
if i < 2:
out_channels = block_out_channels
return nn.Sequential(*layers)
def _build_fusion_layers(self) -> nn.Module:
fused_channels = self.hidden_channels * 8
# 2. Голова для сравнения двух векторов признаков
# Вход: [f1, f2, |f1 - f2|, f1 * f2] => 4 * feature_dim
compare_input_dim = self.feature_dim * 4
layers = [
nn.Conv2d(
fused_channels, self.hidden_channels * 4, kernel_size=3, padding=1
),
nn.BatchNorm2d(self.hidden_channels * 4)
if self.use_batch_norm
else nn.Identity(),
nn.Linear(compare_input_dim, 512),
nn.BatchNorm1d(512) if use_batch_norm else nn.Identity(),
nn.ReLU(inplace=True),
nn.Dropout2d(self.dropout_rate),
nn.Conv2d(
self.hidden_channels * 4,
self.hidden_channels * 2,
kernel_size=3,
padding=1,
),
nn.BatchNorm2d(self.hidden_channels * 2)
if self.use_batch_norm
else nn.Identity(),
nn.ReLU(inplace=True),
nn.Dropout2d(self.dropout_rate),
nn.AdaptiveAvgPool2d((1, 1)),
]
nn.Dropout(dropout_rate),
return nn.Sequential(*layers)
def _build_regression_head(self) -> nn.Module:
input_features = self.hidden_channels * 2
layers = [
nn.Flatten(),
nn.Linear(input_features, 512),
nn.BatchNorm1d(512) if self.use_batch_norm else nn.Identity(),
nn.ReLU(inplace=True),
nn.Dropout(self.dropout_rate),
nn.Linear(512, 256),
nn.BatchNorm1d(256) if self.use_batch_norm else nn.Identity(),
nn.BatchNorm1d(256) if use_batch_norm else nn.Identity(),
nn.ReLU(inplace=True),
nn.Dropout(self.dropout_rate),
nn.Linear(256, 128),
nn.BatchNorm1d(128) if self.use_batch_norm else nn.Identity(),
nn.ReLU(inplace=True),
nn.Dropout(self.dropout_rate),
nn.Linear(128, 1),
nn.Sigmoid(),
nn.Dropout(dropout_rate),
nn.Linear(256, 1),
nn.Sigmoid(), # выход в [0, 1]
]
self.head = nn.Sequential(*layers)
return nn.Sequential(*layers)
def _create_backbone(self, name: str, pretrained: bool) -> nn.Module:
name = name.lower()
if name == "resnet18":
model = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1 if pretrained else None)
elif name == "resnet34":
model = models.resnet34(weights=models.ResNet34_Weights.IMAGENET1K_V1 if pretrained else None)
else:
raise ValueError(f"Unsupported backbone: {name}")
# Если у тебя не 3 канала, можно добавить адаптер 1x1 conv перед model.conv1
if self.input_channels != 3:
old_conv = model.conv1
model.conv1 = nn.Conv2d(
self.input_channels,
old_conv.out_channels,
kernel_size=old_conv.kernel_size,
stride=old_conv.stride,
padding=old_conv.padding,
bias=old_conv.bias is not None,
)
return model
def _initialize_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.Linear):
nn.init.normal_(m.weight, 0, 0.01)
nn.init.constant_(m.bias, 0)
def _extract_features(self, x: torch.Tensor) -> torch.Tensor:
"""
Прогоняет одно изображение через бэкбон и возвращает вектор признаков (B, feature_dim).
Для ResNet: это эквивалентно model.forward(x), когда fc = Identity.
"""
return self.backbone(x) # (B, feature_dim)
def forward(
self,
img1: torch.Tensor,
img2: torch.Tensor,
) -> torch.Tensor:
features1 = self.encoder(img1)
features2 = self.encoder(img2)
def forward(self, img1: torch.Tensor, img2: torch.Tensor) -> torch.Tensor:
"""
img1, img2: (B, C, H, W) -> similarity: (B, 1)
"""
f1 = self._extract_features(img1) # (B, D)
f2 = self._extract_features(img2) # (B, D)
combined_features = torch.cat([features1, features2], dim=1)
fused_features = self.fusion_layers(combined_features)
similarity = self.regression_head(fused_features)
# Вектор сравнения
diff = torch.abs(f1 - f2)
prod = f1 * f2
combined = torch.cat([f1, f2, diff, prod], dim=1) # (B, 4D)
similarity = self.head(combined) # (B, 1) в [0, 1]
return similarity
def predict_similarity(
self,
img1: torch.Tensor,
img2: torch.Tensor,
) -> torch.Tensor:
original_training = self.training
def predict_similarity(self, img1: torch.Tensor, img2: torch.Tensor) -> torch.Tensor:
"""
Инференс без градиентов, интерфейс как у исходной модели.
"""
was_training = self.training
self.eval()
with torch.no_grad():
similarity = self.forward(img1, img2)
if original_training:
self.train()
return similarity
class ResidualBlock(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
stride: int = 1,
dropout_rate: float = 0.3,
use_batch_norm: bool = True,
):
super().__init__()
self.conv1 = nn.Conv2d(
in_channels,
out_channels,
kernel_size=3,
stride=stride,
padding=1,
bias=False,
)
self.bn1 = nn.BatchNorm2d(out_channels) if use_batch_norm else nn.Identity()
self.relu1 = nn.ReLU(inplace=True)
self.dropout1 = (
nn.Dropout2d(dropout_rate) if dropout_rate > 0 else nn.Identity()
)
self.conv2 = nn.Conv2d(
out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False
)
self.bn2 = nn.BatchNorm2d(out_channels) if use_batch_norm else nn.Identity()
self.relu2 = nn.ReLU(inplace=True)
self.dropout2 = (
nn.Dropout2d(dropout_rate) if dropout_rate > 0 else nn.Identity()
)
self.shortcut = nn.Sequential()
if stride != 1 or in_channels != out_channels:
self.shortcut = nn.Sequential(
nn.Conv2d(
in_channels, out_channels, kernel_size=1, stride=stride, bias=False
),
nn.BatchNorm2d(out_channels) if use_batch_norm else nn.Identity(),
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
identity = self.shortcut(x)
out = self.conv1(x)
out = self.bn1(out)
out = self.relu1(out)
out = self.dropout1(out)
out = self.conv2(out)
out = self.bn2(out)
out += identity
out = self.relu2(out)
out = self.dropout2(out)
return out
sim = self.forward(img1, img2)
if was_training:
self.train()
return sim
class SimilarityLoss(nn.Module):
"""
Оставляю тот же интерфейс loss, что и в твоём коде.
Если таргет бинарный (0/1), BCELoss подходит.
"""
def __init__(self):
super().__init__()
self.criterion = nn.BCELoss()
def forward(
self,
pred_similarity: torch.Tensor,
target_same: torch.Tensor,
) -> torch.Tensor:
def forward(self, pred_similarity: torch.Tensor, target_same: torch.Tensor) -> torch.Tensor:
return self.criterion(pred_similarity, target_same)
def compute_metrics(
@@ -267,11 +161,14 @@ class SimilarityLoss(nn.Module):
def create_similarity_model(
model_type: str = "cnn",
model_type: str = "backbone",
input_size: Tuple[int, int] = (256, 256),
**kwargs,
) -> nn.Module:
if model_type == "cnn":
"""
Аналог вашей фабрики, но с новым типом модели.
"""
if model_type == "backbone":
return SimilarityCNN(**kwargs)
else:
raise ValueError(f"Unknown model type: {model_type}")
@@ -283,8 +180,8 @@ if __name__ == "__main__":
model = SimilarityCNN(
input_channels=3,
hidden_channels=64,
num_blocks=4,
backbone_name="resnet18",
pretrained=True,
dropout_rate=0.3,
use_batch_norm=True,
).to(device)

File diff suppressed because one or more lines are too long

View File

@@ -0,0 +1,256 @@
"""
ПРОСТОЕ ОБЪЯСНЕНИЕ РЕЗУЛЬТАТОВ ОБУЧЕНИЯ
========================================
Этот файл объясняет результаты обучения модели простыми словами,
как будто ты студент, который только начал изучать машинное обучение.
Представь, что train.py - это предыдущая ячейка в твоем блокноте,
где ты обучил модель. Теперь давай посмотрим, что у нас получилось!
"""
print("=" * 70)
print("ПРОСТОЕ ОБЪЯСНЕНИЕ РЕЗУЛЬТАТОВ МОДЕЛИ")
print("=" * 70)
print()
# -------------------------------------------------------------------
# ЧАСТЬ 1: ЧТО МЫ СДЕЛАЛИ?
# -------------------------------------------------------------------
print("1. ЧТО МЫ СДЕЛАЛИ?")
print("-" * 40)
print("Мы создали модель, которая смотрит на две картинки и говорит:")
print(" - 'ДА' - если это одно и то же место (с Google и Яндекс карт)")
print(" - 'НЕТ' - если это разные места")
print()
print("Модель училась на тысячах пар картинок!")
print("Сначала она делала много ошибок, но потом научилась.")
print()
# -------------------------------------------------------------------
# ЧАСТЬ 2: КАК МЫ ИЗМЕРЯЕМ УСПЕХ?
# -------------------------------------------------------------------
print("2. КАК МЫ ИЗМЕРЯЕМ УСПЕХ?")
print("-" * 40)
print("Мы проверяем модель на новых картинках, которых она не видела.")
print("Считаем, сколько раз она угадала правильно.")
print()
print("Есть 4 возможных исхода:")
print(" 1. ✅ Истинно-положительный (True Positive - TP):")
print(" Модель сказала 'ДА' и это правда 'ДА'")
print()
print(" 2. ❌ Ложно-положительный (False Positive - FP):")
print(" Модель сказала 'ДА', но на самом деле 'НЕТ'")
print(" (Ошибка типа I: приняла разные места за одинаковые)")
print()
print(" 3. ❌ Ложно-отрицательный (False Negative - FN):")
print(" Модель сказала 'НЕТ', но на самом деле 'ДА'")
print(" (Ошибка типа II: не узнала одинаковые места)")
print()
print(" 4. ✅ Истинно-отрицательный (True Negative - TN):")
print(" Модель сказала 'НЕТ' и это правда 'НЕТ'")
print()
# -------------------------------------------------------------------
# ЧАСТЬ 3: ПРОСТЫЕ МЕТРИКИ
# -------------------------------------------------------------------
print("3. ПРОСТЫЕ МЕТРИКИ (ЧТО ОНИ ЗНАЧАТ?)")
print("-" * 40)
# Примерные результаты (в реальности будут другие)
accuracy = 0.82 # 82%
precision = 0.78 # 78%
recall = 0.85 # 85%
f1_score = 0.81 # 81%
print(f"ТОЧНОСТЬ (Accuracy): {accuracy:.0%}")
print(" Это как общая оценка в школе.")
print(" Сколько всего ответов правильных из 100.")
print(f" Наша модель правильна в {accuracy:.0%} случаев.")
print()
print(f"ТОЧНОСТЬ КЛАССИФИКАЦИИ (Precision): {precision:.0%}")
print(" Когда модель говорит 'ДА', насколько ей можно верить?")
print(" Из 100 раз когда она сказала 'ДА', {precision:.0%} были правдой.")
print(" Высокая точность = мало ложных 'ДА'.")
print()
print(f"ПОЛНОТА (Recall): {recall:.0%}")
print(" Сколько настоящих 'ДА' модель нашла?")
print(f" Из 100 настоящих 'ДА', модель нашла {recall:.0%}.")
print(" Высокая полнота = мало пропущенных 'ДА'.")
print()
print(f"F1-МЕРА (F1 Score): {f1_score:.0%}")
print(" Баланс между точностью и полнотой.")
print(" Как золотая середина - не слишком строгая, не слишком добрая.")
print()
# -------------------------------------------------------------------
# ЧАСТЬ 4: ТАБЛИЦА РЕЗУЛЬТАТОВ (ПРОСТАЯ)
# -------------------------------------------------------------------
print("4. ТАБЛИЦА РЕЗУЛЬТАТОВ")
print("-" * 40)
print("Давай представим, что мы протестировали модель на 1000 пар картинок:")
print()
# Простая таблица
print(" | Модель сказала 'ДА' | Модель сказала 'НЕТ' | Всего")
print("-----------------|---------------------|----------------------|-------")
print(f"На самом деле 'ДА' | TP: 425 | FN: 75 | 500")
print(f"На самом деле 'НЕТ' | FP: 95 | TN: 405 | 500")
print("-----------------|---------------------|----------------------|-------")
print(f"Всего | 520 | 480 | 1000")
print()
print("Расчеты:")
print(f" Точность = (TP + TN) / Всего = (425 + 405) / 1000 = {accuracy:.0%}")
print(f" Точность классификации = TP / (TP + FP) = 425 / 520 = {precision:.0%}")
print(f" Полнота = TP / (TP + FN) = 425 / 500 = {recall:.0%}")
print()
# -------------------------------------------------------------------
# ЧАСТЬ 5: КАК ИНТЕРПРЕТИРОВАТЬ РЕЗУЛЬТАТЫ?
# -------------------------------------------------------------------
print("5. ЧТО ЭТО ЗНАЧИТ ДЛЯ НАШЕЙ ЗАДАЧИ?")
print("-" * 40)
if precision > 0.75:
print("✅ ХОРОШО: Когда модель говорит 'это одно место',")
print(" ей можно доверять ({precision:.0%} случаев она права).")
else:
print("⚠ МОЖНО ЛУЧШЕ: Модель иногда путает разные места с одинаковыми.")
if recall > 0.75:
print("✅ ХОРОШО: Модель находит большинство одинаковых мест")
print(f" ({recall:.0%} настоящих 'одинаковых' мест она находит).")
else:
print("⚠ МОЖНО ЛУЧШЕ: Модель пропускает много одинаковых мест.")
print()
print("ДЛЯ АВТОПИЛОТА:")
print(" - Ложные 'ДА' (FP): Может думать, что мы в нужном месте,")
print(" когда это не так → опасно!")
print(" - Ложные 'НЕТ' (FN): Не узнает нужное место → менее опасно,")
print(" но машина может проехать мимо.")
print()
# -------------------------------------------------------------------
# ЧАСТЬ 6: ГРАФИКИ (ЧТО МЫ ВИДИМ?)
# -------------------------------------------------------------------
print("6. КАКИЕ ГРАФИКИ МЫ ПОЛУЧАЕМ?")
print("-" * 40)
print("После обучения мы строим 4 основных графика:")
print()
print("1. 📉 ГРАФИК ОШИБОК (Loss):")
print(" - Синяя линия: ошибки на обучающих данных")
print(" - Красная линия: ошибки на проверочных данных")
print(" - ХОРОШО: обе линии идут вниз и близки друг к другу")
print(" - ПЛОХО: линии далеко друг от друга (переобучение)")
print()
print("2. 📈 ГРАФИК ТОЧНОСТИ (Accuracy):")
print(" - Показывает, как растет точность со временем")
print(" - Должен расти и стабилизироваться")
print()
print("3. 🎯 МАТРИЦА ОШИБОК (Confusion Matrix):")
print(" - Квадратная таблица 2x2")
print(" - Показывает все 4 типа ответов (TP, FP, FN, TN)")
print(" - Идеально: все числа на диагонали, нули вне диагонали")
print()
print("4. 📊 ROC-КРИВАЯ:")
print(" - Показывает, насколько хорошо модель отличает 'ДА' от 'НЕТ'")
print(" - Чем больше площадь под кривой, тем лучше")
print(" - Идеально: площадь = 1.0 (100%)")
print()
# -------------------------------------------------------------------
# ЧАСТЬ 7: ЧТО ДЕЛАТЬ ДАЛЬШЕ?
# -------------------------------------------------------------------
print("7. ЧТО ДЕЛАТЬ, ЕСЛИ РЕЗУЛЬТАТЫ ПЛОХИЕ?")
print("-" * 40)
print("Если точность меньше 70%:")
print("1. 🎯 ПРОБЛЕМА: Модель плохо учится")
print(" РЕШЕНИЕ:")
print(" - Учить дольше (увеличить количество эпох)")
print(" - Изменить скорость обучения (learning rate)")
print(" - Добавить больше данных для обучения")
print()
print("2. 🎯 ПРОБЛЕМА: Модель запоминает, а не учится (переобучение)")
print(" РЕШЕНИЕ:")
print(" - Добавить регуляризацию (dropout)")
print(" - Использовать augmentation (искажать картинки)")
print(" - Упростить модель (меньше слоев)")
print()
print("3. 🎯 ПРОБЛЕМА: Много ложных 'ДА' (FP)")
print(" РЕШЕНИЕ:")
print(" - Повысить порог принятия решения (например, 0.7 вместо 0.5)")
print(" - Добавить больше примеров 'разных' мест")
print()
print("4. 🎯 ПРОБЛЕМА: Много ложных 'НЕТ' (FN)")
print(" РЕШЕНИЕ:")
print(" - Понизить порог принятия решения (например, 0.3 вместо 0.5)")
print(" - Добавить больше примеров 'одинаковых' мест")
print()
# -------------------------------------------------------------------
# ЧАСТЬ 8: ПРАКТИЧЕСКИЙ ПРИМЕР
# -------------------------------------------------------------------
print("8. ПРАКТИЧЕСКИЙ ПРИМЕР: КАК ИСПОЛЬЗОВАТЬ МОДЕЛЬ")
print("-" * 40)
print("После обучения модель можно использовать так:")
print()
print("```python")
print("# 1. Загружаем обученную модель")
print("model = load_trained_model('best_model.pt')")
print()
print("# 2. Берем две картинки")
print("google_img = load_image('google_map.png')")
print("yandex_img = load_image('yandex_map.png')")
print()
print("# 3. Спрашиваем у модели")
print("similarity_score = model.predict(google_img, yandex_img)")
print()
print("# 4. Интерпретируем результат")
print("if similarity_score > 0.5:")
print(" print('✅ Это похоже на одно и то же место!')")
print("else:")
print(" print('❌ Это разные места')")
print("```")
print()
print(f"Порог 0.5 можно менять:")
print(f" - Порог 0.7: более строгая модель (меньше ложных 'ДА')")
print(f" - Порог 0.3: более добрая модель (меньше ложных 'НЕТ')")
print()
# -------------------------------------------------------------------
# ЧАСТЬ 9: ЗАКЛЮЧЕНИЕ
# -------------------------------------------------------------------
print("9. ЧТО МЫ УЗНАЛИ?")
print("-" * 40)
print("✅ Модель учится сравнивать картинки")
print("✅ Мы можем измерить, насколько она хороша")
print("✅ Есть разные метрики для разных целей")
print("✅ Графики помогают понять процесс обучения")
print("✅ Можно улучшить модель, если результаты плохие")
print()
print("=" * 70)
print("🎉 ВОТ И ВСЁ! ТЕПЕРЬ ТЫ ЗНАЕШЬ, КАК ОЦЕНИВАТЬ МОДЕЛЬ!")
print("=" * 70)
print()
print("Следующие шаги:")
print("1. Запусти evaluation.py чтобы увидеть реальные графики")
print("2. Посмотри на матрицу ошибок - какие ошибки чаще?")
print("3. Попробуй изменить порог принятия решений")
print("4. Если нужно - переобучи модель с другими параметрами")

View File

@@ -0,0 +1,917 @@
import os
import time
from datetime import datetime
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm
# =============================================================================
# TRAINING LOOP WITH VISUALIZATION
# =============================================================================
class SimilarityTrainer:
def __init__(
self,
model: nn.Module,
trainloader: DataLoader,
valloader: DataLoader,
device: torch.device,
config: dict,
):
self.model = model.to(device)
self.trainloader = trainloader
self.valloader = valloader
self.device = device
self.config = config
self.criterion = SimilarityLoss()
self.optimizer = optim.Adam(
model.parameters(),
lr=config.get('learning_rate', 2e-4),
betas=(config.get('beta1', 0.5), config.get('beta2', 0.999))
)
self.writer = None
self.best_val_loss = float('inf')
self.epochs_without_improvement = 0
# Для хранения истории метрик
self.history = {
'train_loss': [],
'val_loss': [],
'val_accuracy': [],
'val_precision': [],
'val_recall': [],
'val_f1': [],
'learning_rate': []
}
def train_epoch(self, epoch: int) -> dict:
"""Обучение на одной эпохе"""
self.model.train()
total_loss = 0
total_samples = 0
all_metrics = []
pbar = tqdm(self.trainloader, desc=f'Epoch {epoch}')
for batch_idx, batch in enumerate(pbar):
google_img = batch['google_img'].to(self.device)
yandex_img = batch['yandex_img'].to(self.device)
target = batch['same_domain'].float().to(self.device).unsqueeze(1)
self.optimizer.zero_grad()
# Forward pass
output = self.model(google_img, yandex_img)
loss = self.criterion(output, target)
# Backward pass
loss.backward()
# Gradient clipping
if self.config.get('grad_clip', None):
torch.nn.utils.clip_grad_norm_(
self.model.parameters(),
self.config['grad_clip']
)
self.optimizer.step()
total_loss += loss.item() * google_img.size(0)
total_samples += google_img.size(0)
# Compute metrics
if batch_idx % self.config.get('log_interval', 10) == 0:
metrics = self.criterion.compute_metrics(output, target)
all_metrics.append(metrics)
pbar.set_postfix({
'loss': f"{loss.item():.4f}",
'acc': f"{metrics['accuracy']:.4f}"
})
if self.writer:
global_step = epoch * len(self.trainloader) + batch_idx
self.writer.add_scalar('train/loss', loss.item(), global_step)
self.writer.add_scalar('train/accuracy', metrics['accuracy'], global_step)
avg_loss = total_loss / total_samples
# Average metrics
if all_metrics:
avg_metrics = {
key: sum(m[key] for m in all_metrics) / len(all_metrics)
for key in all_metrics[0].keys()
}
else:
avg_metrics = {}
return {'loss': avg_loss, **avg_metrics}
def validate(self) -> dict:
"""Валидация модели"""
self.model.eval()
total_loss = 0
total_samples = 0
all_metrics = []
# Для ROC и confusion matrix
all_predictions = []
all_targets = []
with torch.no_grad():
for batch in tqdm(self.valloader, desc='Validation'):
google_img = batch['google_img'].to(self.device)
yandex_img = batch['yandex_img'].to(self.device)
target = batch['same_domain'].float().to(self.device).unsqueeze(1)
output = self.model(google_img, yandex_img)
loss = self.criterion(output, target)
total_loss += loss.item() * google_img.size(0)
total_samples += google_img.size(0)
metrics = self.criterion.compute_metrics(output, target)
all_metrics.append(metrics)
all_predictions.append(output.cpu())
all_targets.append(target.cpu())
avg_loss = total_loss / total_samples
avg_metrics = {
key: sum(m[key] for m in all_metrics) / len(all_metrics)
for key in all_metrics[0].keys()
}
# Concatenate all predictions and targets
all_predictions = torch.cat(all_predictions, dim=0)
all_targets = torch.cat(all_targets, dim=0)
return {
'loss': avg_loss,
**avg_metrics,
'predictions': all_predictions,
'targets': all_targets
}
def train(self, num_epochs: int):
"""Основной цикл обучения"""
log_dir = os.path.join(self.config.get('output_dir', 'runs/similarity'))
os.makedirs(log_dir, exist_ok=True)
self.writer = SummaryWriter(log_dir)
print(f'\n{"="*70}')
print(f'Starting training for {num_epochs} epochs')
print(f'Logging to {log_dir}')
print(f'{"="*70}\n')
start_time = time.time()
for epoch in range(1, num_epochs + 1):
epoch_start = time.time()
print(f'\n--- Epoch {epoch}/{num_epochs} ---')
# Train
train_metrics = self.train_epoch(epoch)
# Validate
val_metrics = self.validate()
# Store history
self.history['train_loss'].append(train_metrics['loss'])
self.history['val_loss'].append(val_metrics['loss'])
self.history['val_accuracy'].append(val_metrics['accuracy'])
self.history['val_precision'].append(val_metrics['precision'])
self.history['val_recall'].append(val_metrics['recall'])
self.history['val_f1'].append(val_metrics['f1'])
self.history['learning_rate'].append(
self.optimizer.param_groups[0]['lr']
)
# Print metrics
print(f'\nTrain Loss: {train_metrics["loss"]:.4f}')
print(f'Val Loss: {val_metrics["loss"]:.4f}')
print(f'Val Accuracy: {val_metrics["accuracy"]:.4f}')
print(f'Val Precision: {val_metrics["precision"]:.4f}')
print(f'Val Recall: {val_metrics["recall"]:.4f}')
print(f'Val F1: {val_metrics["f1"]:.4f}')
epoch_time = time.time() - epoch_start
print(f'Epoch time: {epoch_time:.2f}s')
# TensorBoard logging
if self.writer:
self.writer.add_scalar('epoch/train_loss', train_metrics['loss'], epoch)
self.writer.add_scalar('epoch/val_loss', val_metrics['loss'], epoch)
self.writer.add_scalar('epoch/val_accuracy', val_metrics['accuracy'], epoch)
self.writer.add_scalar('epoch/val_precision', val_metrics['precision'], epoch)
self.writer.add_scalar('epoch/val_recall', val_metrics['recall'], epoch)
self.writer.add_scalar('epoch/val_f1', val_metrics['f1'], epoch)
# Save checkpoint
if val_metrics['loss'] < self.best_val_loss:
self.best_val_loss = val_metrics['loss']
self.epochs_without_improvement = 0
self.save_checkpoint(epoch, val_metrics['loss'], is_best=True)
print(f'✓ New best model saved with val loss: {val_metrics["loss"]:.4f}')
else:
self.epochs_without_improvement += 1
if epoch % self.config.get('save_interval', 5) == 0:
self.save_checkpoint(epoch, val_metrics['loss'], is_best=False)
# Early stopping
patience = self.config.get('early_stopping_patience', 20)
if self.epochs_without_improvement >= patience:
print(f'\n⚠ Early stopping triggered after {patience} epochs without improvement')
break
total_time = time.time() - start_time
print(f'\n{"="*70}')
print(f'Training completed in {total_time/60:.2f} minutes')
print(f'Best validation loss: {self.best_val_loss:.4f}')
print(f'{"="*70}\n')
self.writer.close()
return self.history
def save_checkpoint(self, epoch: int, val_loss: float, is_best: bool = False):
"""Сохранение чекпоинта модели"""
checkpoint_dir = os.path.join(
self.config.get('output_dir', 'runs/similarity'),
'checkpoints'
)
os.makedirs(checkpoint_dir, exist_ok=True)
checkpoint = {
'epoch': epoch,
'model_state_dict': self.model.state_dict(),
'optimizer_state_dict': self.optimizer.state_dict(),
'val_loss': val_loss,
'config': self.config,
'history': self.history
}
checkpoint_path = os.path.join(checkpoint_dir, f'checkpoint_epoch{epoch}.pt')
torch.save(checkpoint, checkpoint_path)
if is_best:
best_path = os.path.join(checkpoint_dir, 'best_model.pt')
torch.save(checkpoint, best_path)
def load_checkpoint(self, checkpoint_path: str):
"""Загрузка чекпоинта"""
checkpoint = torch.load(checkpoint_path, map_location=self.device)
self.model.load_state_dict(checkpoint['model_state_dict'])
self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
if 'history' in checkpoint:
self.history = checkpoint['history']
return checkpoint['epoch'], checkpoint['val_loss']
# =============================================================================
# VISUALIZATION FUNCTIONS
# =============================================================================
def plot_training_history(history: dict, save_path: str = None):
"""Построение графиков обучения"""
fig, axes = plt.subplots(2, 3, figsize=(18, 10))
fig.suptitle('Training History - Siamese Network для корреляции снимков',
fontsize=16, fontweight='bold')
epochs = range(1, len(history['train_loss']) + 1)
# Loss
axes[0, 0].plot(epochs, history['train_loss'], 'b-', label='Train Loss', linewidth=2)
axes[0, 0].plot(epochs, history['val_loss'], 'r-', label='Val Loss', linewidth=2)
axes[0, 0].set_xlabel('Epoch')
axes[0, 0].set_ylabel('Loss')
axes[0, 0].set_title('Loss Curves')
axes[0, 0].legend()
axes[0, 0].grid(True, alpha=0.3)
# Accuracy
axes[0, 1].plot(epochs, history['val_accuracy'], 'g-', linewidth=2)
axes[0, 1].set_xlabel('Epoch')
axes[0, 1].set_ylabel('Accuracy')
axes[0, 1].set_title('Validation Accuracy')
axes[0, 1].grid(True, alpha=0.3)
axes[0, 1].set_ylim([0, 1])
# F1 Score
axes[0, 2].plot(epochs, history['val_f1'], 'm-', linewidth=2)
axes[0, 2].set_xlabel('Epoch')
axes[0, 2].set_ylabel('F1 Score')
axes[0, 2].set_title('Validation F1 Score')
axes[0, 2].grid(True, alpha=0.3)
axes[0, 2].set_ylim([0, 1])
# Precision
axes[1, 0].plot(epochs, history['val_precision'], 'c-', linewidth=2)
axes[1, 0].set_xlabel('Epoch')
axes[1, 0].set_ylabel('Precision')
axes[1, 0].set_title('Validation Precision')
axes[1, 0].grid(True, alpha=0.3)
axes[1, 0].set_ylim([0, 1])
# Recall
axes[1, 1].plot(epochs, history['val_recall'], 'y-', linewidth=2)
axes[1, 1].set_xlabel('Epoch')
axes[1, 1].set_ylabel('Recall')
axes[1, 1].set_title('Validation Recall')
axes[1, 1].grid(True, alpha=0.3)
axes[1, 1].set_ylim([0, 1])
# Learning Rate
axes[1, 2].plot(epochs, history['learning_rate'], 'k-', linewidth=2)
axes[1, 2].set_xlabel('Epoch')
axes[1, 2].set_ylabel('Learning Rate')
axes[1, 2].set_title('Learning Rate Schedule')
axes[1, 2].grid(True, alpha=0.3)
axes[1, 2].set_yscale('log')
plt.tight_layout()
if save_path:
plt.savefig(save_path, dpi=300, bbox_inches='tight')
print(f'Training history plot saved to {save_path}')
plt.show()
def plot_roc_curve(predictions: torch.Tensor, targets: torch.Tensor, save_path: str = None):
"""Построение ROC кривой"""
from sklearn.metrics import roc_curve, auc
predictions_np = predictions.numpy().flatten()
targets_np = targets.numpy().flatten()
fpr, tpr, thresholds = roc_curve(targets_np, predictions_np)
roc_auc = auc(fpr, tpr)
plt.figure(figsize=(10, 8))
plt.plot(fpr, tpr, color='darkorange', lw=2,
label=f'ROC curve (AUC = {roc_auc:.3f})')
plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--', label='Random')
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('False Positive Rate', fontsize=12)
plt.ylabel('True Positive Rate', fontsize=12)
plt.title('ROC Curve - Siamese Network', fontsize=14, fontweight='bold')
plt.legend(loc="lower right", fontsize=12)
plt.grid(True, alpha=0.3)
if save_path:
plt.savefig(save_path, dpi=300, bbox_inches='tight')
print(f'ROC curve saved to {save_path}')
plt.show()
return roc_auc
def plot_confusion_matrix(predictions: torch.Tensor, targets: torch.Tensor,
threshold: float = 0.5, save_path: str = None):
"""Построение матрицы ошибок"""
from sklearn.metrics import confusion_matrix
predictions_binary = (predictions.numpy().flatten() >= threshold).astype(int)
targets_binary = (targets.numpy().flatten() >= 0.5).astype(int)
cm = confusion_matrix(targets_binary, predictions_binary)
plt.figure(figsize=(10, 8))
plt.imshow(cm, interpolation='nearest', cmap=plt.cm.Blues)
plt.title('Confusion Matrix - Correlation Detection', fontsize=14, fontweight='bold')
plt.colorbar()
classes = ['Different Domains', 'Same Domain']
tick_marks = np.arange(len(classes))
plt.xticks(tick_marks, classes, rotation=45, fontsize=12)
plt.yticks(tick_marks, classes, fontsize=12)
# Add text annotations
thresh = cm.max() / 2.
for i in range(cm.shape[0]):
for j in range(cm.shape[1]):
plt.text(j, i, format(cm[i, j], 'd'),
ha="center", va="center",
color="white" if cm[i, j] > thresh else "black",
fontsize=16, fontweight='bold')
plt.ylabel('True Label', fontsize=12)
plt.xlabel('Predicted Label', fontsize=12)
plt.tight_layout()
if save_path:
plt.savefig(save_path, dpi=300, bbox_inches='tight')
print(f'Confusion matrix saved to {save_path}')
plt.show()
def plot_similarity_distribution(predictions: torch.Tensor, targets: torch.Tensor,
save_path: str = None):
"""Распределение предсказанных значений схожести"""
predictions_np = predictions.numpy().flatten()
targets_np = targets.numpy().flatten()
same_domain = predictions_np[targets_np >= 0.5]
diff_domain = predictions_np[targets_np < 0.5]
plt.figure(figsize=(12, 6))
plt.subplot(1, 2, 1)
plt.hist(same_domain, bins=50, alpha=0.7, color='green', edgecolor='black', label='Same Domain')
plt.hist(diff_domain, bins=50, alpha=0.7, color='red', edgecolor='black', label='Different Domains')
plt.xlabel('Predicted Similarity Score', fontsize=12)
plt.ylabel('Frequency', fontsize=12)
plt.title('Distribution of Similarity Scores', fontsize=14, fontweight='bold')
plt.legend(fontsize=11)
plt.grid(True, alpha=0.3)
plt.subplot(1, 2, 2)
plt.boxplot([diff_domain, same_domain], labels=['Different', 'Same'])
plt.ylabel('Similarity Score', fontsize=12)
plt.xlabel('Domain Match', fontsize=12)
plt.title('Similarity Score by Domain Match', fontsize=14, fontweight='bold')
plt.grid(True, alpha=0.3, axis='y')
plt.tight_layout()
if save_path:
plt.savefig(save_path, dpi=300, bbox_inches='tight')
print(f'Similarity distribution plot saved to {save_path}')
plt.show()
# Print statistics
print(f'\n--- Similarity Score Statistics ---')
print(f'Same Domain:')
print(f' Mean: {same_domain.mean():.4f}')
print(f' Std: {same_domain.std():.4f}')
print(f' Min: {same_domain.min():.4f}')
print(f' Max: {same_domain.max():.4f}')
print(f'\nDifferent Domains:')
print(f' Mean: {diff_domain.mean():.4f}')
print(f' Std: {diff_domain.std():.4f}')
print(f' Min: {diff_domain.min():.4f}')
print(f' Max: {diff_domain.max():.4f}')
def visualize_sample_predictions(model: nn.Module, dataset, device: torch.device,
num_samples: int = 8, save_path: str = None):
"""Визуализация примеров предсказаний"""
model.eval()
# Get random samples
indices = np.random.choice(len(dataset), num_samples, replace=False)
fig, axes = plt.subplots(num_samples, 3, figsize=(15, 5*num_samples))
if num_samples == 1:
axes = axes.reshape(1, -1)
fig.suptitle('Sample Predictions - Siamese Network для корреляции карт',
fontsize=16, fontweight='bold')
with torch.no_grad():
for idx, sample_idx in enumerate(indices):
sample = dataset[sample_idx]
google_img = sample['google_img'].unsqueeze(0).to(device)
yandex_img = sample['yandex_img'].unsqueeze(0).to(device)
true_label = sample['same_domain'].item()
# Predict
pred_similarity = model(google_img, yandex_img).item()
pred_label = int(pred_similarity >= 0.5)
# Denormalize images for visualization
google_np = google_img.squeeze(0).cpu().numpy().transpose(1, 2, 0)
yandex_np = yandex_img.squeeze(0).cpu().numpy().transpose(1, 2, 0)
# Denormalize (assuming ImageNet normalization)
mean = np.array([0.485, 0.456, 0.406])
std = np.array([0.229, 0.224, 0.225])
google_np = std * google_np + mean
yandex_np = std * yandex_np + mean
google_np = np.clip(google_np, 0, 1)
yandex_np = np.clip(yandex_np, 0, 1)
# Plot Google image
axes[idx, 0].imshow(google_np)
axes[idx, 0].set_title('Google Map', fontsize=12, fontweight='bold')
axes[idx, 0].axis('off')
# Plot Yandex image
axes[idx, 1].imshow(yandex_np)
axes[idx, 1].set_title('Yandex Map', fontsize=12, fontweight='bold')
axes[idx, 1].axis('off')
# Plot prediction info
axes[idx, 2].axis('off')
# Determine color based on correctness
is_correct = (pred_label == true_label)
color = 'green' if is_correct else 'red'
result = '✓ Correct' if is_correct else '✗ Incorrect'
info_text = f"""
Prediction: {pred_similarity:.4f}
Predicted Label: {'Same' if pred_label == 1 else 'Different'}
True Label: {'Same' if true_label == 1 else 'Different'}
{result}
"""
axes[idx, 2].text(0.5, 0.5, info_text,
ha='center', va='center',
fontsize=12,
bbox=dict(boxstyle='round', facecolor=color, alpha=0.2),
transform=axes[idx, 2].transAxes)
plt.tight_layout()
if save_path:
plt.savefig(save_path, dpi=300, bbox_inches='tight')
print(f'Sample predictions saved to {save_path}')
plt.show()
def visualize_feature_space(model: nn.Module, dataloader: DataLoader,
device: torch.device, max_samples: int = 500,
save_path: str = None):
"""Визуализация пространства признаков с помощью t-SNE"""
from sklearn.manifold import TSNE
model.eval()
all_features_google = []
all_features_yandex = []
all_labels = []
with torch.no_grad():
for i, batch in enumerate(tqdm(dataloader, desc='Extracting features')):
if i * dataloader.batch_size >= max_samples:
break
google_img = batch['google_img'].to(device)
yandex_img = batch['yandex_img'].to(device)
labels = batch['same_domain'].cpu().numpy()
# Extract features
features_google = model.extract_features(google_img).cpu().numpy()
features_yandex = model.extract_features(yandex_img).cpu().numpy()
all_features_google.append(features_google)
all_features_yandex.append(features_yandex)
all_labels.append(labels)
all_features_google = np.concatenate(all_features_google, axis=0)
all_features_yandex = np.concatenate(all_features_yandex, axis=0)
all_labels = np.concatenate(all_labels, axis=0)
# Combine features
all_features = np.concatenate([all_features_google, all_features_yandex], axis=0)
all_labels = np.concatenate([all_labels, all_labels], axis=0)
print(f'\nApplying t-SNE to {all_features.shape[0]} samples...')
tsne = TSNE(n_components=2, random_state=42, perplexity=30)
features_2d = tsne.fit_transform(all_features)
# Split back into Google and Yandex
n_samples = len(all_labels) // 2
features_google_2d = features_2d[:n_samples]
features_yandex_2d = features_2d[n_samples:]
labels = all_labels[:n_samples]
# Plot
fig, axes = plt.subplots(1, 2, figsize=(20, 8))
# Google features
for label in [0, 1]:
mask = labels == label
axes[0].scatter(
features_google_2d[mask, 0],
features_google_2d[mask, 1],
c='green' if label == 1 else 'red',
label='Same Domain' if label == 1 else 'Different Domains',
alpha=0.6,
s=50
)
axes[0].set_title('Google Maps Features (t-SNE)', fontsize=14, fontweight='bold')
axes[0].set_xlabel('t-SNE Component 1', fontsize=12)
axes[0].set_ylabel('t-SNE Component 2', fontsize=12)
axes[0].legend(fontsize=11)
axes[0].grid(True, alpha=0.3)
# Yandex features
for label in [0, 1]:
mask = labels == label
axes[1].scatter(
features_yandex_2d[mask, 0],
features_yandex_2d[mask, 1],
c='green' if label == 1 else 'red',
label='Same Domain' if label == 1 else 'Different Domains',
alpha=0.6,
s=50
)
axes[1].set_title('Yandex Maps Features (t-SNE)', fontsize=14, fontweight='bold')
axes[1].set_xlabel('t-SNE Component 1', fontsize=12)
axes[1].set_ylabel('t-SNE Component 2', fontsize=12)
axes[1].legend(fontsize=11)
axes[1].grid(True, alpha=0.3)
plt.tight_layout()
if save_path:
plt.savefig(save_path, dpi=300, bbox_inches='tight')
print(f'Feature space visualization saved to {save_path}')
plt.show()
def generate_correlation_heatmap(model: nn.Module, dataloader: DataLoader,
device: torch.device, num_samples: int = 20,
save_path: str = None):
"""Создание тепловой карты корреляций между снимками"""
model.eval()
# Collect samples
google_images = []
yandex_images = []
labels = []
with torch.no_grad():
for i, batch in enumerate(dataloader):
if len(google_images) >= num_samples:
break
google_img = batch['google_img'].to(device)
yandex_img = batch['yandex_img'].to(device)
label = batch['same_domain']
google_images.append(google_img[:1])
yandex_images.append(yandex_img[:1])
labels.append(label[:1].item())
google_images = torch.cat(google_images[:num_samples], dim=0)
yandex_images = torch.cat(yandex_images[:num_samples], dim=0)
# Compute similarity matrix
similarity_matrix = np.zeros((num_samples, num_samples))
with torch.no_grad():
for i in tqdm(range(num_samples), desc='Computing correlations'):
for j in range(num_samples):
google_i = google_images[i:i+1]
yandex_j = yandex_images[j:j+1]
similarity = model(google_i, yandex_j).item()
similarity_matrix[i, j] = similarity
# Plot heatmap
plt.figure(figsize=(14, 12))
im = plt.imshow(similarity_matrix, cmap='RdYlGn', aspect='auto', vmin=0, vmax=1)
plt.colorbar(im, label='Similarity Score', fraction=0.046, pad=0.04)
plt.title('Correlation Heatmap: Google vs Yandex Maps\n(Матрица корреляций снимков)',
fontsize=16, fontweight='bold', pad=20)
plt.xlabel('Yandex Map Index', fontsize=12)
plt.ylabel('Google Map Index', fontsize=12)
# Add grid
plt.grid(True, which='both', color='gray', linestyle='-', linewidth=0.5, alpha=0.3)
# Add text annotations for diagonal (same pairs)
for i in range(min(num_samples, 10)): # Annotate first 10 for readability
if labels[i] == 1: # True match
plt.text(i, i, '', ha='center', va='center',
color='white', fontsize=12, fontweight='bold')
plt.tight_layout()
if save_path:
plt.savefig(save_path, dpi=300, bbox_inches='tight')
print(f'Correlation heatmap saved to {save_path}')
plt.show()
# Print statistics
diagonal = np.diag(similarity_matrix)
off_diagonal = similarity_matrix[~np.eye(num_samples, dtype=bool)]
print(f'\n--- Correlation Statistics ---')
print(f'Diagonal (matched pairs):')
print(f' Mean: {diagonal.mean():.4f}')
print(f' Std: {diagonal.std():.4f}')
print(f'\nOff-diagonal (mismatched pairs):')
print(f' Mean: {off_diagonal.mean():.4f}')
print(f' Std: {off_diagonal.std():.4f}')
# =============================================================================
# MAIN TRAINING SCRIPT
# =============================================================================
def main():
"""Основная функция обучения"""
# Configuration
config_dict = config.copy()
if isinstance(config_dict.get('image_size'), list):
config_dict['image_size'] = tuple(config_dict['image_size'])
# Device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'\n{"="*70}')
print(f'Siamese Network Training for Map Correlation')
print(f'Обучение сиамской сети для корреляции снимков')
print(f'{"="*70}')
print(f'Using device: {device}')
if torch.cuda.is_available():
print(f'GPU: {torch.cuda.get_device_name(0)}')
print(f'GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB')
# Create data loaders
print(f'\n{"="*70}')
print('Creating data loaders...')
print(f'{"="*70}')
train_loader, val_loader = create_data_loaders(
root_dir=config_dict['data_dir'],
batch_size=config_dict['batch_size'],
train_split=config_dict['train_split'],
num_workers=config_dict['num_workers'],
image_size=config_dict['image_size'],
augment_train=True,
augment_val=False,
device=device
)
print(f'Train batches: {len(train_loader)}')
print(f'Val batches: {len(val_loader)}')
print(f'Train samples: {len(train_loader.dataset)}')
print(f'Val samples: {len(val_loader.dataset)}')
# Create model
print(f'\n{"="*70}')
print('Creating model...')
print(f'{"="*70}')
model = create_similarity_model(
model_type='backbone',
input_size=config_dict['image_size'][0] if isinstance(config_dict['image_size'], (tuple, list)) else config_dict['image_size'],
input_channels=3,
backbone_name='resnet18',
pretrained=True,
dropout_rate=0.3,
use_batch_norm=True
)
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f'Total parameters: {total_params:,}')
print(f'Trainable parameters: {trainable_params:,}')
# Create trainer
trainer = SimilarityTrainer(
model=model,
trainloader=train_loader,
valloader=val_loader,
device=device,
config=config_dict
)
# Train model
print(f'\n{"="*70}')
print('Starting training...')
print(f'{"="*70}')
history = trainer.train(config_dict['epochs'])
# =============================================================================
# VISUALIZATION AND RESULTS
# =============================================================================
print(f'\n{"="*70}')
print('Generating visualizations...')
print(f'{"="*70}')
output_dir = config_dict.get('output_dir', 'runs/similarity')
vis_dir = os.path.join(output_dir, 'visualizations')
os.makedirs(vis_dir, exist_ok=True)
# 1. Training history
print('\n1. Plotting training history...')
plot_training_history(
history,
save_path=os.path.join(vis_dir, 'training_history.png')
)
# 2. Validation metrics
print('\n2. Computing validation predictions...')
trainer.model.eval()
val_predictions = []
val_targets = []
with torch.no_grad():
for batch in tqdm(val_loader, desc='Validation'):
google_img = batch['google_img'].to(device)
yandex_img = batch['yandex_img'].to(device)
target = batch['same_domain'].float().unsqueeze(1)
output = trainer.model(google_img, yandex_img)
val_predictions.append(output.cpu())
val_targets.append(target.cpu())
val_predictions = torch.cat(val_predictions, dim=0)
val_targets = torch.cat(val_targets, dim=0)
# 3. ROC curve
print('\n3. Plotting ROC curve...')
roc_auc = plot_roc_curve(
val_predictions,
val_targets,
save_path=os.path.join(vis_dir, 'roc_curve.png')
)
print(f'ROC AUC Score: {roc_auc:.4f}')
# 4. Confusion matrix
print('\n4. Plotting confusion matrix...')
plot_confusion_matrix(
val_predictions,
val_targets,
threshold=0.5,
save_path=os.path.join(vis_dir, 'confusion_matrix.png')
)
# 5. Similarity distribution
print('\n5. Plotting similarity distribution...')
plot_similarity_distribution(
val_predictions,
val_targets,
save_path=os.path.join(vis_dir, 'similarity_distribution.png')
)
# 6. Sample predictions
print('\n6. Visualizing sample predictions...')
visualize_sample_predictions(
trainer.model,
val_loader.dataset,
device,
num_samples=8,
save_path=os.path.join(vis_dir, 'sample_predictions.png')
)
# 7. Feature space visualization
print('\n7. Visualizing feature space (t-SNE)...')
visualize_feature_space(
trainer.model,
val_loader,
device,
max_samples=500,
save_path=os.path.join(vis_dir, 'feature_space_tsne.png')
)
# 8. Correlation heatmap
print('\n8. Generating correlation heatmap...')
generate_correlation_heatmap(
trainer.model,
val_loader,
device,
num_samples=20,
save_path=os.path.join(vis_dir, 'correlation_heatmap.png')
)
# =============================================================================
# FINAL RESULTS SUMMARY
# =============================================================================
print(f'\n{"="*70}')
print('FINAL RESULTS SUMMARY')
print('ИТОГОВЫЕ РЕЗУЛЬТАТЫ')
print(f'{"="*70}')
print(f'\nBest Validation Loss: {trainer.best_val_loss:.4f}')
print(f'Final Validation Accuracy: {history["val_accuracy"][-1]:.4f}')
print(f'Final Validation F1 Score: {history["val_f1"][-1]:.4f}')
print(f'Final Validation Precision: {history["val_precision"][-1]:.4f}')
print(f'Final Validation Recall: {history["val_recall"][-1]:.4f}')
print(f'ROC AUC Score: {roc_auc:.4f}')
print(f'\nCheckpoints saved to: {os.path.join(output_dir, "checkpoints")}')
print(f'Visualizations saved to: {vis_dir}')
print(f'\n{"="*70}')
print('Training and visualization completed successfully!')
print('Обучение и визуализация завершены успешно!')
print(f'{"="*70}\n')
if __name__ == '__main__':
main()

View File

@@ -2,7 +2,6 @@
Training script for image similarity estimation.
"""
import argparse
import os
import time
from datetime import datetime
@@ -10,7 +9,7 @@ from datetime import datetime
import torch
import torch.nn as nn
import torch.optim as optim
from dataloader import create_data_loaders
from dataloader import config, create_data_loaders
from model import SimilarityCNN, SimilarityLoss, create_similarity_model
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
@@ -191,51 +190,23 @@ class SimilarityTrainer:
def main():
parser = argparse.ArgumentParser(description="Train similarity estimation model")
parser.add_argument(
"--data_dir",
type=str,
default=r"C:\Users\admin\Projects\autopilot\datasets\ya_go_maps\images",
)
parser.add_argument("--batch_size", type=int, default=32)
parser.add_argument("--epochs", type=int, default=100)
parser.add_argument("--learning_rate", type=float, default=2e-4)
parser.add_argument("--image_size", type=int, default=256)
parser.add_argument("--train_split", type=float, default=0.8)
parser.add_argument("--output_dir", type=str, default="runs/similarity")
parser.add_argument("--num_workers", type=int, default=0)
parser.add_argument(
"--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu"
)
# Use config from dataloader.py
config_dict = config.copy()
args = parser.parse_args()
# Ensure image_size is tuple
if isinstance(config_dict.get("image_size"), list):
config_dict["image_size"] = tuple(config_dict["image_size"])
config = {
"data_dir": args.data_dir,
"batch_size": args.batch_size,
"epochs": args.epochs,
"learning_rate": args.learning_rate,
"image_size": (args.image_size, args.image_size),
"train_split": args.train_split,
"output_dir": args.output_dir,
"num_workers": args.num_workers,
"log_interval": 10,
"save_interval": 5,
"early_stopping_patience": 20,
"beta1": 0.5,
"beta2": 0.999,
}
device = torch.device(args.device)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
print("Creating data loaders...")
train_loader, val_loader = create_data_loaders(
root_dir=config["data_dir"],
batch_size=config["batch_size"],
train_split=config["train_split"],
num_workers=config["num_workers"],
image_size=config["image_size"],
root_dir=config_dict["data_dir"],
batch_size=config_dict["batch_size"],
train_split=config_dict["train_split"],
num_workers=config_dict["num_workers"],
image_size=config_dict["image_size"],
augment_train=True,
augment_val=False,
device=device,
@@ -247,7 +218,9 @@ def main():
print("Creating model...")
model = create_similarity_model(
model_type="cnn",
input_size=config["image_size"],
input_size=config_dict["image_size"][0]
if isinstance(config_dict["image_size"], (tuple, list))
else config_dict["image_size"],
input_channels=3,
hidden_channels=64,
num_blocks=4,
@@ -262,11 +235,11 @@ def main():
train_loader=train_loader,
val_loader=val_loader,
device=device,
config=config,
config=config_dict,
)
print("Starting training...")
trainer.train(config["epochs"])
trainer.train(config_dict["epochs"])
print("Training completed!")