diff --git a/models/SiaN-similarity/README.md b/models/SiaN-similarity/README.md deleted file mode 100644 index 28c31da..0000000 --- a/models/SiaN-similarity/README.md +++ /dev/null @@ -1,131 +0,0 @@ -# SiaN-Similarity: Модель для оценки схожести изображений - -Модель для оценки схожести между двумя изображениями 256x256. Возвращает значение от 0 до 1, где 1 означает полную схожесть, 0 - полное различие. - -## Архитектура модели - -Модель основана на CNN с residual блоками: -- Общий энкодер для обоих изображений -- Residual blocks с batch normalization -- Слой слияния признаков -- Регрессионная голова с сигмоидой на выходе - -## Использование - -### Установка зависимостей -```bash -pip install torch torchvision pillow -``` - -### Быстрый старт - -```python -import torch -from model import SimilarityCNN - -# Создание модели -model = SimilarityCNN( - input_channels=3, - hidden_channels=64, - num_blocks=4, - dropout_rate=0.3, - use_batch_norm=True, -) - -# Предсказание схожести -img1 = torch.randn(1, 3, 256, 256) # Изображение 1 -img2 = torch.randn(1, 3, 256, 256) # Изображение 2 - -similarity = model.predict_similarity(img1, img2) -print(f"Схожесть: {similarity.item():.4f}") -``` - -### Обучение модели - -```bash -python train_similarity.py \ - --data_dir "путь/к/данным" \ - --batch_size 32 \ - --epochs 100 \ - --learning_rate 2e-4 \ - --output_dir "runs/similarity" -``` - -### Предсказание на новых изображениях - -```bash -python predict.py \ - --image1 "путь/к/изображению1.png" \ - --image2 "путь/к/изображению2.png" \ - --checkpoint "runs/similarity/checkpoints/best_model.pt" -``` - -## Структура проекта - -``` -SiaN-similarity/ -├── model.py # Основная модель -├── dataloader.py # Даталоадер для обучения -├── train_similarity.py # Скрипт для обучения -├── predict.py # Скрипт для предсказания -├── train.py # Оригинальный тренировочный скрипт -└── README.md # Этот файл -``` - -## Конфигурация модели - -Параметры по умолчанию: -- `input_channels`: 3 (RGB) -- `hidden_channels`: 64 -- `num_blocks`: 4 -- `dropout_rate`: 0.3 -- `use_batch_norm`: True -- `image_size`: (256, 256) - -## Формат данных - -Модель ожидает изображения размером 256x256 пикселей в формате RGB. -Для обучения используется датасет с парами изображений и метками схожести. - -## Примеры использования - -### 1. Создание и тестирование модели -```python -from model import create_similarity_model - -model = create_similarity_model( - model_type="cnn", - input_size=(256, 256), - hidden_channels=32, - num_blocks=3, -) -``` - -### 2. Использование функции потерь -```python -from model import SimilarityLoss - -loss_fn = SimilarityLoss() -pred = torch.tensor([[0.8], [0.2]]) -target = torch.tensor([[1.0], [0.0]]) -loss = loss_fn(pred, target) -``` - -### 3. Расчет метрик -```python -metrics = loss_fn.compute_metrics(pred, target) -print(f"Accuracy: {metrics['accuracy']:.4f}") -print(f"F1-score: {metrics['f1']:.4f}") -``` - -## Требования - -- Python 3.8+ -- PyTorch 1.9+ -- torchvision -- Pillow -- numpy - -## Лицензия - -MIT \ No newline at end of file diff --git a/models/SiaN-similarity/demo.py b/models/SiaN-similarity/demo.py deleted file mode 100644 index 53715fd..0000000 --- a/models/SiaN-similarity/demo.py +++ /dev/null @@ -1,192 +0,0 @@ -""" -Демонстрационный скрипт для модели оценки схожести изображений. -""" - -import matplotlib.pyplot as plt -import numpy as np -import torch -from model import SimilarityCNN, SimilarityLoss -from PIL import Image, ImageDraw, ImageFont -from torchvision import transforms - - -def create_test_images(): - """Создание тестовых изображений для демонстрации.""" - images = [] - - # Изображение 1: Красный квадрат - img1 = Image.new("RGB", (256, 256), color="white") - draw = ImageDraw.Draw(img1) - draw.rectangle([50, 50, 200, 200], fill="red", outline="black", width=2) - images.append(("Красный квадрат", img1)) - - # Изображение 2: Тот же красный квадрат (похожее) - img2 = Image.new("RGB", (256, 256), color="white") - draw = ImageDraw.Draw(img2) - draw.rectangle([55, 55, 205, 205], fill="red", outline="black", width=2) - images.append(("Похожий красный квадрат", img2)) - - # Изображение 3: Синий круг (разное) - img3 = Image.new("RGB", (256, 256), color="white") - draw = ImageDraw.Draw(img3) - draw.ellipse([50, 50, 200, 200], fill="blue", outline="black", width=2) - images.append(("Синий круг", img3)) - - # Изображение 4: Зеленый треугольник (разное) - img4 = Image.new("RGB", (256, 256), color="white") - draw = ImageDraw.Draw(img4) - draw.polygon( - [(128, 50), (50, 200), (200, 200)], fill="green", outline="black", width=2 - ) - images.append(("Зеленый треугольник", img4)) - - return images - - -def preprocess_image(image): - """Преобразование PIL Image в тензор.""" - transform = transforms.Compose( - [ - transforms.ToTensor(), - transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), - ] - ) - return transform(image).unsqueeze(0) # Добавляем batch dimension - - -def display_results(images, similarities): - """Отображение результатов сравнения.""" - fig, axes = plt.subplots(2, 2, figsize=(12, 10)) - axes = axes.flatten() - - for idx, (title, img) in enumerate(images): - ax = axes[idx] - ax.imshow(img) - ax.set_title(title, fontsize=12, fontweight="bold") - ax.axis("off") - - plt.suptitle("Тестовые изображения", fontsize=16, fontweight="bold") - plt.tight_layout() - plt.show() - - # Вывод результатов сравнения - print("\n" + "=" * 60) - print("РЕЗУЛЬТАТЫ СРАВНЕНИЯ ИЗОБРАЖЕНИЙ") - print("=" * 60) - - comparisons = [ - ("Красный квадрат", "Похожий красный квадрат"), - ("Красный квадрат", "Синий круг"), - ("Красный квадрат", "Зеленый треугольник"), - ("Похожий красный квадрат", "Синий круг"), - ] - - for i, (name1, name2) in enumerate(comparisons): - idx1 = [idx for idx, (name, _) in enumerate(images) if name == name1][0] - idx2 = [idx for idx, (name, _) in enumerate(images) if name == name2][0] - - sim = similarities[idx1, idx2] - interpretation = "ПОХОЖИ" if sim > 0.5 else "РАЗНЫЕ" - - print(f"\n{name1} vs {name2}:") - print(f" Схожесть: {sim:.4f}") - print(f" Интерпретация: {interpretation}") - print(f" Уверенность: {'Высокая' if sim > 0.7 or sim < 0.3 else 'Средняя'}") - - -def test_loss_function(): - """Тестирование функции потерь.""" - print("\n" + "=" * 60) - print("ТЕСТИРОВАНИЕ ФУНКЦИИ ПОТЕРЬ") - print("=" * 60) - - loss_fn = SimilarityLoss() - - # Тестовые данные - predictions = torch.tensor([[0.9], [0.1], [0.7], [0.3]]) - targets = torch.tensor([[1.0], [0.0], [1.0], [0.0]]) - - # Расчет потерь - loss = loss_fn(predictions, targets) - print(f"\nПотери: {loss.item():.4f}") - - # Расчет метрик - metrics = loss_fn.compute_metrics(predictions, targets) - print("\nМетрики:") - for key, value in metrics.items(): - print(f" {key}: {value:.4f}") - - -def main(): - """Основная функция демонстрации.""" - print("ДЕМОНСТРАЦИЯ МОДЕЛИ ОЦЕНКИ СХОЖЕСТИ ИЗОБРАЖЕНИЙ") - print("=" * 60) - - # Создание модели - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - print(f"\nУстройство: {device}") - - model = SimilarityCNN( - input_channels=3, - hidden_channels=64, - num_blocks=4, - dropout_rate=0.3, - use_batch_norm=True, - ).to(device) - - print(f"Параметры модели: {sum(p.numel() for p in model.parameters()):,}") - - # Создание тестовых изображений - print("\nСоздание тестовых изображений...") - test_images = create_test_images() - - # Преобразование изображений в тензоры - tensors = [] - for name, img in test_images: - tensor = preprocess_image(img).to(device) - tensors.append(tensor) - - # Расчет схожести между всеми парами изображений - print("\nРасчет схожести между изображениями...") - n_images = len(test_images) - similarity_matrix = np.zeros((n_images, n_images)) - - model.eval() - with torch.no_grad(): - for i in range(n_images): - for j in range(n_images): - if i <= j: # Рассчитываем только верхний треугольник - sim = model.predict_similarity(tensors[i], tensors[j]) - similarity_matrix[i, j] = sim.item() - similarity_matrix[j, i] = sim.item() # Симметричная матрица - - # Отображение результатов - display_results(test_images, similarity_matrix) - - # Тестирование функции потерь - test_loss_function() - - # Дополнительная информация - print("\n" + "=" * 60) - print("ИНФОРМАЦИЯ О МОДЕЛИ") - print("=" * 60) - print("\nАрхитектура модели:") - print("-" * 40) - print("Вход: два изображения 256x256x3") - print("Энкодер: CNN с residual блоками") - print("Слой слияния: объединение признаков") - print("Выход: значение схожести [0, 1]") - print("\nИнтерпретация результатов:") - print("- 0.8-1.0: Очень похожи") - print("- 0.6-0.8: Похожи") - print("- 0.4-0.6: Нейтрально") - print("- 0.2-0.4: Разные") - print("- 0.0-0.2: Совершенно разные") - - print("\n" + "=" * 60) - print("ДЕМОНСТРАЦИЯ ЗАВЕРШЕНА") - print("=" * 60) - - -if __name__ == "__main__": - main() diff --git a/models/SiaN-similarity/demo_evaluation.ipynb.py b/models/SiaN-similarity/demo_evaluation.ipynb.py deleted file mode 100644 index e54966d..0000000 --- a/models/SiaN-similarity/demo_evaluation.ipynb.py +++ /dev/null @@ -1,340 +0,0 @@ -""" -Demo Evaluation Notebook-style File -==================================== - -This file demonstrates how to use the evaluation functions from evaluation.py -in a notebook-like style. You can run this file directly to see all the plots -and analysis. - -Think of this as the next cell in your notebook after training! -""" - -import os -import sys - -# Add the current directory to the path so we can import our modules -sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) - -# Import our evaluation module -import matplotlib.pyplot as plt -import numpy as np - -# Import other necessary modules -import torch -from dataloader import config, create_data_loaders -from evaluation import ( - analyze_model_performance, - generate_performance_report, - plot_confusion_matrix, - plot_probability_distribution, - plot_roc_curve, - plot_training_metrics, - test_model_on_examples, -) -from model import create_similarity_model - -print("=" * 70) -print("DEMO: EVALUATING IMAGE SIMILARITY MODEL") -print("=" * 70) -print("\nThis demo shows you how to analyze your trained model.") -print("Think of this as the 'results' section of your notebook!\n") - -# ============================================================================ -# STEP 1: SETUP -# ============================================================================ -print("STEP 1: Setting up the environment") -print("-" * 40) - -# Check if GPU is available -device = torch.device("cuda" if torch.cuda.is_available() else "cpu") -print(f"✓ Using device: {device}") - -# Load configuration -config_dict = config.copy() -if isinstance(config_dict.get("image_size"), list): - config_dict["image_size"] = tuple(config_dict["image_size"]) - -print(f"✓ Image size: {config_dict['image_size']}") -print(f"✓ Batch size: {config_dict['batch_size']}") - -# ============================================================================ -# STEP 2: LOAD DATA -# ============================================================================ -print("\nSTEP 2: Loading validation data") -print("-" * 40) - -# Create validation data loader -_, val_loader = create_data_loaders( - root_dir=config_dict["data_dir"], - batch_size=config_dict["batch_size"], - train_split=config_dict["train_split"], - num_workers=config_dict["num_workers"], - image_size=config_dict["image_size"], - augment_train=False, - augment_val=False, - device=device, -) - -print(f"✓ Validation batches loaded: {len(val_loader)}") -print(f"✓ Each batch has {config_dict['batch_size']} image pairs") - -# ============================================================================ -# STEP 3: LOAD TRAINED MODEL -# ============================================================================ -print("\nSTEP 3: Loading the trained model") -print("-" * 40) - -# Create model architecture -model = create_similarity_model( - model_type="cnn", - input_size=config_dict["image_size"][0], - input_channels=3, - hidden_channels=64, - num_blocks=4, - dropout_rate=0.3, - use_batch_norm=True, -) - -# Try to load the best checkpoint -checkpoint_dir = os.path.join( - config_dict.get("output_dir", "runs/similarity"), "checkpoints" -) -best_checkpoint = os.path.join(checkpoint_dir, "best_model.pt") - -if os.path.exists(best_checkpoint): - checkpoint = torch.load(best_checkpoint, map_location=device) - model.load_state_dict(checkpoint["model_state_dict"]) - print(f"✓ Loaded best model from epoch {checkpoint['epoch']}") - print(f"✓ Best validation loss: {checkpoint['val_loss']:.4f}") -else: - print("⚠ Warning: Best model checkpoint not found!") - print(" Using randomly initialized model for demonstration.") - print(" (This is normal if you haven't trained the model yet)") - -model = model.to(device) -print(f"✓ Model moved to {device}") - -# Count parameters -total_params = sum(p.numel() for p in model.parameters()) -trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) -print(f"✓ Total parameters: {total_params:,}") -print(f"✓ Trainable parameters: {trainable_params:,}") - -# ============================================================================ -# STEP 4: PLOT TRAINING METRICS -# ============================================================================ -print("\nSTEP 4: Plotting training metrics") -print("-" * 40) -print("This shows how the model learned over time:") - -# This will show 4 plots: -# 1. Training and validation loss -# 2. Training and validation accuracy -# 3. Overfitting indicator -# 4. Learning rate schedule -plot_training_metrics(config_dict.get("output_dir", "runs/similarity")) - -print("✓ Training metrics plotted!") -print(" Look for 'training_metrics.png' in your runs directory") - -# ============================================================================ -# STEP 5: ANALYZE MODEL PERFORMANCE -# ============================================================================ -print("\nSTEP 5: Analyzing model performance on validation set") -print("-" * 40) -print("Calculating metrics like accuracy, precision, recall, F1 score...") - -# Analyze the model -metrics = analyze_model_performance(model, val_loader, device, threshold=0.5) - -print("\n📊 PERFORMANCE METRICS:") -print(" Accuracy: {:.2%}".format(metrics["accuracy"])) -print(" Precision: {:.2%}".format(metrics["precision"])) -print(" Recall: {:.2%}".format(metrics["recall"])) -print(" F1 Score: {:.2%}".format(metrics["f1_score"])) -print(" ROC AUC: {:.4f}".format(metrics["roc_auc"])) - -# ============================================================================ -# STEP 6: SHOW CONFUSION MATRIX -# ============================================================================ -print("\nSTEP 6: Confusion Matrix") -print("-" * 40) -print("This shows how many predictions were correct/wrong:") - -plot_confusion_matrix(metrics["confusion_matrix"]) - -# ============================================================================ -# STEP 7: ROC CURVE -# ============================================================================ -print("\nSTEP 7: ROC Curve") -print("-" * 40) -print("This shows how well the model distinguishes between classes:") - -# Get probabilities for ROC curve -model.eval() -all_probabilities = [] -all_targets = [] - -with torch.no_grad(): - for batch in val_loader: - google_img = batch["google_img"].to(device) - yandex_img = batch["yandex_img"].to(device) - target = batch["same_domain"].float().to(device) - - output = model(google_img, yandex_img) - probabilities = torch.sigmoid(output).squeeze() - - all_probabilities.extend(probabilities.cpu().numpy()) - all_targets.extend(target.cpu().numpy()) - -all_probabilities = np.array(all_probabilities) -all_targets = np.array(all_targets) - -from sklearn.metrics import auc, roc_curve - -fpr, tpr, _ = roc_curve(all_targets, all_probabilities) -roc_auc = auc(fpr, tpr) - -plot_roc_curve(fpr, tpr, roc_auc) - -# ============================================================================ -# STEP 8: PROBABILITY DISTRIBUTION -# ============================================================================ -print("\nSTEP 8: Probability Distribution") -print("-" * 40) -print("This shows how confident the model is for different classes:") - -plot_probability_distribution(all_probabilities, all_targets) - -# ============================================================================ -# STEP 9: TEST ON EXAMPLE IMAGES -# ============================================================================ -print("\nSTEP 9: Testing on example images") -print("-" * 40) -print("Let's see how the model performs on some examples:") - -test_results = test_model_on_examples(model, device) - -# ============================================================================ -# STEP 10: GENERATE REPORT -# ============================================================================ -print("\nSTEP 10: Generating performance report") -print("-" * 40) -print("Creating a detailed report with all metrics...") - -final_metrics = generate_performance_report(model, val_loader, device) - -print("\n" + "=" * 70) -print("🎉 DEMO COMPLETED SUCCESSFULLY!") -print("=" * 70) - -print("\n📁 What was created:") -print(" 1. Training metrics plots (saved to runs/similarity/)") -print(" 2. Confusion matrix visualization") -print(" 3. ROC curve plot") -print(" 4. Probability distribution plot") -print(" 5. Performance report (saved to reports/)") - -print("\n🔍 Key things to check in your model:") -print(" ✓ Accuracy should be above 70% for a good model") -print(" ✓ Precision: High = few false positives") -print(" ✓ Recall: High = few false negatives") -print(" ✓ ROC AUC: Above 0.8 = good discrimination") - -print("\n🔄 If results are poor, try:") -print(" 1. Train for more epochs") -print(" 2. Adjust learning rate") -print(" 3. Use more training data") -print(" 4. Try different model architecture") - -print( - "\n💡 Pro tip: The optimal threshold is {:.3f}".format( - final_metrics["optimal_threshold"] - ) -) -print(" You can use this instead of 0.5 for better results!") - -# ============================================================================ -# BONUS: QUICK DIAGNOSTICS TABLE -# ============================================================================ -print("\n" + "=" * 70) -print("BONUS: Quick Diagnostics Table") -print("=" * 70) - -# Create a simple table of what each metric means -diagnostics = [ - ["Metric", "Value", "What it means", "Is it good?"], - ["-" * 15, "-" * 10, "-" * 30, "-" * 15], - ["Accuracy", f"{metrics['accuracy']:.2%}", "Overall correctness", ">70% is good"], - ["Precision", f"{metrics['precision']:.2%}", "Few false positives", ">70% is good"], - ["Recall", f"{metrics['recall']:.2%}", "Few false negatives", ">70% is good"], - [ - "F1 Score", - f"{metrics['f1_score']:.2%}", - "Balance of precision/recall", - ">70% is good", - ], - ["ROC AUC", f"{metrics['roc_auc']:.4f}", "Discrimination ability", ">0.8 is good"], -] - -for row in diagnostics: - print("{:<15} {:<10} {:<30} {:<15}".format(*row)) - -print("\n" + "=" * 70) -print("To run this again, just execute: python demo_evaluation.ipynb.py") -print("=" * 70) - -# ============================================================================ -# EXTRA: SAVE PREDICTIONS FOR FURTHER ANALYSIS -# ============================================================================ -print("\n💾 Saving predictions for further analysis...") - -# Get all predictions -model.eval() -all_predictions = [] -all_targets = [] -all_probabilities = [] -image_indices = [] - -with torch.no_grad(): - for batch_idx, batch in enumerate(val_loader): - google_img = batch["google_img"].to(device) - yandex_img = batch["yandex_img"].to(device) - target = batch["same_domain"].float().to(device) - - output = model(google_img, yandex_img) - probabilities = torch.sigmoid(output).squeeze() - predictions = (probabilities > 0.5).float() - - all_predictions.extend(predictions.cpu().numpy()) - all_targets.extend(target.cpu().numpy()) - all_probabilities.extend(probabilities.cpu().numpy()) - image_indices.extend( - range(batch_idx * len(target), (batch_idx + 1) * len(target)) - ) - -# Save to CSV for further analysis -import pandas as pd - -predictions_df = pd.DataFrame( - { - "image_pair_index": image_indices, - "true_label": all_targets, - "predicted_label": all_predictions, - "probability": all_probabilities, - "correct": np.array(all_targets) == np.array(all_predictions), - } -) - -predictions_path = os.path.join( - config_dict.get("output_dir", "runs/similarity"), "predictions_analysis.csv" -) -predictions_df.to_csv(predictions_path, index=False) -print(f"✓ Predictions saved to: {predictions_path}") -print(f"✓ Total predictions: {len(predictions_df)}") -print( - f"✓ Correct predictions: {predictions_df['correct'].sum()} ({predictions_df['correct'].mean():.2%})" -) - -print("\n" + "🎯 You can now analyze individual predictions in the CSV file!") -print(" Look for patterns in the mistakes your model makes.") diff --git a/models/SiaN-similarity/evaluation.py b/models/SiaN-similarity/evaluation.py deleted file mode 100644 index 1afd7a7..0000000 --- a/models/SiaN-similarity/evaluation.py +++ /dev/null @@ -1,663 +0,0 @@ -""" -Evaluation and visualization for image similarity model. -This file contains code for plotting training metrics, analyzing model performance, -and testing the trained model. -""" - -import os - -import matplotlib.pyplot as plt -import numpy as np -import pandas as pd -import seaborn as sns -import torch -import torch.nn as nn -from dataloader import config, create_data_loaders -from model import create_similarity_model -from sklearn.metrics import auc, classification_report, confusion_matrix, roc_curve -from torch.utils.data import DataLoader -from train import SimilarityTrainer - -# Set style for plots -plt.style.use("seaborn-v0_8-darkgrid") -sns.set_palette("husl") - - -def plot_training_metrics(log_dir="runs/similarity"): - """ - Plot training and validation metrics from TensorBoard logs or saved metrics. - - Args: - log_dir: Directory containing training logs - """ - # In a real scenario, we would read from TensorBoard logs - # For this example, we'll create simulated data to show what plots would look like - - # Simulated training data (in reality, you would load this from logs) - epochs = list(range(1, 51)) - - # Simulated metrics - train_loss = [0.8 - 0.015 * i + np.random.normal(0, 0.02) for i in range(50)] - val_loss = [0.75 - 0.012 * i + np.random.normal(0, 0.03) for i in range(50)] - train_acc = [0.55 + 0.008 * i + np.random.normal(0, 0.01) for i in range(50)] - val_acc = [0.6 + 0.006 * i + np.random.normal(0, 0.015) for i in range(50)] - - # Create figure with subplots - fig, axes = plt.subplots(2, 2, figsize=(14, 10)) - - # Plot 1: Training and Validation Loss - axes[0, 0].plot(epochs, train_loss, "b-", linewidth=2, label="Training Loss") - axes[0, 0].plot(epochs, val_loss, "r-", linewidth=2, label="Validation Loss") - axes[0, 0].set_xlabel("Epoch") - axes[0, 0].set_ylabel("Loss") - axes[0, 0].set_title("Training and Validation Loss") - axes[0, 0].legend() - axes[0, 0].grid(True, alpha=0.3) - - # Plot 2: Training and Validation Accuracy - axes[0, 1].plot(epochs, train_acc, "b-", linewidth=2, label="Training Accuracy") - axes[0, 1].plot(epochs, val_acc, "r-", linewidth=2, label="Validation Accuracy") - axes[0, 1].set_xlabel("Epoch") - axes[0, 1].set_ylabel("Accuracy") - axes[0, 1].set_title("Training and Validation Accuracy") - axes[0, 1].legend() - axes[0, 1].grid(True, alpha=0.3) - - # Plot 3: Loss difference (train - val) - loss_diff = [t - v for t, v in zip(train_loss, val_loss)] - axes[1, 0].plot(epochs, loss_diff, "g-", linewidth=2) - axes[1, 0].axhline(y=0, color="r", linestyle="--", alpha=0.5) - axes[1, 0].fill_between( - epochs, - 0, - loss_diff, - where=np.array(loss_diff) > 0, - alpha=0.3, - color="red", - label="Overfitting (train > val)", - ) - axes[1, 0].fill_between( - epochs, - 0, - loss_diff, - where=np.array(loss_diff) < 0, - alpha=0.3, - color="green", - label="Underfitting (train < val)", - ) - axes[1, 0].set_xlabel("Epoch") - axes[1, 0].set_ylabel("Loss Difference") - axes[1, 0].set_title("Train Loss - Val Loss (Overfitting Indicator)") - axes[1, 0].legend() - axes[1, 0].grid(True, alpha=0.3) - - # Plot 4: Learning rate schedule (if available) - axes[1, 1].plot( - epochs, [0.0002 * (0.95**i) for i in range(50)], "purple-", linewidth=2 - ) - axes[1, 1].set_xlabel("Epoch") - axes[1, 1].set_ylabel("Learning Rate") - axes[1, 1].set_title("Learning Rate Schedule") - axes[1, 1].set_yscale("log") - axes[1, 1].grid(True, alpha=0.3) - - plt.tight_layout() - plt.savefig( - os.path.join(log_dir, "training_metrics.png"), dpi=150, bbox_inches="tight" - ) - plt.show() - - print( - "Training metrics plots saved to:", - os.path.join(log_dir, "training_metrics.png"), - ) - - -def analyze_model_performance(model, data_loader, device, threshold=0.5): - """ - Analyze model performance on a dataset. - - Args: - model: Trained model - data_loader: DataLoader with test/validation data - device: torch device - threshold: Decision threshold for binary classification - - Returns: - Dictionary with performance metrics - """ - model.eval() - - all_predictions = [] - all_targets = [] - all_probabilities = [] - - with torch.no_grad(): - for batch in data_loader: - google_img = batch["google_img"].to(device) - yandex_img = batch["yandex_img"].to(device) - target = batch["same_domain"].float().to(device) - - output = model(google_img, yandex_img) - probabilities = torch.sigmoid(output).squeeze() - - predictions = (probabilities > threshold).float() - - all_predictions.extend(predictions.cpu().numpy()) - all_targets.extend(target.cpu().numpy()) - all_probabilities.extend(probabilities.cpu().numpy()) - - # Convert to numpy arrays - all_predictions = np.array(all_predictions) - all_targets = np.array(all_targets) - all_probabilities = np.array(all_probabilities) - - # Calculate confusion matrix - cm = confusion_matrix(all_targets, all_predictions) - - # Calculate metrics - tn, fp, fn, tp = cm.ravel() - - accuracy = (tp + tn) / (tp + tn + fp + fn) - precision = tp / (tp + fp) if (tp + fp) > 0 else 0 - recall = tp / (tp + fn) if (tp + fn) > 0 else 0 - f1_score = ( - 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0 - ) - - # Create classification report - report = classification_report( - all_targets, all_predictions, target_names=["Different", "Same"] - ) - - # Calculate ROC curve - fpr, tpr, thresholds = roc_curve(all_targets, all_probabilities) - roc_auc = auc(fpr, tpr) - - # Find optimal threshold (Youden's J statistic) - youden_j = tpr - fpr - optimal_idx = np.argmax(youden_j) - optimal_threshold = thresholds[optimal_idx] - - metrics = { - "confusion_matrix": cm, - "accuracy": accuracy, - "precision": precision, - "recall": recall, - "f1_score": f1_score, - "roc_auc": roc_auc, - "optimal_threshold": optimal_threshold, - "classification_report": report, - "true_negatives": tn, - "false_positives": fp, - "false_negatives": fn, - "true_positives": tp, - } - - return metrics - - -def plot_confusion_matrix(cm, class_names=["Different", "Same"]): - """ - Plot confusion matrix with annotations. - - Args: - cm: Confusion matrix - class_names: List of class names - """ - plt.figure(figsize=(8, 6)) - - # Create heatmap - sns.heatmap( - cm, - annot=True, - fmt="d", - cmap="Blues", - xticklabels=class_names, - yticklabels=class_names, - ) - - plt.title("Confusion Matrix") - plt.ylabel("True Label") - plt.xlabel("Predicted Label") - - # Add text annotations - tn, fp, fn, tp = cm.ravel() - - plt.text( - 0.5, -0.15, f"True Negatives: {tn}", ha="center", transform=plt.gca().transAxes - ) - plt.text( - 0.5, -0.20, f"False Positives: {fp}", ha="center", transform=plt.gca().transAxes - ) - plt.text( - 0.5, -0.25, f"False Negatives: {fn}", ha="center", transform=plt.gca().transAxes - ) - plt.text( - 0.5, -0.30, f"True Positives: {tp}", ha="center", transform=plt.gca().transAxes - ) - - plt.tight_layout() - plt.show() - - # Create a summary table - print("\n" + "=" * 50) - print("CONFUSION MATRIX SUMMARY") - print("=" * 50) - - summary_data = { - "Metric": [ - "True Negatives", - "False Positives", - "False Negatives", - "True Positives", - ], - "Count": [tn, fp, fn, tp], - "Description": [ - "Correctly predicted as different", - "Incorrectly predicted as same (Type I error)", - "Incorrectly predicted as different (Type II error)", - "Correctly predicted as same", - ], - } - - df = pd.DataFrame(summary_data) - print(df.to_string(index=False)) - print("=" * 50) - - -def plot_roc_curve(fpr, tpr, roc_auc): - """ - Plot ROC curve. - - Args: - fpr: False positive rates - tpr: True positive rates - roc_auc: Area under ROC curve - """ - plt.figure(figsize=(8, 6)) - - plt.plot( - fpr, tpr, color="darkorange", lw=2, label=f"ROC curve (AUC = {roc_auc:.3f})" - ) - plt.plot([0, 1], [0, 1], color="navy", lw=2, linestyle="--", label="Random") - - plt.xlim([0.0, 1.0]) - plt.ylim([0.0, 1.05]) - plt.xlabel("False Positive Rate") - plt.ylabel("True Positive Rate") - plt.title("Receiver Operating Characteristic (ROC) Curve") - plt.legend(loc="lower right") - plt.grid(True, alpha=0.3) - - plt.tight_layout() - plt.show() - - print(f"ROC AUC Score: {roc_auc:.4f}") - print("AUC Interpretation:") - print("0.90-1.00 = Excellent") - print("0.80-0.90 = Good") - print("0.70-0.80 = Fair") - print("0.60-0.70 = Poor") - print("0.50-0.60 = Fail") - - -def plot_probability_distribution(all_probabilities, all_targets): - """ - Plot probability distribution for positive and negative classes. - - Args: - all_probabilities: List of predicted probabilities - all_targets: List of true labels - """ - # Separate probabilities by true class - pos_probs = [p for p, t in zip(all_probabilities, all_targets) if t == 1] - neg_probs = [p for p, t in zip(all_probabilities, all_targets) if t == 0] - - plt.figure(figsize=(10, 6)) - - # Plot histograms - plt.hist( - pos_probs, - bins=30, - alpha=0.5, - color="green", - label="Same Domain (Positive)", - density=True, - ) - plt.hist( - neg_probs, - bins=30, - alpha=0.5, - color="red", - label="Different Domain (Negative)", - density=True, - ) - - # Add vertical line at threshold 0.5 - plt.axvline( - x=0.5, - color="black", - linestyle="--", - linewidth=2, - label="Decision Threshold (0.5)", - ) - - plt.xlabel("Predicted Probability") - plt.ylabel("Density") - plt.title("Probability Distribution by True Class") - plt.legend() - plt.grid(True, alpha=0.3) - - plt.tight_layout() - plt.show() - - # Print statistics - print("\nProbability Statistics:") - print( - f"Positive class (Same): Mean = {np.mean(pos_probs):.3f}, Std = {np.std(pos_probs):.3f}" - ) - print( - f"Negative class (Different): Mean = {np.mean(neg_probs):.3f}, Std = {np.std(neg_probs):.3f}" - ) - - -def test_model_on_examples(model, device, examples_dir="examples"): - """ - Test model on example image pairs. - - Args: - model: Trained model - device: torch device - examples_dir: Directory containing example image pairs - """ - import cv2 - from torchvision import transforms - - model.eval() - - # Define image preprocessing - transform = transforms.Compose( - [ - transforms.ToTensor(), - transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), - ] - ) - - # Check if examples directory exists - if not os.path.exists(examples_dir): - print(f"Examples directory '{examples_dir}' not found.") - print("Creating dummy examples for demonstration...") - - # Create dummy example data for demonstration - examples = [ - { - "name": "Example 1: Similar locations", - "google_img": torch.randn(1, 3, 224, 224), - "yandex_img": torch.randn(1, 3, 224, 224), - "expected": "Same", - }, - { - "name": "Example 2: Different locations", - "google_img": torch.randn(1, 3, 224, 224), - "yandex_img": torch.randn(1, 3, 224, 224) * 2, - "expected": "Different", - }, - ] - else: - # In real implementation, load actual images - examples = [] - - print("\n" + "=" * 60) - print("MODEL TESTING ON EXAMPLES") - print("=" * 60) - - results = [] - - for example in examples: - with torch.no_grad(): - google_img = example["google_img"].to(device) - yandex_img = example["yandex_img"].to(device) - - output = model(google_img, yandex_img) - probability = torch.sigmoid(output).item() - prediction = "Same" if probability > 0.5 else "Different" - - result = { - "Example": example["name"], - "Predicted": prediction, - "Probability": probability, - "Expected": example.get("expected", "Unknown"), - "Correct": prediction == example.get("expected", "Unknown"), - } - - results.append(result) - - print(f"\n{example['name']}:") - print(f" Predicted: {prediction} (probability: {probability:.4f})") - print(f" Expected: {example.get('expected', 'Unknown')}") - print(f" Result: {'✓ CORRECT' if result['Correct'] else '✗ WRONG'}") - - # Create results table - print("\n" + "=" * 60) - print("SUMMARY OF TEST RESULTS") - print("=" * 60) - - df_results = pd.DataFrame(results) - print(df_results.to_string(index=False)) - - accuracy = df_results["Correct"].mean() * 100 - print(f"\nTest Accuracy: {accuracy:.1f}%") - - return df_results - - -def generate_performance_report(model, data_loader, device, output_dir="reports"): - """ - Generate a comprehensive performance report. - - Args: - model: Trained model - data_loader: DataLoader with test data - device: torch device - output_dir: Directory to save reports - """ - os.makedirs(output_dir, exist_ok=True) - - print("Generating performance report...") - - # Analyze performance - metrics = analyze_model_performance(model, data_loader, device) - - # Create report file - report_path = os.path.join(output_dir, "model_performance_report.txt") - - with open(report_path, "w") as f: - f.write("=" * 60 + "\n") - f.write("MODEL PERFORMANCE REPORT\n") - f.write("=" * 60 + "\n\n") - - f.write("1. BASIC METRICS\n") - f.write("-" * 40 + "\n") - f.write(f"Accuracy: {metrics['accuracy']:.4f}\n") - f.write(f"Precision: {metrics['precision']:.4f}\n") - f.write(f"Recall: {metrics['recall']:.4f}\n") - f.write(f"F1 Score: {metrics['f1_score']:.4f}\n") - f.write(f"ROC AUC: {metrics['roc_auc']:.4f}\n") - f.write(f"Optimal Threshold: {metrics['optimal_threshold']:.4f}\n\n") - - f.write("2. CONFUSION MATRIX\n") - f.write("-" * 40 + "\n") - f.write(f"True Negatives: {metrics['true_negatives']}\n") - f.write(f"False Positives: {metrics['false_positives']}\n") - f.write(f"False Negatives: {metrics['false_negatives']}\n") - f.write(f"True Positives: {metrics['true_positives']}\n\n") - - f.write("3. CLASSIFICATION REPORT\n") - f.write("-" * 40 + "\n") - f.write(metrics["classification_report"] + "\n") - - f.write("4. INTERPRETATION\n") - f.write("-" * 40 + "\n") - f.write("Accuracy: Proportion of correct predictions\n") - f.write("Precision: Proportion of positive predictions that are correct\n") - f.write( - "Recall: Proportion of actual positives that are correctly identified\n" - ) - f.write("F1 Score: Harmonic mean of precision and recall\n") - f.write("ROC AUC: Ability to distinguish between classes\n\n") - - f.write("5. RECOMMENDATIONS\n") - f.write("-" * 40 + "\n") - if metrics["precision"] < 0.7: - f.write("- Improve precision to reduce false positives\n") - if metrics["recall"] < 0.7: - f.write("- Improve recall to reduce false negatives\n") - if metrics["f1_score"] < 0.7: - f.write("- Overall model performance needs improvement\n") - if metrics["roc_auc"] > 0.8: - f.write("- Good discrimination ability between classes\n") - else: - f.write("- Consider improving feature extraction\n") - - print(f"Report saved to: {report_path}") - - return metrics - - -def main(): - """ - Main function to run evaluation and generate reports. - """ - print("=" * 60) - print("IMAGE SIMILARITY MODEL EVALUATION") - print("=" * 60) - - # Set device - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - print(f"Using device: {device}") - - # Load configuration - config_dict = config.copy() - if isinstance(config_dict.get("image_size"), list): - config_dict["image_size"] = tuple(config_dict["image_size"]) - - # Create data loaders - print("\n1. Creating data loaders...") - _, val_loader = create_data_loaders( - root_dir=config_dict["data_dir"], - batch_size=config_dict["batch_size"], - train_split=config_dict["train_split"], - num_workers=config_dict["num_workers"], - image_size=config_dict["image_size"], - augment_train=False, - augment_val=False, - device=device, - ) - - print(f"Validation batches: {len(val_loader)}") - - # Load trained model - print("\n2. Loading trained model...") - model = create_similarity_model( - model_type="cnn", - input_size=config_dict["image_size"][0] - if isinstance(config_dict["image_size"], (tuple, list)) - else config_dict["image_size"], - input_channels=3, - hidden_channels=64, - num_blocks=4, - dropout_rate=0.3, - use_batch_norm=True, - ) - - # Load best checkpoint - checkpoint_dir = os.path.join( - config_dict.get("output_dir", "runs/similarity"), "checkpoints" - ) - best_checkpoint = os.path.join(checkpoint_dir, "best_model.pt") - - if os.path.exists(best_checkpoint): - checkpoint = torch.load(best_checkpoint, map_location=device) - model.load_state_dict(checkpoint["model_state_dict"]) - print(f"Loaded best model from epoch {checkpoint['epoch']}") - print(f"Best validation loss: {checkpoint['val_loss']:.4f}") - else: - print("Warning: Best model checkpoint not found!") - print("Using randomly initialized model for demonstration.") - - model = model.to(device) - - # Plot training metrics - print("\n3. Plotting training metrics...") - plot_training_metrics(config_dict.get("output_dir", "runs/similarity")) - - # Analyze model performance - print("\n4. Analyzing model performance...") - metrics = analyze_model_performance(model, val_loader, device) - - # Display results - print("\n" + "=" * 60) - print("PERFORMANCE METRICS") - print("=" * 60) - print(f"Accuracy: {metrics['accuracy']:.4f}") - print(f"Precision: {metrics['precision']:.4f}") - print(f"Recall: {metrics['recall']:.4f}") - print(f"F1 Score: {metrics['f1_score']:.4f}") - print(f"ROC AUC: {metrics['roc_auc']:.4f}") - print(f"Optimal Threshold: {metrics['optimal_threshold']:.4f}") - - # Plot confusion matrix - print("\n5. Plotting confusion matrix...") - plot_confusion_matrix(metrics["confusion_matrix"]) - - # Plot ROC curve - print("\n6. Plotting ROC curve...") - # For demonstration, we need to get probabilities again - model.eval() - all_probabilities = [] - all_targets = [] - - with torch.no_grad(): - for batch in val_loader: - google_img = batch["google_img"].to(device) - yandex_img = batch["yandex_img"].to(device) - target = batch["same_domain"].float().to(device) - - output = model(google_img, yandex_img) - probabilities = torch.sigmoid(output).squeeze() - - all_probabilities.extend(probabilities.cpu().numpy()) - all_targets.extend(target.cpu().numpy()) - - all_probabilities = np.array(all_probabilities) - all_targets = np.array(all_targets) - - fpr, tpr, _ = roc_curve(all_targets, all_probabilities) - roc_auc = auc(fpr, tpr) - plot_roc_curve(fpr, tpr, roc_auc) - - # Plot probability distribution - print("\n7. Plotting probability distribution...") - plot_probability_distribution(all_probabilities, all_targets) - - # Test on examples - print("\n8. Testing on examples...") - test_model_on_examples(model, device) - - # Generate comprehensive report - print("\n9. Generating performance report...") - generate_performance_report(model, val_loader, device) - - print("\n" + "=" * 60) - print("EVALUATION COMPLETED SUCCESSFULLY!") - print("=" * 60) - print("\nNext steps:") - print("1. Check the generated plots in the runs/similarity directory") - print("2. Review the performance report in the reports directory") - print("3. Consider adjusting the decision threshold if needed") - print("4. Retrain with different hyperparameters if performance is poor") - - -if __name__ == "__main__": - main() diff --git a/models/SiaN-similarity/example.py b/models/SiaN-similarity/example.py deleted file mode 100644 index d17cda7..0000000 --- a/models/SiaN-similarity/example.py +++ /dev/null @@ -1,216 +0,0 @@ -""" -Пример использования модели оценки схожести с даталоадером. -""" - -import torch -from dataloader import YaGoDataset, create_data_loaders -from model import SimilarityCNN, SimilarityLoss - - -def main(): - """Основной пример использования.""" - print("ПРИМЕР ИСПОЛЬЗОВАНИЯ МОДЕЛИ СХОЖЕСТИ С ДАТАЛОАДЕРОМ") - print("=" * 60) - - # Конфигурация - config = { - "data_dir": r"C:\Users\admin\Projects\autopilot\datasets\ya_go_maps\images", - "batch_size": 4, - "image_size": (256, 256), - "train_split": 0.8, - "num_workers": 0, - } - - # Устройство - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - print(f"Устройство: {device}") - - # 1. Создание датасета - print("\n1. СОЗДАНИЕ ДАТАСЕТА") - print("-" * 40) - - dataset = YaGoDataset( - root_dir=config["data_dir"], - augment=False, - image_size=config["image_size"], - ) - - print(f"Размер датасета: {len(dataset)} пар изображений") - - # Получение примера из датасета - sample = dataset[0] - print(f"\nПример из датасета:") - print(f" Google image shape: {sample['google_img'].shape}") - print(f" Yandex image shape: {sample['yandex_img'].shape}") - print(f" Same domain: {sample['same_domain']}") - print(f" Index: {sample['idx'].item()}") - - # 2. Создание даталоадеров - print("\n2. СОЗДАНИЕ ДАТАЛОАДЕРОВ") - print("-" * 40) - - train_loader, val_loader = create_data_loaders( - root_dir=config["data_dir"], - batch_size=config["batch_size"], - train_split=config["train_split"], - num_workers=config["num_workers"], - image_size=config["image_size"], - augment_train=True, - augment_val=False, - device=device, - ) - - print(f"Train batches: {len(train_loader)}") - print(f"Val batches: {len(val_loader)}") - - # 3. Создание модели - print("\n3. СОЗДАНИЕ МОДЕЛИ") - print("-" * 40) - - model = SimilarityCNN( - input_channels=3, - hidden_channels=64, - num_blocks=4, - dropout_rate=0.3, - use_batch_norm=True, - ).to(device) - - print(f"Параметры модели: {sum(p.numel() for p in model.parameters()):,}") - - # 4. Тестирование на одном батче - print("\n4. ТЕСТИРОВАНИЕ НА ОДНОМ БАТЧЕ") - print("-" * 40) - - # Получаем батч из train_loader - for batch in train_loader: - google_img = batch["google_img"].to(device) - yandex_img = batch["yandex_img"].to(device) - same_domain = batch["same_domain"].float().to(device).unsqueeze(1) - - print(f"Batch size: {google_img.shape[0]}") - print(f"Image shape: {google_img.shape[1:]}") - print(f"Same domain labels: {same_domain.squeeze().tolist()}") - - # Предсказание схожести - with torch.no_grad(): - predictions = model.predict_similarity(google_img, yandex_img) - print(f"\nПредсказания схожести:") - for i in range(len(predictions)): - print( - f" Sample {i}: {predictions[i].item():.4f} (target: {same_domain[i].item():.1f})" - ) - - # Расчет потерь - loss_fn = SimilarityLoss().to(device) - loss = loss_fn(predictions, same_domain) - print(f"\nПотери на батче: {loss.item():.4f}") - - # Расчет метрик - metrics = loss_fn.compute_metrics(predictions, same_domain) - print("\nМетрики на батче:") - for key, value in metrics.items(): - print(f" {key}: {value:.4f}") - - break # Только первый батч - - # 5. Обучение на одном эпохе (демонстрация) - print("\n5. ДЕМОНСТРАЦИЯ ОБУЧЕНИЯ НА ОДНОЙ ЭПОХЕ") - print("-" * 40) - - optimizer = torch.optim.Adam(model.parameters(), lr=2e-4) - model.train() - - total_loss = 0 - total_samples = 0 - - for batch_idx, batch in enumerate(train_loader): - if batch_idx >= 3: # Ограничиваем 3 батчами для демонстрации - break - - google_img = batch["google_img"].to(device) - yandex_img = batch["yandex_img"].to(device) - same_domain = batch["same_domain"].float().to(device).unsqueeze(1) - - optimizer.zero_grad() - - predictions = model(google_img, yandex_img) - loss = loss_fn(predictions, same_domain) - - loss.backward() - optimizer.step() - - batch_loss = loss.item() * google_img.size(0) - total_loss += batch_loss - total_samples += google_img.size(0) - - print(f"Batch {batch_idx + 1}: loss = {loss.item():.4f}") - - avg_loss = total_loss / total_samples - print(f"\nСредние потери за 3 батча: {avg_loss:.4f}") - - # 6. Валидация - print("\n6. ВАЛИДАЦИЯ") - print("-" * 40) - - model.eval() - val_loss = 0 - val_samples = 0 - - with torch.no_grad(): - for batch_idx, batch in enumerate(val_loader): - if batch_idx >= 2: # Ограничиваем 2 батчами для демонстрации - break - - google_img = batch["google_img"].to(device) - yandex_img = batch["yandex_img"].to(device) - same_domain = batch["same_domain"].float().to(device).unsqueeze(1) - - predictions = model.predict_similarity(google_img, yandex_img) - loss = loss_fn(predictions, same_domain) - - val_loss += loss.item() * google_img.size(0) - val_samples += google_img.size(0) - - print(f"Val batch {batch_idx + 1}: loss = {loss.item():.4f}") - - avg_val_loss = val_loss / val_samples - print(f"\nСредние потери на валидации: {avg_val_loss:.4f}") - - # 7. Пример использования для отдельных изображений - print("\n7. ПРИМЕР ДЛЯ ОТДЕЛЬНЫХ ИЗОБРАЖЕНИЙ") - print("-" * 40) - - # Берем два примера из датасета - sample1 = dataset[0] - sample2 = dataset[1] - - # Подготавливаем тензоры - img1_1 = sample1["google_img"].unsqueeze(0).to(device) - img1_2 = sample1["yandex_img"].unsqueeze(0).to(device) - - img2_1 = sample2["google_img"].unsqueeze(0).to(device) - img2_2 = sample2["yandex_img"].unsqueeze(0).to(device) - - # Предсказания - with torch.no_grad(): - # Сравнение пар из одного домена - sim_same1 = model.predict_similarity(img1_1, img1_2) - sim_same2 = model.predict_similarity(img2_1, img2_2) - - # Сравнение пар из разных доменов - sim_diff1 = model.predict_similarity(img1_1, img2_2) - sim_diff2 = model.predict_similarity(img2_1, img1_2) - - print("Сравнение пар изображений:") - print(f" Пара 1 (один домен): {sim_same1.item():.4f}") - print(f" Пара 2 (один домен): {sim_same2.item():.4f}") - print(f" Разные домены 1: {sim_diff1.item():.4f}") - print(f" Разные домены 2: {sim_diff2.item():.4f}") - - print("\n" + "=" * 60) - print("ПРИМЕР ЗАВЕРШЕН") - print("=" * 60) - - -if __name__ == "__main__": - main() diff --git a/models/SiaN-similarity/simple_results_explanation.py b/models/SiaN-similarity/simple_results_explanation.py deleted file mode 100644 index 2dcec8a..0000000 --- a/models/SiaN-similarity/simple_results_explanation.py +++ /dev/null @@ -1,256 +0,0 @@ -""" -ПРОСТОЕ ОБЪЯСНЕНИЕ РЕЗУЛЬТАТОВ ОБУЧЕНИЯ -======================================== - -Этот файл объясняет результаты обучения модели простыми словами, -как будто ты студент, который только начал изучать машинное обучение. - -Представь, что train.py - это предыдущая ячейка в твоем блокноте, -где ты обучил модель. Теперь давай посмотрим, что у нас получилось! -""" - -print("=" * 70) -print("ПРОСТОЕ ОБЪЯСНЕНИЕ РЕЗУЛЬТАТОВ МОДЕЛИ") -print("=" * 70) -print() - -# ------------------------------------------------------------------- -# ЧАСТЬ 1: ЧТО МЫ СДЕЛАЛИ? -# ------------------------------------------------------------------- -print("1. ЧТО МЫ СДЕЛАЛИ?") -print("-" * 40) -print("Мы создали модель, которая смотрит на две картинки и говорит:") -print(" - 'ДА' - если это одно и то же место (с Google и Яндекс карт)") -print(" - 'НЕТ' - если это разные места") -print() -print("Модель училась на тысячах пар картинок!") -print("Сначала она делала много ошибок, но потом научилась.") -print() - -# ------------------------------------------------------------------- -# ЧАСТЬ 2: КАК МЫ ИЗМЕРЯЕМ УСПЕХ? -# ------------------------------------------------------------------- -print("2. КАК МЫ ИЗМЕРЯЕМ УСПЕХ?") -print("-" * 40) -print("Мы проверяем модель на новых картинках, которых она не видела.") -print("Считаем, сколько раз она угадала правильно.") -print() - -print("Есть 4 возможных исхода:") -print(" 1. ✅ Истинно-положительный (True Positive - TP):") -print(" Модель сказала 'ДА' и это правда 'ДА'") -print() -print(" 2. ❌ Ложно-положительный (False Positive - FP):") -print(" Модель сказала 'ДА', но на самом деле 'НЕТ'") -print(" (Ошибка типа I: приняла разные места за одинаковые)") -print() -print(" 3. ❌ Ложно-отрицательный (False Negative - FN):") -print(" Модель сказала 'НЕТ', но на самом деле 'ДА'") -print(" (Ошибка типа II: не узнала одинаковые места)") -print() -print(" 4. ✅ Истинно-отрицательный (True Negative - TN):") -print(" Модель сказала 'НЕТ' и это правда 'НЕТ'") -print() - -# ------------------------------------------------------------------- -# ЧАСТЬ 3: ПРОСТЫЕ МЕТРИКИ -# ------------------------------------------------------------------- -print("3. ПРОСТЫЕ МЕТРИКИ (ЧТО ОНИ ЗНАЧАТ?)") -print("-" * 40) - -# Примерные результаты (в реальности будут другие) -accuracy = 0.82 # 82% -precision = 0.78 # 78% -recall = 0.85 # 85% -f1_score = 0.81 # 81% - -print(f"ТОЧНОСТЬ (Accuracy): {accuracy:.0%}") -print(" Это как общая оценка в школе.") -print(" Сколько всего ответов правильных из 100.") -print(f" Наша модель правильна в {accuracy:.0%} случаев.") -print() - -print(f"ТОЧНОСТЬ КЛАССИФИКАЦИИ (Precision): {precision:.0%}") -print(" Когда модель говорит 'ДА', насколько ей можно верить?") -print(" Из 100 раз когда она сказала 'ДА', {precision:.0%} были правдой.") -print(" Высокая точность = мало ложных 'ДА'.") -print() - -print(f"ПОЛНОТА (Recall): {recall:.0%}") -print(" Сколько настоящих 'ДА' модель нашла?") -print(f" Из 100 настоящих 'ДА', модель нашла {recall:.0%}.") -print(" Высокая полнота = мало пропущенных 'ДА'.") -print() - -print(f"F1-МЕРА (F1 Score): {f1_score:.0%}") -print(" Баланс между точностью и полнотой.") -print(" Как золотая середина - не слишком строгая, не слишком добрая.") -print() - -# ------------------------------------------------------------------- -# ЧАСТЬ 4: ТАБЛИЦА РЕЗУЛЬТАТОВ (ПРОСТАЯ) -# ------------------------------------------------------------------- -print("4. ТАБЛИЦА РЕЗУЛЬТАТОВ") -print("-" * 40) -print("Давай представим, что мы протестировали модель на 1000 пар картинок:") -print() - -# Простая таблица -print(" | Модель сказала 'ДА' | Модель сказала 'НЕТ' | Всего") -print("-----------------|---------------------|----------------------|-------") -print(f"На самом деле 'ДА' | TP: 425 | FN: 75 | 500") -print(f"На самом деле 'НЕТ' | FP: 95 | TN: 405 | 500") -print("-----------------|---------------------|----------------------|-------") -print(f"Всего | 520 | 480 | 1000") -print() - -print("Расчеты:") -print(f" Точность = (TP + TN) / Всего = (425 + 405) / 1000 = {accuracy:.0%}") -print(f" Точность классификации = TP / (TP + FP) = 425 / 520 = {precision:.0%}") -print(f" Полнота = TP / (TP + FN) = 425 / 500 = {recall:.0%}") -print() - -# ------------------------------------------------------------------- -# ЧАСТЬ 5: КАК ИНТЕРПРЕТИРОВАТЬ РЕЗУЛЬТАТЫ? -# ------------------------------------------------------------------- -print("5. ЧТО ЭТО ЗНАЧИТ ДЛЯ НАШЕЙ ЗАДАЧИ?") -print("-" * 40) - -if precision > 0.75: - print("✅ ХОРОШО: Когда модель говорит 'это одно место',") - print(" ей можно доверять ({precision:.0%} случаев она права).") -else: - print("⚠ МОЖНО ЛУЧШЕ: Модель иногда путает разные места с одинаковыми.") - -if recall > 0.75: - print("✅ ХОРОШО: Модель находит большинство одинаковых мест") - print(f" ({recall:.0%} настоящих 'одинаковых' мест она находит).") -else: - print("⚠ МОЖНО ЛУЧШЕ: Модель пропускает много одинаковых мест.") - -print() -print("ДЛЯ АВТОПИЛОТА:") -print(" - Ложные 'ДА' (FP): Может думать, что мы в нужном месте,") -print(" когда это не так → опасно!") -print(" - Ложные 'НЕТ' (FN): Не узнает нужное место → менее опасно,") -print(" но машина может проехать мимо.") -print() - -# ------------------------------------------------------------------- -# ЧАСТЬ 6: ГРАФИКИ (ЧТО МЫ ВИДИМ?) -# ------------------------------------------------------------------- -print("6. КАКИЕ ГРАФИКИ МЫ ПОЛУЧАЕМ?") -print("-" * 40) -print("После обучения мы строим 4 основных графика:") -print() - -print("1. 📉 ГРАФИК ОШИБОК (Loss):") -print(" - Синяя линия: ошибки на обучающих данных") -print(" - Красная линия: ошибки на проверочных данных") -print(" - ХОРОШО: обе линии идут вниз и близки друг к другу") -print(" - ПЛОХО: линии далеко друг от друга (переобучение)") -print() - -print("2. 📈 ГРАФИК ТОЧНОСТИ (Accuracy):") -print(" - Показывает, как растет точность со временем") -print(" - Должен расти и стабилизироваться") -print() - -print("3. 🎯 МАТРИЦА ОШИБОК (Confusion Matrix):") -print(" - Квадратная таблица 2x2") -print(" - Показывает все 4 типа ответов (TP, FP, FN, TN)") -print(" - Идеально: все числа на диагонали, нули вне диагонали") -print() - -print("4. 📊 ROC-КРИВАЯ:") -print(" - Показывает, насколько хорошо модель отличает 'ДА' от 'НЕТ'") -print(" - Чем больше площадь под кривой, тем лучше") -print(" - Идеально: площадь = 1.0 (100%)") -print() - -# ------------------------------------------------------------------- -# ЧАСТЬ 7: ЧТО ДЕЛАТЬ ДАЛЬШЕ? -# ------------------------------------------------------------------- -print("7. ЧТО ДЕЛАТЬ, ЕСЛИ РЕЗУЛЬТАТЫ ПЛОХИЕ?") -print("-" * 40) -print("Если точность меньше 70%:") - -print("1. 🎯 ПРОБЛЕМА: Модель плохо учится") -print(" РЕШЕНИЕ:") -print(" - Учить дольше (увеличить количество эпох)") -print(" - Изменить скорость обучения (learning rate)") -print(" - Добавить больше данных для обучения") -print() - -print("2. 🎯 ПРОБЛЕМА: Модель запоминает, а не учится (переобучение)") -print(" РЕШЕНИЕ:") -print(" - Добавить регуляризацию (dropout)") -print(" - Использовать augmentation (искажать картинки)") -print(" - Упростить модель (меньше слоев)") -print() - -print("3. 🎯 ПРОБЛЕМА: Много ложных 'ДА' (FP)") -print(" РЕШЕНИЕ:") -print(" - Повысить порог принятия решения (например, 0.7 вместо 0.5)") -print(" - Добавить больше примеров 'разных' мест") -print() - -print("4. 🎯 ПРОБЛЕМА: Много ложных 'НЕТ' (FN)") -print(" РЕШЕНИЕ:") -print(" - Понизить порог принятия решения (например, 0.3 вместо 0.5)") -print(" - Добавить больше примеров 'одинаковых' мест") -print() - -# ------------------------------------------------------------------- -# ЧАСТЬ 8: ПРАКТИЧЕСКИЙ ПРИМЕР -# ------------------------------------------------------------------- -print("8. ПРАКТИЧЕСКИЙ ПРИМЕР: КАК ИСПОЛЬЗОВАТЬ МОДЕЛЬ") -print("-" * 40) -print("После обучения модель можно использовать так:") -print() - -print("```python") -print("# 1. Загружаем обученную модель") -print("model = load_trained_model('best_model.pt')") -print() -print("# 2. Берем две картинки") -print("google_img = load_image('google_map.png')") -print("yandex_img = load_image('yandex_map.png')") -print() -print("# 3. Спрашиваем у модели") -print("similarity_score = model.predict(google_img, yandex_img)") -print() -print("# 4. Интерпретируем результат") -print("if similarity_score > 0.5:") -print(" print('✅ Это похоже на одно и то же место!')") -print("else:") -print(" print('❌ Это разные места')") -print("```") -print() - -print(f"Порог 0.5 можно менять:") -print(f" - Порог 0.7: более строгая модель (меньше ложных 'ДА')") -print(f" - Порог 0.3: более добрая модель (меньше ложных 'НЕТ')") -print() - -# ------------------------------------------------------------------- -# ЧАСТЬ 9: ЗАКЛЮЧЕНИЕ -# ------------------------------------------------------------------- -print("9. ЧТО МЫ УЗНАЛИ?") -print("-" * 40) -print("✅ Модель учится сравнивать картинки") -print("✅ Мы можем измерить, насколько она хороша") -print("✅ Есть разные метрики для разных целей") -print("✅ Графики помогают понять процесс обучения") -print("✅ Можно улучшить модель, если результаты плохие") -print() - -print("=" * 70) -print("🎉 ВОТ И ВСЁ! ТЕПЕРЬ ТЫ ЗНАЕШЬ, КАК ОЦЕНИВАТЬ МОДЕЛЬ!") -print("=" * 70) -print() -print("Следующие шаги:") -print("1. Запусти evaluation.py чтобы увидеть реальные графики") -print("2. Посмотри на матрицу ошибок - какие ошибки чаще?") -print("3. Попробуй изменить порог принятия решений") -print("4. Если нужно - переобучи модель с другими параметрами") diff --git a/models/SiaN-similarity/train-adv.py b/models/SiaN-similarity/train-adv.py deleted file mode 100644 index efb32fa..0000000 --- a/models/SiaN-similarity/train-adv.py +++ /dev/null @@ -1,917 +0,0 @@ - -import os -import time -from datetime import datetime - -import torch -import torch.nn as nn -import torch.optim as optim -from torch.utils.data import DataLoader -from torch.utils.tensorboard import SummaryWriter -from tqdm import tqdm - -# ============================================================================= -# TRAINING LOOP WITH VISUALIZATION -# ============================================================================= - -class SimilarityTrainer: - def __init__( - self, - model: nn.Module, - trainloader: DataLoader, - valloader: DataLoader, - device: torch.device, - config: dict, - ): - self.model = model.to(device) - self.trainloader = trainloader - self.valloader = valloader - self.device = device - self.config = config - - self.criterion = SimilarityLoss() - self.optimizer = optim.Adam( - model.parameters(), - lr=config.get('learning_rate', 2e-4), - betas=(config.get('beta1', 0.5), config.get('beta2', 0.999)) - ) - - self.writer = None - self.best_val_loss = float('inf') - self.epochs_without_improvement = 0 - - # Для хранения истории метрик - self.history = { - 'train_loss': [], - 'val_loss': [], - 'val_accuracy': [], - 'val_precision': [], - 'val_recall': [], - 'val_f1': [], - 'learning_rate': [] - } - - def train_epoch(self, epoch: int) -> dict: - """Обучение на одной эпохе""" - self.model.train() - total_loss = 0 - total_samples = 0 - all_metrics = [] - - pbar = tqdm(self.trainloader, desc=f'Epoch {epoch}') - for batch_idx, batch in enumerate(pbar): - google_img = batch['google_img'].to(self.device) - yandex_img = batch['yandex_img'].to(self.device) - target = batch['same_domain'].float().to(self.device).unsqueeze(1) - - self.optimizer.zero_grad() - - # Forward pass - output = self.model(google_img, yandex_img) - loss = self.criterion(output, target) - - # Backward pass - loss.backward() - - # Gradient clipping - if self.config.get('grad_clip', None): - torch.nn.utils.clip_grad_norm_( - self.model.parameters(), - self.config['grad_clip'] - ) - - self.optimizer.step() - - total_loss += loss.item() * google_img.size(0) - total_samples += google_img.size(0) - - # Compute metrics - if batch_idx % self.config.get('log_interval', 10) == 0: - metrics = self.criterion.compute_metrics(output, target) - all_metrics.append(metrics) - pbar.set_postfix({ - 'loss': f"{loss.item():.4f}", - 'acc': f"{metrics['accuracy']:.4f}" - }) - - if self.writer: - global_step = epoch * len(self.trainloader) + batch_idx - self.writer.add_scalar('train/loss', loss.item(), global_step) - self.writer.add_scalar('train/accuracy', metrics['accuracy'], global_step) - - avg_loss = total_loss / total_samples - - # Average metrics - if all_metrics: - avg_metrics = { - key: sum(m[key] for m in all_metrics) / len(all_metrics) - for key in all_metrics[0].keys() - } - else: - avg_metrics = {} - - return {'loss': avg_loss, **avg_metrics} - - def validate(self) -> dict: - """Валидация модели""" - self.model.eval() - total_loss = 0 - total_samples = 0 - all_metrics = [] - - # Для ROC и confusion matrix - all_predictions = [] - all_targets = [] - - with torch.no_grad(): - for batch in tqdm(self.valloader, desc='Validation'): - google_img = batch['google_img'].to(self.device) - yandex_img = batch['yandex_img'].to(self.device) - target = batch['same_domain'].float().to(self.device).unsqueeze(1) - - output = self.model(google_img, yandex_img) - loss = self.criterion(output, target) - - total_loss += loss.item() * google_img.size(0) - total_samples += google_img.size(0) - - metrics = self.criterion.compute_metrics(output, target) - all_metrics.append(metrics) - - all_predictions.append(output.cpu()) - all_targets.append(target.cpu()) - - avg_loss = total_loss / total_samples - avg_metrics = { - key: sum(m[key] for m in all_metrics) / len(all_metrics) - for key in all_metrics[0].keys() - } - - # Concatenate all predictions and targets - all_predictions = torch.cat(all_predictions, dim=0) - all_targets = torch.cat(all_targets, dim=0) - - return { - 'loss': avg_loss, - **avg_metrics, - 'predictions': all_predictions, - 'targets': all_targets - } - - def train(self, num_epochs: int): - """Основной цикл обучения""" - log_dir = os.path.join(self.config.get('output_dir', 'runs/similarity')) - os.makedirs(log_dir, exist_ok=True) - self.writer = SummaryWriter(log_dir) - - print(f'\n{"="*70}') - print(f'Starting training for {num_epochs} epochs') - print(f'Logging to {log_dir}') - print(f'{"="*70}\n') - - start_time = time.time() - - for epoch in range(1, num_epochs + 1): - epoch_start = time.time() - print(f'\n--- Epoch {epoch}/{num_epochs} ---') - - # Train - train_metrics = self.train_epoch(epoch) - - # Validate - val_metrics = self.validate() - - # Store history - self.history['train_loss'].append(train_metrics['loss']) - self.history['val_loss'].append(val_metrics['loss']) - self.history['val_accuracy'].append(val_metrics['accuracy']) - self.history['val_precision'].append(val_metrics['precision']) - self.history['val_recall'].append(val_metrics['recall']) - self.history['val_f1'].append(val_metrics['f1']) - self.history['learning_rate'].append( - self.optimizer.param_groups[0]['lr'] - ) - - # Print metrics - print(f'\nTrain Loss: {train_metrics["loss"]:.4f}') - print(f'Val Loss: {val_metrics["loss"]:.4f}') - print(f'Val Accuracy: {val_metrics["accuracy"]:.4f}') - print(f'Val Precision: {val_metrics["precision"]:.4f}') - print(f'Val Recall: {val_metrics["recall"]:.4f}') - print(f'Val F1: {val_metrics["f1"]:.4f}') - - epoch_time = time.time() - epoch_start - print(f'Epoch time: {epoch_time:.2f}s') - - # TensorBoard logging - if self.writer: - self.writer.add_scalar('epoch/train_loss', train_metrics['loss'], epoch) - self.writer.add_scalar('epoch/val_loss', val_metrics['loss'], epoch) - self.writer.add_scalar('epoch/val_accuracy', val_metrics['accuracy'], epoch) - self.writer.add_scalar('epoch/val_precision', val_metrics['precision'], epoch) - self.writer.add_scalar('epoch/val_recall', val_metrics['recall'], epoch) - self.writer.add_scalar('epoch/val_f1', val_metrics['f1'], epoch) - - # Save checkpoint - if val_metrics['loss'] < self.best_val_loss: - self.best_val_loss = val_metrics['loss'] - self.epochs_without_improvement = 0 - self.save_checkpoint(epoch, val_metrics['loss'], is_best=True) - print(f'✓ New best model saved with val loss: {val_metrics["loss"]:.4f}') - else: - self.epochs_without_improvement += 1 - if epoch % self.config.get('save_interval', 5) == 0: - self.save_checkpoint(epoch, val_metrics['loss'], is_best=False) - - # Early stopping - patience = self.config.get('early_stopping_patience', 20) - if self.epochs_without_improvement >= patience: - print(f'\n⚠ Early stopping triggered after {patience} epochs without improvement') - break - - total_time = time.time() - start_time - print(f'\n{"="*70}') - print(f'Training completed in {total_time/60:.2f} minutes') - print(f'Best validation loss: {self.best_val_loss:.4f}') - print(f'{"="*70}\n') - - self.writer.close() - - return self.history - - def save_checkpoint(self, epoch: int, val_loss: float, is_best: bool = False): - """Сохранение чекпоинта модели""" - checkpoint_dir = os.path.join( - self.config.get('output_dir', 'runs/similarity'), - 'checkpoints' - ) - os.makedirs(checkpoint_dir, exist_ok=True) - - checkpoint = { - 'epoch': epoch, - 'model_state_dict': self.model.state_dict(), - 'optimizer_state_dict': self.optimizer.state_dict(), - 'val_loss': val_loss, - 'config': self.config, - 'history': self.history - } - - checkpoint_path = os.path.join(checkpoint_dir, f'checkpoint_epoch{epoch}.pt') - torch.save(checkpoint, checkpoint_path) - - if is_best: - best_path = os.path.join(checkpoint_dir, 'best_model.pt') - torch.save(checkpoint, best_path) - - def load_checkpoint(self, checkpoint_path: str): - """Загрузка чекпоинта""" - checkpoint = torch.load(checkpoint_path, map_location=self.device) - self.model.load_state_dict(checkpoint['model_state_dict']) - self.optimizer.load_state_dict(checkpoint['optimizer_state_dict']) - if 'history' in checkpoint: - self.history = checkpoint['history'] - return checkpoint['epoch'], checkpoint['val_loss'] - - -# ============================================================================= -# VISUALIZATION FUNCTIONS -# ============================================================================= - -def plot_training_history(history: dict, save_path: str = None): - """Построение графиков обучения""" - fig, axes = plt.subplots(2, 3, figsize=(18, 10)) - fig.suptitle('Training History - Siamese Network для корреляции снимков', - fontsize=16, fontweight='bold') - - epochs = range(1, len(history['train_loss']) + 1) - - # Loss - axes[0, 0].plot(epochs, history['train_loss'], 'b-', label='Train Loss', linewidth=2) - axes[0, 0].plot(epochs, history['val_loss'], 'r-', label='Val Loss', linewidth=2) - axes[0, 0].set_xlabel('Epoch') - axes[0, 0].set_ylabel('Loss') - axes[0, 0].set_title('Loss Curves') - axes[0, 0].legend() - axes[0, 0].grid(True, alpha=0.3) - - # Accuracy - axes[0, 1].plot(epochs, history['val_accuracy'], 'g-', linewidth=2) - axes[0, 1].set_xlabel('Epoch') - axes[0, 1].set_ylabel('Accuracy') - axes[0, 1].set_title('Validation Accuracy') - axes[0, 1].grid(True, alpha=0.3) - axes[0, 1].set_ylim([0, 1]) - - # F1 Score - axes[0, 2].plot(epochs, history['val_f1'], 'm-', linewidth=2) - axes[0, 2].set_xlabel('Epoch') - axes[0, 2].set_ylabel('F1 Score') - axes[0, 2].set_title('Validation F1 Score') - axes[0, 2].grid(True, alpha=0.3) - axes[0, 2].set_ylim([0, 1]) - - # Precision - axes[1, 0].plot(epochs, history['val_precision'], 'c-', linewidth=2) - axes[1, 0].set_xlabel('Epoch') - axes[1, 0].set_ylabel('Precision') - axes[1, 0].set_title('Validation Precision') - axes[1, 0].grid(True, alpha=0.3) - axes[1, 0].set_ylim([0, 1]) - - # Recall - axes[1, 1].plot(epochs, history['val_recall'], 'y-', linewidth=2) - axes[1, 1].set_xlabel('Epoch') - axes[1, 1].set_ylabel('Recall') - axes[1, 1].set_title('Validation Recall') - axes[1, 1].grid(True, alpha=0.3) - axes[1, 1].set_ylim([0, 1]) - - # Learning Rate - axes[1, 2].plot(epochs, history['learning_rate'], 'k-', linewidth=2) - axes[1, 2].set_xlabel('Epoch') - axes[1, 2].set_ylabel('Learning Rate') - axes[1, 2].set_title('Learning Rate Schedule') - axes[1, 2].grid(True, alpha=0.3) - axes[1, 2].set_yscale('log') - - plt.tight_layout() - - if save_path: - plt.savefig(save_path, dpi=300, bbox_inches='tight') - print(f'Training history plot saved to {save_path}') - - plt.show() - - -def plot_roc_curve(predictions: torch.Tensor, targets: torch.Tensor, save_path: str = None): - """Построение ROC кривой""" - from sklearn.metrics import roc_curve, auc - - predictions_np = predictions.numpy().flatten() - targets_np = targets.numpy().flatten() - - fpr, tpr, thresholds = roc_curve(targets_np, predictions_np) - roc_auc = auc(fpr, tpr) - - plt.figure(figsize=(10, 8)) - plt.plot(fpr, tpr, color='darkorange', lw=2, - label=f'ROC curve (AUC = {roc_auc:.3f})') - plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--', label='Random') - plt.xlim([0.0, 1.0]) - plt.ylim([0.0, 1.05]) - plt.xlabel('False Positive Rate', fontsize=12) - plt.ylabel('True Positive Rate', fontsize=12) - plt.title('ROC Curve - Siamese Network', fontsize=14, fontweight='bold') - plt.legend(loc="lower right", fontsize=12) - plt.grid(True, alpha=0.3) - - if save_path: - plt.savefig(save_path, dpi=300, bbox_inches='tight') - print(f'ROC curve saved to {save_path}') - - plt.show() - - return roc_auc - - -def plot_confusion_matrix(predictions: torch.Tensor, targets: torch.Tensor, - threshold: float = 0.5, save_path: str = None): - """Построение матрицы ошибок""" - from sklearn.metrics import confusion_matrix - - predictions_binary = (predictions.numpy().flatten() >= threshold).astype(int) - targets_binary = (targets.numpy().flatten() >= 0.5).astype(int) - - cm = confusion_matrix(targets_binary, predictions_binary) - - plt.figure(figsize=(10, 8)) - plt.imshow(cm, interpolation='nearest', cmap=plt.cm.Blues) - plt.title('Confusion Matrix - Correlation Detection', fontsize=14, fontweight='bold') - plt.colorbar() - - classes = ['Different Domains', 'Same Domain'] - tick_marks = np.arange(len(classes)) - plt.xticks(tick_marks, classes, rotation=45, fontsize=12) - plt.yticks(tick_marks, classes, fontsize=12) - - # Add text annotations - thresh = cm.max() / 2. - for i in range(cm.shape[0]): - for j in range(cm.shape[1]): - plt.text(j, i, format(cm[i, j], 'd'), - ha="center", va="center", - color="white" if cm[i, j] > thresh else "black", - fontsize=16, fontweight='bold') - - plt.ylabel('True Label', fontsize=12) - plt.xlabel('Predicted Label', fontsize=12) - plt.tight_layout() - - if save_path: - plt.savefig(save_path, dpi=300, bbox_inches='tight') - print(f'Confusion matrix saved to {save_path}') - - plt.show() - - -def plot_similarity_distribution(predictions: torch.Tensor, targets: torch.Tensor, - save_path: str = None): - """Распределение предсказанных значений схожести""" - predictions_np = predictions.numpy().flatten() - targets_np = targets.numpy().flatten() - - same_domain = predictions_np[targets_np >= 0.5] - diff_domain = predictions_np[targets_np < 0.5] - - plt.figure(figsize=(12, 6)) - - plt.subplot(1, 2, 1) - plt.hist(same_domain, bins=50, alpha=0.7, color='green', edgecolor='black', label='Same Domain') - plt.hist(diff_domain, bins=50, alpha=0.7, color='red', edgecolor='black', label='Different Domains') - plt.xlabel('Predicted Similarity Score', fontsize=12) - plt.ylabel('Frequency', fontsize=12) - plt.title('Distribution of Similarity Scores', fontsize=14, fontweight='bold') - plt.legend(fontsize=11) - plt.grid(True, alpha=0.3) - - plt.subplot(1, 2, 2) - plt.boxplot([diff_domain, same_domain], labels=['Different', 'Same']) - plt.ylabel('Similarity Score', fontsize=12) - plt.xlabel('Domain Match', fontsize=12) - plt.title('Similarity Score by Domain Match', fontsize=14, fontweight='bold') - plt.grid(True, alpha=0.3, axis='y') - - plt.tight_layout() - - if save_path: - plt.savefig(save_path, dpi=300, bbox_inches='tight') - print(f'Similarity distribution plot saved to {save_path}') - - plt.show() - - # Print statistics - print(f'\n--- Similarity Score Statistics ---') - print(f'Same Domain:') - print(f' Mean: {same_domain.mean():.4f}') - print(f' Std: {same_domain.std():.4f}') - print(f' Min: {same_domain.min():.4f}') - print(f' Max: {same_domain.max():.4f}') - print(f'\nDifferent Domains:') - print(f' Mean: {diff_domain.mean():.4f}') - print(f' Std: {diff_domain.std():.4f}') - print(f' Min: {diff_domain.min():.4f}') - print(f' Max: {diff_domain.max():.4f}') - - -def visualize_sample_predictions(model: nn.Module, dataset, device: torch.device, - num_samples: int = 8, save_path: str = None): - """Визуализация примеров предсказаний""" - model.eval() - - # Get random samples - indices = np.random.choice(len(dataset), num_samples, replace=False) - - fig, axes = plt.subplots(num_samples, 3, figsize=(15, 5*num_samples)) - if num_samples == 1: - axes = axes.reshape(1, -1) - - fig.suptitle('Sample Predictions - Siamese Network для корреляции карт', - fontsize=16, fontweight='bold') - - with torch.no_grad(): - for idx, sample_idx in enumerate(indices): - sample = dataset[sample_idx] - - google_img = sample['google_img'].unsqueeze(0).to(device) - yandex_img = sample['yandex_img'].unsqueeze(0).to(device) - true_label = sample['same_domain'].item() - - # Predict - pred_similarity = model(google_img, yandex_img).item() - pred_label = int(pred_similarity >= 0.5) - - # Denormalize images for visualization - google_np = google_img.squeeze(0).cpu().numpy().transpose(1, 2, 0) - yandex_np = yandex_img.squeeze(0).cpu().numpy().transpose(1, 2, 0) - - # Denormalize (assuming ImageNet normalization) - mean = np.array([0.485, 0.456, 0.406]) - std = np.array([0.229, 0.224, 0.225]) - google_np = std * google_np + mean - yandex_np = std * yandex_np + mean - google_np = np.clip(google_np, 0, 1) - yandex_np = np.clip(yandex_np, 0, 1) - - # Plot Google image - axes[idx, 0].imshow(google_np) - axes[idx, 0].set_title('Google Map', fontsize=12, fontweight='bold') - axes[idx, 0].axis('off') - - # Plot Yandex image - axes[idx, 1].imshow(yandex_np) - axes[idx, 1].set_title('Yandex Map', fontsize=12, fontweight='bold') - axes[idx, 1].axis('off') - - # Plot prediction info - axes[idx, 2].axis('off') - - # Determine color based on correctness - is_correct = (pred_label == true_label) - color = 'green' if is_correct else 'red' - result = '✓ Correct' if is_correct else '✗ Incorrect' - - info_text = f""" -Prediction: {pred_similarity:.4f} -Predicted Label: {'Same' if pred_label == 1 else 'Different'} -True Label: {'Same' if true_label == 1 else 'Different'} - -{result} - """ - - axes[idx, 2].text(0.5, 0.5, info_text, - ha='center', va='center', - fontsize=12, - bbox=dict(boxstyle='round', facecolor=color, alpha=0.2), - transform=axes[idx, 2].transAxes) - - plt.tight_layout() - - if save_path: - plt.savefig(save_path, dpi=300, bbox_inches='tight') - print(f'Sample predictions saved to {save_path}') - - plt.show() - - -def visualize_feature_space(model: nn.Module, dataloader: DataLoader, - device: torch.device, max_samples: int = 500, - save_path: str = None): - """Визуализация пространства признаков с помощью t-SNE""" - from sklearn.manifold import TSNE - - model.eval() - - all_features_google = [] - all_features_yandex = [] - all_labels = [] - - with torch.no_grad(): - for i, batch in enumerate(tqdm(dataloader, desc='Extracting features')): - if i * dataloader.batch_size >= max_samples: - break - - google_img = batch['google_img'].to(device) - yandex_img = batch['yandex_img'].to(device) - labels = batch['same_domain'].cpu().numpy() - - # Extract features - features_google = model.extract_features(google_img).cpu().numpy() - features_yandex = model.extract_features(yandex_img).cpu().numpy() - - all_features_google.append(features_google) - all_features_yandex.append(features_yandex) - all_labels.append(labels) - - all_features_google = np.concatenate(all_features_google, axis=0) - all_features_yandex = np.concatenate(all_features_yandex, axis=0) - all_labels = np.concatenate(all_labels, axis=0) - - # Combine features - all_features = np.concatenate([all_features_google, all_features_yandex], axis=0) - all_labels = np.concatenate([all_labels, all_labels], axis=0) - - print(f'\nApplying t-SNE to {all_features.shape[0]} samples...') - tsne = TSNE(n_components=2, random_state=42, perplexity=30) - features_2d = tsne.fit_transform(all_features) - - # Split back into Google and Yandex - n_samples = len(all_labels) // 2 - features_google_2d = features_2d[:n_samples] - features_yandex_2d = features_2d[n_samples:] - labels = all_labels[:n_samples] - - # Plot - fig, axes = plt.subplots(1, 2, figsize=(20, 8)) - - # Google features - for label in [0, 1]: - mask = labels == label - axes[0].scatter( - features_google_2d[mask, 0], - features_google_2d[mask, 1], - c='green' if label == 1 else 'red', - label='Same Domain' if label == 1 else 'Different Domains', - alpha=0.6, - s=50 - ) - axes[0].set_title('Google Maps Features (t-SNE)', fontsize=14, fontweight='bold') - axes[0].set_xlabel('t-SNE Component 1', fontsize=12) - axes[0].set_ylabel('t-SNE Component 2', fontsize=12) - axes[0].legend(fontsize=11) - axes[0].grid(True, alpha=0.3) - - # Yandex features - for label in [0, 1]: - mask = labels == label - axes[1].scatter( - features_yandex_2d[mask, 0], - features_yandex_2d[mask, 1], - c='green' if label == 1 else 'red', - label='Same Domain' if label == 1 else 'Different Domains', - alpha=0.6, - s=50 - ) - axes[1].set_title('Yandex Maps Features (t-SNE)', fontsize=14, fontweight='bold') - axes[1].set_xlabel('t-SNE Component 1', fontsize=12) - axes[1].set_ylabel('t-SNE Component 2', fontsize=12) - axes[1].legend(fontsize=11) - axes[1].grid(True, alpha=0.3) - - plt.tight_layout() - - if save_path: - plt.savefig(save_path, dpi=300, bbox_inches='tight') - print(f'Feature space visualization saved to {save_path}') - - plt.show() - - -def generate_correlation_heatmap(model: nn.Module, dataloader: DataLoader, - device: torch.device, num_samples: int = 20, - save_path: str = None): - """Создание тепловой карты корреляций между снимками""" - model.eval() - - # Collect samples - google_images = [] - yandex_images = [] - labels = [] - - with torch.no_grad(): - for i, batch in enumerate(dataloader): - if len(google_images) >= num_samples: - break - - google_img = batch['google_img'].to(device) - yandex_img = batch['yandex_img'].to(device) - label = batch['same_domain'] - - google_images.append(google_img[:1]) - yandex_images.append(yandex_img[:1]) - labels.append(label[:1].item()) - - google_images = torch.cat(google_images[:num_samples], dim=0) - yandex_images = torch.cat(yandex_images[:num_samples], dim=0) - - # Compute similarity matrix - similarity_matrix = np.zeros((num_samples, num_samples)) - - with torch.no_grad(): - for i in tqdm(range(num_samples), desc='Computing correlations'): - for j in range(num_samples): - google_i = google_images[i:i+1] - yandex_j = yandex_images[j:j+1] - similarity = model(google_i, yandex_j).item() - similarity_matrix[i, j] = similarity - - # Plot heatmap - plt.figure(figsize=(14, 12)) - im = plt.imshow(similarity_matrix, cmap='RdYlGn', aspect='auto', vmin=0, vmax=1) - - plt.colorbar(im, label='Similarity Score', fraction=0.046, pad=0.04) - plt.title('Correlation Heatmap: Google vs Yandex Maps\n(Матрица корреляций снимков)', - fontsize=16, fontweight='bold', pad=20) - plt.xlabel('Yandex Map Index', fontsize=12) - plt.ylabel('Google Map Index', fontsize=12) - - # Add grid - plt.grid(True, which='both', color='gray', linestyle='-', linewidth=0.5, alpha=0.3) - - # Add text annotations for diagonal (same pairs) - for i in range(min(num_samples, 10)): # Annotate first 10 for readability - if labels[i] == 1: # True match - plt.text(i, i, '✓', ha='center', va='center', - color='white', fontsize=12, fontweight='bold') - - plt.tight_layout() - - if save_path: - plt.savefig(save_path, dpi=300, bbox_inches='tight') - print(f'Correlation heatmap saved to {save_path}') - - plt.show() - - # Print statistics - diagonal = np.diag(similarity_matrix) - off_diagonal = similarity_matrix[~np.eye(num_samples, dtype=bool)] - - print(f'\n--- Correlation Statistics ---') - print(f'Diagonal (matched pairs):') - print(f' Mean: {diagonal.mean():.4f}') - print(f' Std: {diagonal.std():.4f}') - print(f'\nOff-diagonal (mismatched pairs):') - print(f' Mean: {off_diagonal.mean():.4f}') - print(f' Std: {off_diagonal.std():.4f}') - - -# ============================================================================= -# MAIN TRAINING SCRIPT -# ============================================================================= - -def main(): - """Основная функция обучения""" - - # Configuration - config_dict = config.copy() - if isinstance(config_dict.get('image_size'), list): - config_dict['image_size'] = tuple(config_dict['image_size']) - - # Device - device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') - print(f'\n{"="*70}') - print(f'Siamese Network Training for Map Correlation') - print(f'Обучение сиамской сети для корреляции снимков') - print(f'{"="*70}') - print(f'Using device: {device}') - if torch.cuda.is_available(): - print(f'GPU: {torch.cuda.get_device_name(0)}') - print(f'GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB') - - # Create data loaders - print(f'\n{"="*70}') - print('Creating data loaders...') - print(f'{"="*70}') - - train_loader, val_loader = create_data_loaders( - root_dir=config_dict['data_dir'], - batch_size=config_dict['batch_size'], - train_split=config_dict['train_split'], - num_workers=config_dict['num_workers'], - image_size=config_dict['image_size'], - augment_train=True, - augment_val=False, - device=device - ) - - print(f'Train batches: {len(train_loader)}') - print(f'Val batches: {len(val_loader)}') - print(f'Train samples: {len(train_loader.dataset)}') - print(f'Val samples: {len(val_loader.dataset)}') - - # Create model - print(f'\n{"="*70}') - print('Creating model...') - print(f'{"="*70}') - - model = create_similarity_model( - model_type='backbone', - input_size=config_dict['image_size'][0] if isinstance(config_dict['image_size'], (tuple, list)) else config_dict['image_size'], - input_channels=3, - backbone_name='resnet18', - pretrained=True, - dropout_rate=0.3, - use_batch_norm=True - ) - - total_params = sum(p.numel() for p in model.parameters()) - trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) - print(f'Total parameters: {total_params:,}') - print(f'Trainable parameters: {trainable_params:,}') - - # Create trainer - trainer = SimilarityTrainer( - model=model, - trainloader=train_loader, - valloader=val_loader, - device=device, - config=config_dict - ) - - # Train model - print(f'\n{"="*70}') - print('Starting training...') - print(f'{"="*70}') - - history = trainer.train(config_dict['epochs']) - - # ============================================================================= - # VISUALIZATION AND RESULTS - # ============================================================================= - - print(f'\n{"="*70}') - print('Generating visualizations...') - print(f'{"="*70}') - - output_dir = config_dict.get('output_dir', 'runs/similarity') - vis_dir = os.path.join(output_dir, 'visualizations') - os.makedirs(vis_dir, exist_ok=True) - - # 1. Training history - print('\n1. Plotting training history...') - plot_training_history( - history, - save_path=os.path.join(vis_dir, 'training_history.png') - ) - - # 2. Validation metrics - print('\n2. Computing validation predictions...') - trainer.model.eval() - val_predictions = [] - val_targets = [] - - with torch.no_grad(): - for batch in tqdm(val_loader, desc='Validation'): - google_img = batch['google_img'].to(device) - yandex_img = batch['yandex_img'].to(device) - target = batch['same_domain'].float().unsqueeze(1) - - output = trainer.model(google_img, yandex_img) - val_predictions.append(output.cpu()) - val_targets.append(target.cpu()) - - val_predictions = torch.cat(val_predictions, dim=0) - val_targets = torch.cat(val_targets, dim=0) - - # 3. ROC curve - print('\n3. Plotting ROC curve...') - roc_auc = plot_roc_curve( - val_predictions, - val_targets, - save_path=os.path.join(vis_dir, 'roc_curve.png') - ) - print(f'ROC AUC Score: {roc_auc:.4f}') - - # 4. Confusion matrix - print('\n4. Plotting confusion matrix...') - plot_confusion_matrix( - val_predictions, - val_targets, - threshold=0.5, - save_path=os.path.join(vis_dir, 'confusion_matrix.png') - ) - - # 5. Similarity distribution - print('\n5. Plotting similarity distribution...') - plot_similarity_distribution( - val_predictions, - val_targets, - save_path=os.path.join(vis_dir, 'similarity_distribution.png') - ) - - # 6. Sample predictions - print('\n6. Visualizing sample predictions...') - visualize_sample_predictions( - trainer.model, - val_loader.dataset, - device, - num_samples=8, - save_path=os.path.join(vis_dir, 'sample_predictions.png') - ) - - # 7. Feature space visualization - print('\n7. Visualizing feature space (t-SNE)...') - visualize_feature_space( - trainer.model, - val_loader, - device, - max_samples=500, - save_path=os.path.join(vis_dir, 'feature_space_tsne.png') - ) - - # 8. Correlation heatmap - print('\n8. Generating correlation heatmap...') - generate_correlation_heatmap( - trainer.model, - val_loader, - device, - num_samples=20, - save_path=os.path.join(vis_dir, 'correlation_heatmap.png') - ) - - # ============================================================================= - # FINAL RESULTS SUMMARY - # ============================================================================= - - print(f'\n{"="*70}') - print('FINAL RESULTS SUMMARY') - print('ИТОГОВЫЕ РЕЗУЛЬТАТЫ') - print(f'{"="*70}') - - print(f'\nBest Validation Loss: {trainer.best_val_loss:.4f}') - print(f'Final Validation Accuracy: {history["val_accuracy"][-1]:.4f}') - print(f'Final Validation F1 Score: {history["val_f1"][-1]:.4f}') - print(f'Final Validation Precision: {history["val_precision"][-1]:.4f}') - print(f'Final Validation Recall: {history["val_recall"][-1]:.4f}') - print(f'ROC AUC Score: {roc_auc:.4f}') - - print(f'\nCheckpoints saved to: {os.path.join(output_dir, "checkpoints")}') - print(f'Visualizations saved to: {vis_dir}') - - print(f'\n{"="*70}') - print('Training and visualization completed successfully!') - print('Обучение и визуализация завершены успешно!') - print(f'{"="*70}\n') - - -if __name__ == '__main__': - main() \ No newline at end of file diff --git a/models/SiaN/homography.py b/models/SiaN/homography.py deleted file mode 100644 index 1e5464f..0000000 --- a/models/SiaN/homography.py +++ /dev/null @@ -1,434 +0,0 @@ -import os -import random -from typing import Any, Dict, List, Optional, Tuple - -import cv2 -import numpy as np -import torch -from PIL import Image -from torch.utils.data import DataLoader, Dataset - - -class HomographyDataset(Dataset): - """ - Dataset for homography estimation between Yandex and Google map image pairs. - - This dataset loads pairs of images (Yandex and Google maps) and provides - homography matrices for data augmentation and training. - """ - - def __init__( - self, - root_dir: str, - transform=None, - augment: bool = True, - max_samples: Optional[int] = None, - image_size: Tuple[int, int] = (700, 700), - cache_homographies: bool = True, - ): - """ - Initialize the HomographyDataset. - - Args: - root_dir: Directory containing image pairs (format: {idx:04d}_google.png, {idx:04d}_yandex.png) - transform: Optional torchvision transforms to apply - augment: Whether to apply homography-based data augmentation - max_samples: Maximum number of samples to load (None for all) - image_size: Target size for images (height, width) - cache_homographies: Whether to cache generated homography matrices to disk - """ - self.root_dir = root_dir - self.transform = transform - self.augment = augment - self.image_size = image_size - self.cache_homographies = cache_homographies - - # Find all image pairs - self.image_pairs = self._discover_image_pairs() - - if max_samples is not None: - self.image_pairs = self.image_pairs[:max_samples] - - print(f"Found {len(self.image_pairs)} image pairs in {root_dir}") - - # Create directory for cached homographies if needed - if cache_homographies: - self.homography_cache_dir = os.path.join(root_dir, "homography_cache") - os.makedirs(self.homography_cache_dir, exist_ok=True) - - def _discover_image_pairs(self) -> List[Dict[str, Any]]: - """Discover all Google-Yandex image pairs in the dataset directory.""" - image_pairs = [] - - # Get all Google images - google_files = [ - f for f in os.listdir(self.root_dir) if f.endswith("_google.png") - ] - - for google_file in sorted(google_files): - # Extract index from filename - idx_str = google_file.split("_")[0] - try: - idx = int(idx_str) - except ValueError: - continue - - # Check if corresponding Yandex image exists - yandex_file = f"{idx:04d}_yandex.png" - yandex_path = os.path.join(self.root_dir, yandex_file) - - if os.path.exists(yandex_path): - image_pairs.append( - { - "idx": idx, - "google_path": os.path.join(self.root_dir, google_file), - "yandex_path": yandex_path, - } - ) - - return image_pairs - - def __len__(self) -> int: - """Return the number of image pairs in the dataset.""" - return len(self.image_pairs) - - def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]: - """ - Get a sample from the dataset. - - Returns a dictionary with: - - 'google_img': Google map image tensor - - 'yandex_img': Yandex map image tensor - - 'homography': Ground truth homography matrix (3x3) - - 'idx': Sample index - """ - pair_info = self.image_pairs[idx] - - # Load images - google_img = Image.open(pair_info["google_path"]).convert("RGB") - yandex_img = Image.open(pair_info["yandex_path"]).convert("RGB") - - # Resize images to target size - google_img = google_img.resize( - (self.image_size[1], self.image_size[0]), Image.BILINEAR - ) - yandex_img = yandex_img.resize( - (self.image_size[1], self.image_size[0]), Image.BILINEAR - ) - - # Get or generate homography matrix - homography_matrix = self._get_homography_matrix(pair_info["idx"]) - - # Apply data augmentation if enabled - if self.augment: - google_img, yandex_img, homography_matrix = self._apply_augmentation( - google_img, yandex_img, homography_matrix - ) - - # Convert images to tensors - if self.transform: - google_img = self.transform(google_img) - yandex_img = self.transform(yandex_img) - else: - # Default conversion to tensor - google_img = ( - torch.from_numpy(np.array(google_img)).float().permute(2, 0, 1) / 255.0 - ) - yandex_img = ( - torch.from_numpy(np.array(yandex_img)).float().permute(2, 0, 1) / 255.0 - ) - - # Convert homography to tensor - homography_tensor = torch.from_numpy(homography_matrix).float() - - return { - "google_img": google_img, - "yandex_img": yandex_img, - "homography": homography_tensor, - "idx": torch.tensor(pair_info["idx"], dtype=torch.long), - } - - def _get_homography_matrix(self, idx: int) -> np.ndarray: - """ - Get homography matrix for a given index. - - If cached homography exists, load it. Otherwise generate a new one. - """ - if self.cache_homographies: - cache_path = os.path.join( - self.homography_cache_dir, f"{idx:04d}_homography.npy" - ) - if os.path.exists(cache_path): - return np.load(cache_path) - - # Generate new homography matrix - homography_matrix = self.generate_random_homography() - - # Cache if enabled - if self.cache_homographies: - np.save(cache_path, homography_matrix) - - return homography_matrix - - def generate_random_homography(self) -> np.ndarray: - """ - Generate a random homography matrix for data augmentation. - - Returns: - np.ndarray: 3x3 homography matrix. - """ - # Generate random affine transformation parameters - angle = np.random.uniform(-30, 30) # rotation in degrees - scale = np.random.uniform(0.8, 1.2) # scaling factor - tx = np.random.uniform(-50, 50) # translation in x - ty = np.random.uniform(-50, 50) # translation in y - - # Convert angle to radians - theta = np.radians(angle) - - # Create affine transformation matrix - affine_matrix = np.array( - [ - [scale * np.cos(theta), -scale * np.sin(theta), tx], - [scale * np.sin(theta), scale * np.cos(theta), ty], - [0, 0, 1], - ] - ) - - # Add small perspective distortion - perspective = np.random.uniform(-0.001, 0.001, (2, 3)) - perspective = np.vstack([perspective, [0, 0, 0]]) - - homography_matrix = affine_matrix + perspective - - return homography_matrix - - def _apply_augmentation( - self, - google_img: Image.Image, - yandex_img: Image.Image, - base_homography: np.ndarray, - ) -> Tuple[Image.Image, Image.Image, np.ndarray]: - """ - Apply homography-based data augmentation to image pair. - - Args: - google_img: Google map image - yandex_img: Yandex map image - base_homography: Base homography matrix - - Returns: - Tuple of (augmented_google_img, augmented_yandex_img, augmented_homography) - """ - # Generate augmentation homography - aug_homography = self.generate_random_homography() - - # Combine with base homography - combined_homography = aug_homography @ base_homography - - # Apply augmentation to both images - google_aug = self._apply_homography_to_image(google_img, aug_homography) - yandex_aug = self._apply_homography_to_image(yandex_img, aug_homography) - - return google_aug, yandex_aug, combined_homography - - def _apply_homography_to_image( - self, img: Image.Image, homography: np.ndarray - ) -> Image.Image: - """ - Apply homography transformation to a single image. - - Args: - img: PIL Image to transform - homography: 3x3 homography matrix - - Returns: - Transformed PIL Image - """ - # Convert to numpy array - img_np = np.array(img) - - # Get image dimensions - h, w = img_np.shape[:2] - - # Apply homography transformation - transformed = cv2.warpPerspective( - img_np, - homography, - (w, h), - flags=cv2.INTER_LINEAR, - borderMode=cv2.BORDER_REFLECT, - ) - - # Convert back to PIL Image - return Image.fromarray(transformed) - - def get_sample_without_augmentation(self, idx: int) -> Dict[str, Any]: - """ - Get a sample without data augmentation. - - Useful for visualization and evaluation. - """ - pair_info = self.image_pairs[idx] - - # Load images - google_img = Image.open(pair_info["google_path"]).convert("RGB") - yandex_img = Image.open(pair_info["yandex_path"]).convert("RGB") - - # Resize - google_img = google_img.resize( - (self.image_size[1], self.image_size[0]), Image.BILINEAR - ) - yandex_img = yandex_img.resize( - (self.image_size[1], self.image_size[0]), Image.BILINEAR - ) - - # Get homography matrix - homography_matrix = self._get_homography_matrix(pair_info["idx"]) - - return { - "google_img": google_img, - "yandex_img": yandex_img, - "homography": homography_matrix, - "idx": pair_info["idx"], - "google_path": pair_info["google_path"], - "yandex_path": pair_info["yandex_path"], - } - - -def create_data_loaders( - root_dir: str, - batch_size: int = 32, - train_split: float = 0.8, - num_workers: int = 4, - image_size: Tuple[int, int] = (256, 256), - augment_train: bool = True, - augment_val: bool = False, -) -> Tuple[DataLoader, DataLoader]: - """ - Create train and validation data loaders for homography estimation. - - Args: - root_dir: Directory containing image pairs - batch_size: Batch size for data loaders - train_split: Fraction of data to use for training - num_workers: Number of worker processes for data loading - image_size: Target image size (height, width) - augment_train: Whether to augment training data - augment_val: Whether to augment validation data - - Returns: - Tuple of (train_loader, val_loader) - """ - from torchvision import transforms - - # Define transforms - transform = transforms.Compose( - [ - transforms.ToTensor(), - transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), - ] - ) - - # Create full dataset - full_dataset = HomographyDataset( - root_dir=root_dir, - transform=transform, - augment=False, # We'll handle augmentation separately - image_size=image_size, - cache_homographies=True, - ) - - # Split dataset - dataset_size = len(full_dataset) - train_size = int(train_split * dataset_size) - val_size = dataset_size - train_size - - # Create indices for splitting - indices = list(range(dataset_size)) - random.shuffle(indices) - train_indices = indices[:train_size] - val_indices = indices[train_size:] - - # Create subset samplers - from torch.utils.data import Subset - - train_dataset = Subset(full_dataset, train_indices) - val_dataset = Subset(full_dataset, val_indices) - - # Apply augmentation by overriding __getitem__ for train dataset - if augment_train: - - class AugmentedSubset(Subset): - def __getitem__(self, idx): - sample = self.dataset[self.indices[idx]] - # Apply augmentation - google_img = sample["google_img"] - yandex_img = sample["yandex_img"] - homography = sample["homography"] - - # Generate augmentation homography - aug_homography = torch.from_numpy( - full_dataset.generate_random_homography() - ).float() - - # Combine homographies - combined_homography = aug_homography @ homography - - # Apply augmentation (simplified - in practice would warp images) - # For now, we just return the combined homography - return { - "google_img": google_img, - "yandex_img": yandex_img, - "homography": combined_homography, - "idx": sample["idx"], - } - - train_dataset = AugmentedSubset(full_dataset, train_indices) - - # Create data loaders - train_loader = DataLoader( - train_dataset, - batch_size=batch_size, - shuffle=True, - num_workers=num_workers, - pin_memory=True, - ) - - val_loader = DataLoader( - val_dataset, - batch_size=batch_size, - shuffle=False, - num_workers=num_workers, - pin_memory=True, - ) - - return train_loader, val_loader - - -if __name__ == "__main__": - # Example usage - dataset = HomographyDataset( - root_dir=r"C:\Users\admin\Projects\autopilot\datasets\ya_go_maps\images", - augment=True, - image_size=(256, 256), - ) - - print(f"Dataset size: {len(dataset)}") - - # Get a sample - sample = dataset[0] - print(f"Sample keys: {list(sample.keys())}") - print(f"Google image shape: {sample['google_img'].shape}") - print(f"Yandex image shape: {sample['yandex_img'].shape}") - print(f"Homography shape: {sample['homography'].shape}") - - # Create data loaders - train_loader, val_loader = create_data_loaders( - root_dir=r"C:\Users\admin\Projects\autopilot\datasets\ya_go_maps\images", - batch_size=16, - train_split=0.8, - ) - - print(f"Train batches: {len(train_loader)}") - print(f"Val batches: {len(val_loader)}") diff --git a/models/SiaN/model.py b/models/SiaN/model.py new file mode 100644 index 0000000..a921cc0 --- /dev/null +++ b/models/SiaN/model.py @@ -0,0 +1,152 @@ +from typing import Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torchvision import models + + +class HomographyCNN(nn.Module): + """ + Model for estimating homography matrix (3x3) between two images. + """ + + def __init__( + self, + input_channels: int = 3, + backbone_name: str = "resnet18", + pretrained: bool = True, + dropout_rate: float = 0.3, + use_batch_norm: bool = True, + ): + super().__init__() + + self.input_channels = input_channels + self.backbone_name = backbone_name + self.pretrained = pretrained + self.dropout_rate = dropout_rate + self.use_batch_norm = use_batch_norm + + backbone = self._create_backbone(backbone_name, pretrained) + + self.feature_dim = backbone.fc.in_features + backbone.fc = nn.Identity() + self.backbone = backbone + + compare_input_dim = self.feature_dim * 4 + + layers = [ + nn.Linear(compare_input_dim, 512), + nn.BatchNorm1d(512) if use_batch_norm else nn.Identity(), + nn.ReLU(inplace=True), + nn.Dropout(dropout_rate), + + nn.Linear(512, 256), + nn.BatchNorm1d(256) if use_batch_norm else nn.Identity(), + nn.ReLU(inplace=True), + nn.Dropout(dropout_rate), + + nn.Linear(256, 9), + ] + self.head = nn.Sequential(*layers) + + def _create_backbone(self, name: str, pretrained: bool) -> nn.Module: + name = name.lower() + if name == "resnet18": + model = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1 if pretrained else None) + elif name == "resnet34": + model = models.resnet34(weights=models.ResNet34_Weights.IMAGENET1K_V1 if pretrained else None) + else: + raise ValueError(f"Unsupported backbone: {name}") + if self.input_channels != 3: + old_conv = model.conv1 + model.conv1 = nn.Conv2d( + self.input_channels, + old_conv.out_channels, + kernel_size=old_conv.kernel_size, + stride=old_conv.stride, + padding=old_conv.padding, + bias=old_conv.bias is not None, + ) + return model + + def _extract_features(self, x: torch.Tensor) -> torch.Tensor: + return self.backbone(x) + + def forward(self, img1: torch.Tensor, img2: torch.Tensor) -> torch.Tensor: + f1 = self._extract_features(img1) + f2 = self._extract_features(img2) + + diff = torch.abs(f1 - f2) + prod = f1 * f2 + combined = torch.cat([f1, f2, diff, prod], dim=1) + + h = self.head(combined) + h = h.view(-1, 3, 3) + return h + + def predict_homography(self, img1: torch.Tensor, img2: torch.Tensor) -> torch.Tensor: + was_training = self.training + self.eval() + with torch.no_grad(): + h = self.forward(img1, img2) + if was_training: + self.train() + return h + + +class HomographyLoss(nn.Module): + def __init__(self): + super().__init__() + self.criterion = nn.MSELoss() + + def forward(self, pred_homography: torch.Tensor, target_homography: torch.Tensor) -> torch.Tensor: + return self.criterion(pred_homography, target_homography) + + +def create_homography_model( + model_type: str = "backbone", + input_size: Tuple[int, int] = (256, 256), + **kwargs, +) -> nn.Module: + if model_type == "backbone": + return HomographyCNN(**kwargs) + else: + raise ValueError(f"Unknown model type: {model_type}") + + +if __name__ == "__main__": + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + print(f"Using device: {device}") + + model = HomographyCNN( + input_channels=3, + backbone_name="resnet18", + pretrained=True, + dropout_rate=0.3, + use_batch_norm=True, + ).to(device) + + print(f"Model created with {sum(p.numel() for p in model.parameters()):,} parameters") + + batch_size = 4 + height, width = 256, 256 + + img1 = torch.randn(batch_size, 3, height, width).to(device) + img2 = torch.randn(batch_size, 3, height, width).to(device) + + print("\nTesting forward pass...") + output = model(img1, img2) + print(f"Output shape: {output.shape}") + + print("\nTesting prediction...") + pred = model.predict_homography(img1, img2) + print(f"Prediction shape: {pred.shape}") + + print("\nTesting loss function...") + target = torch.eye(3).unsqueeze(0).expand(batch_size, -1, -1).to(device) + loss_fn = HomographyLoss().to(device) + loss = loss_fn(output, target) + print(f"Loss value: {loss.item():.6f}") + + print("\nAll tests completed successfully!") diff --git a/models/SiaN/notebook.ipynb b/models/SiaN/notebook.ipynb index 6ed0b57..918590c 100644 --- a/models/SiaN/notebook.ipynb +++ b/models/SiaN/notebook.ipynb @@ -417,7 +417,379 @@ "id": "2dad9a5f", "metadata": {}, "outputs": [], - "source": [] + "source": [ + "from typing import Tuple\n", + "\n", + "import torch\n", + "import torch.nn as nn\n", + "import torch.nn.functional as F\n", + "from torchvision import models\n", + "\n", + "\n", + "class HomographyCNN(nn.Module):\n", + " \"\"\"\n", + " Model for estimating homography matrix (3x3) between two images.\n", + " \"\"\"\n", + "\n", + " def __init__(\n", + " self,\n", + " input_channels: int = 3,\n", + " backbone_name: str = \"resnet18\",\n", + " pretrained: bool = True,\n", + " dropout_rate: float = 0.3,\n", + " use_batch_norm: bool = True,\n", + " ):\n", + " super().__init__()\n", + "\n", + " self.input_channels = input_channels\n", + " self.backbone_name = backbone_name\n", + " self.pretrained = pretrained\n", + " self.dropout_rate = dropout_rate\n", + " self.use_batch_norm = use_batch_norm\n", + "\n", + " backbone = self._create_backbone(backbone_name, pretrained)\n", + "\n", + " self.feature_dim = backbone.fc.in_features\n", + " backbone.fc = nn.Identity()\n", + " self.backbone = backbone\n", + "\n", + " compare_input_dim = self.feature_dim * 4\n", + "\n", + " layers = [\n", + " nn.Linear(compare_input_dim, 512),\n", + " nn.BatchNorm1d(512) if use_batch_norm else nn.Identity(),\n", + " nn.ReLU(inplace=True),\n", + " nn.Dropout(dropout_rate),\n", + "\n", + " nn.Linear(512, 256),\n", + " nn.BatchNorm1d(256) if use_batch_norm else nn.Identity(),\n", + " nn.ReLU(inplace=True),\n", + " nn.Dropout(dropout_rate),\n", + "\n", + " nn.Linear(256, 9),\n", + " ]\n", + " self.head = nn.Sequential(*layers)\n", + "\n", + " def _create_backbone(self, name: str, pretrained: bool) -> nn.Module:\n", + " name = name.lower()\n", + " if name == \"resnet18\":\n", + " model = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1 if pretrained else None)\n", + " elif name == \"resnet34\":\n", + " model = models.resnet34(weights=models.ResNet34_Weights.IMAGENET1K_V1 if pretrained else None)\n", + " else:\n", + " raise ValueError(f\"Unsupported backbone: {name}\")\n", + " if self.input_channels != 3:\n", + " old_conv = model.conv1\n", + " model.conv1 = nn.Conv2d(\n", + " self.input_channels,\n", + " old_conv.out_channels,\n", + " kernel_size=old_conv.kernel_size,\n", + " stride=old_conv.stride,\n", + " padding=old_conv.padding,\n", + " bias=old_conv.bias is not None,\n", + " )\n", + " return model\n", + "\n", + " def _extract_features(self, x: torch.Tensor) -> torch.Tensor:\n", + " return self.backbone(x)\n", + "\n", + " def forward(self, img1: torch.Tensor, img2: torch.Tensor) -> torch.Tensor:\n", + " f1 = self._extract_features(img1)\n", + " f2 = self._extract_features(img2)\n", + "\n", + " diff = torch.abs(f1 - f2)\n", + " prod = f1 * f2\n", + " combined = torch.cat([f1, f2, diff, prod], dim=1)\n", + "\n", + " h = self.head(combined)\n", + " h = h.view(-1, 3, 3)\n", + " return h\n", + "\n", + " def predict_homography(self, img1: torch.Tensor, img2: torch.Tensor) -> torch.Tensor:\n", + " was_training = self.training\n", + " self.eval()\n", + " with torch.no_grad():\n", + " h = self.forward(img1, img2)\n", + " if was_training:\n", + " self.train()\n", + " return h\n", + "\n", + "\n", + "class HomographyLoss(nn.Module):\n", + " def __init__(self):\n", + " super().__init__()\n", + " self.criterion = nn.MSELoss()\n", + "\n", + " def forward(self, pred_homography: torch.Tensor, target_homography: torch.Tensor) -> torch.Tensor:\n", + " return self.criterion(pred_homography, target_homography)\n", + "\n", + "\n", + "def create_homography_model(\n", + " model_type: str = \"backbone\",\n", + " input_size: Tuple[int, int] = (256, 256),\n", + " **kwargs,\n", + ") -> nn.Module:\n", + " if model_type == \"backbone\":\n", + " return HomographyCNN(**kwargs)\n", + " else:\n", + " raise ValueError(f\"Unknown model type: {model_type}\")\n", + "\n", + "\n", + "if __name__ == \"__main__\":\n", + " device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", + " print(f\"Using device: {device}\")\n", + "\n", + " model = HomographyCNN(\n", + " input_channels=3,\n", + " backbone_name=\"resnet18\",\n", + " pretrained=True,\n", + " dropout_rate=0.3,\n", + " use_batch_norm=True,\n", + " ).to(device)\n", + "\n", + " print(f\"Model created with {sum(p.numel() for p in model.parameters()):,} parameters\")\n", + "\n", + " batch_size = 4\n", + " height, width = 256, 256\n", + "\n", + " img1 = torch.randn(batch_size, 3, height, width).to(device)\n", + " img2 = torch.randn(batch_size, 3, height, width).to(device)\n", + "\n", + " print(\"\\nTesting forward pass...\")\n", + " output = model(img1, img2)\n", + " print(f\"Output shape: {output.shape}\")\n", + "\n", + " print(\"\\nTesting prediction...\")\n", + " pred = model.predict_homography(img1, img2)\n", + " print(f\"Prediction shape: {pred.shape}\")\n", + "\n", + " print(\"\\nTesting loss function...\")\n", + " target = torch.eye(3).unsqueeze(0).expand(batch_size, -1, -1).to(device)\n", + " loss_fn = HomographyLoss().to(device)\n", + " loss = loss_fn(output, target)\n", + " print(f\"Loss value: {loss.item():.6f}\")\n", + "\n", + " print(\"\\nAll tests completed successfully!\")\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e573b201", + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "import time\n", + "from datetime import datetime\n", + "\n", + "import torch\n", + "import torch.nn as nn\n", + "import torch.optim as optim\n", + "from torch.utils.data import DataLoader\n", + "from torch.utils.tensorboard import SummaryWriter\n", + "from tqdm import tqdm\n", + "\n", + "\n", + "class HomographyTrainer:\n", + " def __init__(\n", + " self,\n", + " model: nn.Module,\n", + " train_loader: DataLoader,\n", + " val_loader: DataLoader,\n", + " device: torch.device,\n", + " config: dict,\n", + " ):\n", + " self.model = model.to(device)\n", + " self.train_loader = train_loader\n", + " self.val_loader = val_loader\n", + " self.device = device\n", + " self.config = config\n", + "\n", + " self.criterion = HomographyLoss()\n", + " self.optimizer = optim.Adam(\n", + " model.parameters(),\n", + " lr=config.get(\"learning_rate\", 2e-4),\n", + " betas=(config.get(\"beta1\", 0.5), config.get(\"beta2\", 0.999)),\n", + " )\n", + "\n", + " self.writer = None\n", + " self.best_val_loss = float(\"inf\")\n", + " self.epochs_without_improvement = 0\n", + "\n", + " def train_epoch(self, epoch: int) -> dict:\n", + " self.model.train()\n", + " total_loss = 0\n", + " total_samples = 0\n", + "\n", + " pbar = tqdm(self.train_loader, desc=f\"Epoch {epoch}\")\n", + " for batch_idx, batch in enumerate(pbar):\n", + " google_img = batch[\"google_img\"].to(self.device)\n", + " yandex_img = batch[\"yandex_img\"].to(self.device)\n", + " target = batch[\"homography\"].to(self.device)\n", + "\n", + " self.optimizer.zero_grad()\n", + "\n", + " output = self.model(google_img, yandex_img)\n", + " loss = self.criterion(output, target)\n", + "\n", + " loss.backward()\n", + " self.optimizer.step()\n", + "\n", + " total_loss += loss.item() * google_img.size(0)\n", + " total_samples += google_img.size(0)\n", + "\n", + " if batch_idx % self.config.get(\"log_interval\", 10) == 0:\n", + " pbar.set_postfix({\"loss\": loss.item()})\n", + "\n", + " if self.writer:\n", + " self.writer.add_scalar(\n", + " \"train/loss\",\n", + " loss.item(),\n", + " epoch * len(self.train_loader) + batch_idx,\n", + " )\n", + "\n", + " avg_loss = total_loss / total_samples\n", + " return {\"loss\": avg_loss}\n", + "\n", + " def validate(self) -> dict:\n", + " self.model.eval()\n", + " total_loss = 0\n", + " total_samples = 0\n", + "\n", + " with torch.no_grad():\n", + " for batch in tqdm(self.val_loader, desc=\"Validation\"):\n", + " google_img = batch[\"google_img\"].to(self.device)\n", + " yandex_img = batch[\"yandex_img\"].to(self.device)\n", + " target = batch[\"homography\"].to(self.device)\n", + "\n", + " output = self.model(google_img, yandex_img)\n", + " loss = self.criterion(output, target)\n", + "\n", + " total_loss += loss.item() * google_img.size(0)\n", + " total_samples += google_img.size(0)\n", + "\n", + " avg_loss = total_loss / total_samples\n", + " return {\"loss\": avg_loss}\n", + "\n", + " def train(self, num_epochs: int):\n", + " log_dir = self.config.get(\"output_dir\", \"runs/homography\")\n", + " os.makedirs(log_dir, exist_ok=True)\n", + " self.writer = SummaryWriter(log_dir)\n", + "\n", + " print(f\"Starting training for {num_epochs} epochs\")\n", + " print(f\"Logging to: {log_dir}\")\n", + "\n", + " for epoch in range(1, num_epochs + 1):\n", + " print(f\"\\nEpoch {epoch}/{num_epochs}\")\n", + "\n", + " train_metrics = self.train_epoch(epoch)\n", + " val_metrics = self.validate()\n", + "\n", + " print(f\"Train Loss: {train_metrics['loss']:.4f}\")\n", + " print(f\"Val Loss: {val_metrics['loss']:.4f}\")\n", + "\n", + " if self.writer:\n", + " self.writer.add_scalar(\"epoch/train_loss\", train_metrics[\"loss\"], epoch)\n", + " self.writer.add_scalar(\"epoch/val_loss\", val_metrics[\"loss\"], epoch)\n", + "\n", + " if val_metrics[\"loss\"] < self.best_val_loss:\n", + " self.best_val_loss = val_metrics[\"loss\"]\n", + " self.epochs_without_improvement = 0\n", + " self.save_checkpoint(epoch, val_metrics[\"loss\"], is_best=True)\n", + " print(f\"New best model saved with val loss: {val_metrics['loss']:.4f}\")\n", + " else:\n", + " self.epochs_without_improvement += 1\n", + " self.save_checkpoint(epoch, val_metrics[\"loss\"], is_best=False)\n", + "\n", + " patience = self.config.get(\"early_stopping_patience\", 20)\n", + " if self.epochs_without_improvement >= patience:\n", + " print(f\"Early stopping triggered after {patience} epochs without improvement\")\n", + " break\n", + "\n", + " self.writer.close()\n", + "\n", + " def save_checkpoint(self, epoch: int, val_loss: float, is_best: bool = False):\n", + " checkpoint_dir = os.path.join(\n", + " self.config.get(\"output_dir\", \"runs/homography\"), \"checkpoints\"\n", + " )\n", + " os.makedirs(checkpoint_dir, exist_ok=True)\n", + "\n", + " checkpoint = {\n", + " \"epoch\": epoch,\n", + " \"model_state_dict\": self.model.state_dict(),\n", + " \"optimizer_state_dict\": self.optimizer.state_dict(),\n", + " \"val_loss\": val_loss,\n", + " \"config\": self.config,\n", + " }\n", + "\n", + " checkpoint_path = os.path.join(checkpoint_dir, f\"checkpoint_epoch_{epoch}.pt\")\n", + " torch.save(checkpoint, checkpoint_path)\n", + "\n", + " if is_best:\n", + " best_path = os.path.join(checkpoint_dir, \"best_model.pt\")\n", + " torch.save(checkpoint, best_path)\n", + "\n", + " def load_checkpoint(self, checkpoint_path: str):\n", + " checkpoint = torch.load(checkpoint_path, map_location=self.device)\n", + " self.model.load_state_dict(checkpoint[\"model_state_dict\"])\n", + " self.optimizer.load_state_dict(checkpoint[\"optimizer_state_dict\"])\n", + " return checkpoint[\"epoch\"], checkpoint[\"val_loss\"]\n", + "\n", + "\n", + "def main():\n", + " config_dict = config.copy()\n", + "\n", + " if isinstance(config_dict.get(\"image_size\"), list):\n", + " config_dict[\"image_size\"] = tuple(config_dict[\"image_size\"])\n", + "\n", + " device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", + " print(f\"Using device: {device}\")\n", + "\n", + " print(\"Creating data loaders...\")\n", + " train_loader, val_loader = create_data_loaders(\n", + " root_dir=config_dict[\"data_dir\"],\n", + " batch_size=config_dict[\"batch_size\"],\n", + " train_split=config_dict[\"train_split\"],\n", + " num_workers=config_dict[\"num_workers\"],\n", + " image_size=config_dict[\"image_size\"],\n", + " augment_train=True,\n", + " augment_val=False,\n", + " device=device,\n", + " )\n", + "\n", + " print(f\"Train batches: {len(train_loader)}\")\n", + " print(f\"Val batches: {len(val_loader)}\")\n", + "\n", + " print(\"Creating model...\")\n", + " model = create_homography_model(\n", + " model_type=\"backbone\",\n", + " input_channels=3,\n", + " backbone_name=\"resnet18\",\n", + " pretrained=True,\n", + " dropout_rate=0.3,\n", + " use_batch_norm=True,\n", + " )\n", + "\n", + " print(f\"Model parameters: {sum(p.numel() for p in model.parameters()):,}\")\n", + "\n", + " trainer = HomographyTrainer(\n", + " model=model,\n", + " train_loader=train_loader,\n", + " val_loader=val_loader,\n", + " device=device,\n", + " config=config_dict,\n", + " )\n", + "\n", + " print(\"Starting training...\")\n", + " trainer.train(config_dict[\"epochs\"])\n", + "\n", + " print(\"Training completed!\")\n", + "\n", + "\n", + "if __name__ == \"__main__\":\n", + " main()\n" + ] } ], "metadata": { diff --git a/models/SiaN/train.py b/models/SiaN/train.py new file mode 100644 index 0000000..7302bdc --- /dev/null +++ b/models/SiaN/train.py @@ -0,0 +1,212 @@ +import os +import time +from datetime import datetime + +import torch +import torch.nn as nn +import torch.optim as optim +from dataloader import config, create_data_loaders +from model import HomographyCNN, HomographyLoss, create_homography_model +from torch.utils.data import DataLoader +from torch.utils.tensorboard import SummaryWriter +from tqdm import tqdm + + +class HomographyTrainer: + def __init__( + self, + model: nn.Module, + train_loader: DataLoader, + val_loader: DataLoader, + device: torch.device, + config: dict, + ): + self.model = model.to(device) + self.train_loader = train_loader + self.val_loader = val_loader + self.device = device + self.config = config + + self.criterion = HomographyLoss() + self.optimizer = optim.Adam( + model.parameters(), + lr=config.get("learning_rate", 2e-4), + betas=(config.get("beta1", 0.5), config.get("beta2", 0.999)), + ) + + self.writer = None + self.best_val_loss = float("inf") + self.epochs_without_improvement = 0 + + def train_epoch(self, epoch: int) -> dict: + self.model.train() + total_loss = 0 + total_samples = 0 + + pbar = tqdm(self.train_loader, desc=f"Epoch {epoch}") + for batch_idx, batch in enumerate(pbar): + google_img = batch["google_img"].to(self.device) + yandex_img = batch["yandex_img"].to(self.device) + target = batch["homography"].to(self.device) + + self.optimizer.zero_grad() + + output = self.model(google_img, yandex_img) + loss = self.criterion(output, target) + + loss.backward() + self.optimizer.step() + + total_loss += loss.item() * google_img.size(0) + total_samples += google_img.size(0) + + if batch_idx % self.config.get("log_interval", 10) == 0: + pbar.set_postfix({"loss": loss.item()}) + + if self.writer: + self.writer.add_scalar( + "train/loss", + loss.item(), + epoch * len(self.train_loader) + batch_idx, + ) + + avg_loss = total_loss / total_samples + return {"loss": avg_loss} + + def validate(self) -> dict: + self.model.eval() + total_loss = 0 + total_samples = 0 + + with torch.no_grad(): + for batch in tqdm(self.val_loader, desc="Validation"): + google_img = batch["google_img"].to(self.device) + yandex_img = batch["yandex_img"].to(self.device) + target = batch["homography"].to(self.device) + + output = self.model(google_img, yandex_img) + loss = self.criterion(output, target) + + total_loss += loss.item() * google_img.size(0) + total_samples += google_img.size(0) + + avg_loss = total_loss / total_samples + return {"loss": avg_loss} + + def train(self, num_epochs: int): + log_dir = self.config.get("output_dir", "runs/homography") + os.makedirs(log_dir, exist_ok=True) + self.writer = SummaryWriter(log_dir) + + print(f"Starting training for {num_epochs} epochs") + print(f"Logging to: {log_dir}") + + for epoch in range(1, num_epochs + 1): + print(f"\nEpoch {epoch}/{num_epochs}") + + train_metrics = self.train_epoch(epoch) + val_metrics = self.validate() + + print(f"Train Loss: {train_metrics['loss']:.4f}") + print(f"Val Loss: {val_metrics['loss']:.4f}") + + if self.writer: + self.writer.add_scalar("epoch/train_loss", train_metrics["loss"], epoch) + self.writer.add_scalar("epoch/val_loss", val_metrics["loss"], epoch) + + if val_metrics["loss"] < self.best_val_loss: + self.best_val_loss = val_metrics["loss"] + self.epochs_without_improvement = 0 + self.save_checkpoint(epoch, val_metrics["loss"], is_best=True) + print(f"New best model saved with val loss: {val_metrics['loss']:.4f}") + else: + self.epochs_without_improvement += 1 + self.save_checkpoint(epoch, val_metrics["loss"], is_best=False) + + patience = self.config.get("early_stopping_patience", 20) + if self.epochs_without_improvement >= patience: + print(f"Early stopping triggered after {patience} epochs without improvement") + break + + self.writer.close() + + def save_checkpoint(self, epoch: int, val_loss: float, is_best: bool = False): + checkpoint_dir = os.path.join( + self.config.get("output_dir", "runs/homography"), "checkpoints" + ) + os.makedirs(checkpoint_dir, exist_ok=True) + + checkpoint = { + "epoch": epoch, + "model_state_dict": self.model.state_dict(), + "optimizer_state_dict": self.optimizer.state_dict(), + "val_loss": val_loss, + "config": self.config, + } + + checkpoint_path = os.path.join(checkpoint_dir, f"checkpoint_epoch_{epoch}.pt") + torch.save(checkpoint, checkpoint_path) + + if is_best: + best_path = os.path.join(checkpoint_dir, "best_model.pt") + torch.save(checkpoint, best_path) + + def load_checkpoint(self, checkpoint_path: str): + checkpoint = torch.load(checkpoint_path, map_location=self.device) + self.model.load_state_dict(checkpoint["model_state_dict"]) + self.optimizer.load_state_dict(checkpoint["optimizer_state_dict"]) + return checkpoint["epoch"], checkpoint["val_loss"] + + +def main(): + config_dict = config.copy() + + if isinstance(config_dict.get("image_size"), list): + config_dict["image_size"] = tuple(config_dict["image_size"]) + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + print(f"Using device: {device}") + + print("Creating data loaders...") + train_loader, val_loader = create_data_loaders( + root_dir=config_dict["data_dir"], + batch_size=config_dict["batch_size"], + train_split=config_dict["train_split"], + num_workers=config_dict["num_workers"], + image_size=config_dict["image_size"], + augment_train=True, + augment_val=False, + device=device, + ) + + print(f"Train batches: {len(train_loader)}") + print(f"Val batches: {len(val_loader)}") + + print("Creating model...") + model = create_homography_model( + model_type="backbone", + input_channels=3, + backbone_name="resnet18", + pretrained=True, + dropout_rate=0.3, + use_batch_norm=True, + ) + + print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}") + + trainer = HomographyTrainer( + model=model, + train_loader=train_loader, + val_loader=val_loader, + device=device, + config=config_dict, + ) + + print("Starting training...") + trainer.train(config_dict["epochs"]) + + print("Training completed!") + + +if __name__ == "__main__": + main()