Compare commits

...

2 Commits

Author SHA1 Message Date
05f8746d58 feat: complete sian-similarity 2026-03-22 14:29:00 +03:00
43cd4222bc feat: add similarity model 2026-03-03 21:42:23 +03:00
13 changed files with 5346 additions and 1668 deletions

2
models/.gitignore vendored Normal file
View File

@@ -0,0 +1,2 @@
reports
runs

View File

@@ -0,0 +1,131 @@
# SiaN-Similarity: Модель для оценки схожести изображений
Модель для оценки схожести между двумя изображениями 256x256. Возвращает значение от 0 до 1, где 1 означает полную схожесть, 0 - полное различие.
## Архитектура модели
Модель основана на CNN с residual блоками:
- Общий энкодер для обоих изображений
- Residual blocks с batch normalization
- Слой слияния признаков
- Регрессионная голова с сигмоидой на выходе
## Использование
### Установка зависимостей
```bash
pip install torch torchvision pillow
```
### Быстрый старт
```python
import torch
from model import SimilarityCNN
# Создание модели
model = SimilarityCNN(
input_channels=3,
hidden_channels=64,
num_blocks=4,
dropout_rate=0.3,
use_batch_norm=True,
)
# Предсказание схожести
img1 = torch.randn(1, 3, 256, 256) # Изображение 1
img2 = torch.randn(1, 3, 256, 256) # Изображение 2
similarity = model.predict_similarity(img1, img2)
print(f"Схожесть: {similarity.item():.4f}")
```
### Обучение модели
```bash
python train_similarity.py \
--data_dir "путь/к/данным" \
--batch_size 32 \
--epochs 100 \
--learning_rate 2e-4 \
--output_dir "runs/similarity"
```
### Предсказание на новых изображениях
```bash
python predict.py \
--image1 "путь/к/изображению1.png" \
--image2 "путь/к/изображению2.png" \
--checkpoint "runs/similarity/checkpoints/best_model.pt"
```
## Структура проекта
```
SiaN-similarity/
├── model.py # Основная модель
├── dataloader.py # Даталоадер для обучения
├── train_similarity.py # Скрипт для обучения
├── predict.py # Скрипт для предсказания
├── train.py # Оригинальный тренировочный скрипт
└── README.md # Этот файл
```
## Конфигурация модели
Параметры по умолчанию:
- `input_channels`: 3 (RGB)
- `hidden_channels`: 64
- `num_blocks`: 4
- `dropout_rate`: 0.3
- `use_batch_norm`: True
- `image_size`: (256, 256)
## Формат данных
Модель ожидает изображения размером 256x256 пикселей в формате RGB.
Для обучения используется датасет с парами изображений и метками схожести.
## Примеры использования
### 1. Создание и тестирование модели
```python
from model import create_similarity_model
model = create_similarity_model(
model_type="cnn",
input_size=(256, 256),
hidden_channels=32,
num_blocks=3,
)
```
### 2. Использование функции потерь
```python
from model import SimilarityLoss
loss_fn = SimilarityLoss()
pred = torch.tensor([[0.8], [0.2]])
target = torch.tensor([[1.0], [0.0]])
loss = loss_fn(pred, target)
```
### 3. Расчет метрик
```python
metrics = loss_fn.compute_metrics(pred, target)
print(f"Accuracy: {metrics['accuracy']:.4f}")
print(f"F1-score: {metrics['f1']:.4f}")
```
## Требования
- Python 3.8+
- PyTorch 1.9+
- torchvision
- Pillow
- numpy
## Лицензия
MIT

View File

@@ -0,0 +1,519 @@
config = {
# Параметры оптимизатора
"learning_rate": 2e-4,
"beta1": 0.5,
"beta2": 0.999,
# Параметры обучения
"batch_size": 4,
"epochs": 100,
# Параметры GAN
"gan_mode": "vanilla", # "vanilla", "lsgan", или "wgangp"
"lambda_L1": 100.0, # Вес L1 потерь
# Регуляризация
"grad_clip": 1.0,
# Ранняя остановка
"early_stopping_patience": 20,
# Выходные данные
"output_dir": "runs/gan_training",
# Логирование
"log_interval": 10, # Логировать каждые N батчей
"save_interval": 5, # Сохранять чекпоинт каждые N эпох
"data_dir": r"C:\Users\admin\Projects\autopilot\datasets\ya_go_maps\images",
"batch_size": 32,
"image_size": [256, 256],
"train_split": 0.8,
"num_workers": 0,
}
import os
import random
from typing import Any, Dict, List, Optional, Tuple
import cv2
import numpy as np
import torch
from PIL import Image
from torch.utils.data import DataLoader, Dataset
class YaGoDataset(Dataset):
"""
Dataset for homography estimation between Yandex and Google map image pairs.
This dataset loads pairs of images (Yandex and Google maps) and provides
homography matrices for data augmentation and training.
"""
def __init__(
self,
root_dir: str,
transform=None,
augment: bool = True,
max_samples: Optional[int] = None,
image_size: Tuple[int, int] = (700, 700),
cache_homographies: bool = True,
device: torch.device = None,
):
"""
Initialize the YaGoDataset.
Args:
root_dir: Directory containing image pairs (format: {idx:04d}_google.png, {idx:04d}_yandex.png)
transform: Optional torchvision transforms to apply
augment: Whether to apply homography-based data augmentation
max_samples: Maximum number of samples to load (None for all)
image_size: Target size for images (height, width)
cache_homographies: Whether to cache generated homography matrices to disk
"""
self.root_dir = root_dir
self.transform = transform
self.augment = augment
self.image_size = image_size
self.cache_homographies = cache_homographies
self.device = device
# Find all image pairs
self.image_pairs = self._discover_image_pairs()
if max_samples is not None:
self.image_pairs = self.image_pairs[:max_samples]
print(f"Found {len(self.image_pairs)} image pairs in {root_dir}")
def _discover_image_pairs(self) -> List[Dict[str, Any]]:
"""Discover all Google-Yandex image pairs in the dataset directory."""
image_pairs = []
# Get all Google images
google_files = [
f for f in os.listdir(self.root_dir) if f.endswith("_google.png")
]
for google_file in sorted(google_files):
# Extract index from filename
idx_str = google_file.split("_")[0]
try:
idx = int(idx_str)
except ValueError:
continue
# Check if corresponding Yandex image exists
yandex_file = f"{idx:04d}_yandex.png"
yandex_path = os.path.join(self.root_dir, yandex_file)
if os.path.exists(yandex_path):
image_pairs.append(
{
"idx": idx,
"google_path": os.path.join(self.root_dir, google_file),
"yandex_path": yandex_path,
}
)
return image_pairs
def __len__(self) -> int:
"""Return the number of image pairs in the dataset."""
return len(self.image_pairs)
def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
"""
Get a sample from the dataset.
Returns a dictionary with:
- 'google_img': Google map image tensor
- 'yandex_img': Yandex map image tensor
- 'homography': Ground truth homography matrix (3x3)
- 'idx': Sample index
"""
pair_info = self.image_pairs[idx]
google_path = pair_info["google_path"]
yandex_path = pair_info["yandex_path"]
same_domain = True
if np.random.rand() > 0.5:
random_idx = np.random.randint(0, len(self))
google_path = self.image_pairs[random_idx]["google_path"]
same_domain = random_idx == idx
# Load images
yandex_img = Image.open(yandex_path).convert("RGB")
google_img = Image.open(google_path).convert("RGB")
# Resize images to target size
google_img = google_img.resize(
(self.image_size[1], self.image_size[0]), Image.BILINEAR
)
yandex_img = yandex_img.resize(
(self.image_size[1], self.image_size[0]), Image.BILINEAR
)
# Get or generate homography matrix
matrices: Tuple[np.ndarray, np.ndarray, np.ndarray] = (
self._get_homography_matrix(pair_info["idx"])
)
# Apply data augmentation if enabled
if self.augment:
google_img, yandex_img, homography_matrix = self._apply_augmentation(
google_img, yandex_img, matrices
)
# Convert images to tensors
if self.transform:
google_img = self.transform(google_img)
yandex_img = self.transform(yandex_img)
else:
# Default conversion to tensor
google_img = (
torch.from_numpy(np.array(google_img)).float().permute(2, 0, 1) / 255.0
)
yandex_img = (
torch.from_numpy(np.array(yandex_img)).float().permute(2, 0, 1) / 255.0
)
# Convert homography to tensor
if self.augment:
homography_tensor = torch.from_numpy(homography_matrix).float()
else:
homography_tensor = torch.from_numpy(np.eye(3))
return {
"google_img": google_img,
"yandex_img": yandex_img,
"homography": homography_tensor,
"same_domain": same_domain,
"idx": torch.tensor(pair_info["idx"], dtype=torch.long),
}
def _get_homography_matrix(
self, idx: int
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
"""
Get homography matrices for a given index.
If cached homography exists, load it. Otherwise generate a new one.
"""
# Generate new homography matrix
homography_matrix_1 = self.generate_random_homography()
homography_matrix_2 = self.generate_random_homography()
homography_matrix_r = np.linalg.inv(homography_matrix_1) @ homography_matrix_2
result = (homography_matrix_1, homography_matrix_2, homography_matrix_r)
return result
def generate_random_homography(self) -> np.ndarray:
"""
Generate a random homography matrix for data augmentation.
Returns:
np.ndarray: 3x3 homography matrix.
"""
# Generate random affine transformation parameters
scale = np.random.uniform(0.8, 1.2) # scaling factor
tx = np.random.uniform(-0.50, 0.50) # translation in x
ty = np.random.uniform(-0.50, 0.50) # translation in y
# rotation
angle_x = np.random.uniform(np.radians(-10), np.radians(10))
angle_y = np.random.uniform(np.radians(-10), np.radians(10))
angle_z = np.random.uniform(np.radians(-10), np.radians(10))
cy, sy = np.cos(angle_z), np.sin(angle_z)
cp, sp = np.cos(angle_y), np.sin(angle_y)
cr, sr = np.cos(angle_x), np.sin(angle_x)
Rz = np.array(
[
[cy, -sy, 0],
[sy, cy, 0],
[0, 0, 1],
]
)
Ry = np.array(
[
[cp, 0, sp],
[0, 1, 0],
[-sp, 0, cp],
]
)
Rx = np.array(
[
[1, 0, 0],
[0, cr, -sr],
[0, sr, cr],
]
)
# Create affine transformation matrix
T = np.array(
[
[1, 0, tx],
[0, 1, ty],
[0, 0, scale],
]
)
K = self.get_camera_matrix()
return K @ Rx @ Ry @ Rz @ T @ np.linalg.inv(K)
def get_camera_matrix(self) -> np.ndarray:
w, h = config["image_size"]
K = np.array(
[
[w / 2, 0, w / 2],
[0, h / 2, h / 2],
[0, 0, 1],
]
)
return K
def _apply_augmentation(
self,
google_img: Image.Image,
yandex_img: Image.Image,
matrices: Tuple[np.ndarray, np.ndarray, np.ndarray],
) -> Tuple[Image.Image, Image.Image, np.ndarray]:
"""
Apply homography-based data augmentation to image pair.
Args:
google_img: Google map image
yandex_img: Yandex map image
matrices: homography matrices
Returns:
Tuple of (augmented_google_img, augmented_yandex_img, augmented_homography)
"""
# Combine with base homography
combined_homography = matrices[2]
# Apply augmentation to both images
# google_aug = self._apply_homography_to_image(google_img, aug_homography)
yandex_aug = self._apply_homography_to_image(yandex_img, matrices[0])
google_aug = self._apply_homography_to_image(google_img, matrices[1])
return google_aug, yandex_aug, combined_homography
def _apply_homography_to_image(
self, img: Image.Image, homography: np.ndarray
) -> Image.Image:
"""
Apply homography transformation to a single image.
Args:
img: PIL Image to transform
homography: 3x3 homography matrix
Returns:
Transformed PIL Image
"""
# Convert to numpy array
img_np = np.array(img)
# Get image dimensions
h, w = img_np.shape[:2]
# Apply homography transformation
transformed = cv2.warpPerspective(
img_np,
homography,
(w, h),
flags=cv2.INTER_LINEAR,
# borderMode=cv2.BORDER_REFLECT,
)
# Convert back to PIL Image
return Image.fromarray(transformed)
def get_sample_without_augmentation(self, idx: int) -> Dict[str, Any]:
"""
Get a sample without data augmentation.
Useful for visualization and evaluation.
"""
pair_info = self.image_pairs[idx]
# Load images
google_img = Image.open(pair_info["google_path"]).convert("RGB")
yandex_img = Image.open(pair_info["yandex_path"]).convert("RGB")
# Resize
google_img = google_img.resize(
(self.image_size[1], self.image_size[0]), Image.BILINEAR
)
yandex_img = yandex_img.resize(
(self.image_size[1], self.image_size[0]), Image.BILINEAR
)
# Get homography matrix
homography_matrix = self._get_homography_matrix(pair_info["idx"])
return {
"google_img": google_img,
"yandex_img": yandex_img,
"homography": homography_matrix,
"idx": pair_info["idx"],
"google_path": pair_info["google_path"],
"yandex_path": pair_info["yandex_path"],
}
def create_data_loaders(
root_dir: str,
batch_size: int = 32,
train_split: float = 0.8,
num_workers: int = 4,
image_size: Tuple[int, int] = (256, 256),
augment_train: bool = True,
augment_val: bool = False,
device: torch.device = None,
) -> Tuple[DataLoader, DataLoader]:
"""
Create train and validation data loaders for homography estimation.
Args:
root_dir: Directory containing image pairs
batch_size: Batch size for data loaders
train_split: Fraction of data to use for training
num_workers: Number of worker processes for data loading
image_size: Target image size (height, width)
augment_train: Whether to augment training data
augment_val: Whether to augment validation data
device: Target device for tensors (optional)
Returns:
Tuple of (train_loader, val_loader)
"""
from torchvision import transforms
# Define transforms
transform = transforms.Compose(
[
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
]
)
# Create full dataset
full_dataset = YaGoDataset(
root_dir=root_dir,
transform=transform,
augment=False, # We'll handle augmentation separately
image_size=image_size,
cache_homographies=True,
device=device,
)
# Split dataset
dataset_size = len(full_dataset)
train_size = int(train_split * dataset_size)
val_size = dataset_size - train_size
# Create indices for splitting
indices = list(range(dataset_size))
random.shuffle(indices)
train_indices = indices[:train_size]
val_indices = indices[train_size:]
# Create subset samplers
from torch.utils.data import Subset
train_dataset = Subset(full_dataset, train_indices)
val_dataset = Subset(full_dataset, val_indices)
# Apply augmentation by overriding __getitem__ for train dataset
if augment_train:
class AugmentedSubset(Subset):
def __init__(self, dataset, indices, device=None):
super().__init__(dataset, indices)
self.device = device
def __getitem__(self, idx):
sample = self.dataset[self.indices[idx]]
# Apply augmentation
google_img = sample["google_img"]
yandex_img = sample["yandex_img"]
homography = sample["homography"]
if self.device is not None:
google_img = google_img.to(self.device)
yandex_img = yandex_img.to(self.device)
homography = homography.to(self.device)
# Generate augmentation homography
aug_homography = torch.from_numpy(
full_dataset.generate_random_homography()
).float()
if self.device is not None:
aug_homography = aug_homography.to(self.device)
# Combine homographies
combined_homography = aug_homography @ homography
# Apply augmentation (simplified - in practice would warp images)
# For now, we just return the combined homography
return {
"google_img": google_img,
"yandex_img": yandex_img,
"homography": combined_homography,
"idx": sample["idx"],
}
train_dataset = AugmentedSubset(full_dataset, train_indices, device=device)
# Create data loaders
train_loader = DataLoader(
train_dataset,
batch_size=batch_size,
shuffle=True,
num_workers=num_workers,
pin_memory=True,
)
val_loader = DataLoader(
val_dataset,
batch_size=batch_size,
shuffle=False,
num_workers=num_workers,
pin_memory=True,
)
return train_loader, val_loader
# Example usage
dataset = YaGoDataset(
root_dir=config["data_dir"],
augment=False,
image_size=(256, 256),
)
print(f"Dataset size: {len(dataset)}")
# Get a sample
sample = dataset[0]
print(f"Sample keys: {list(sample.keys())}")
print(f"Google image shape: {sample['google_img'].shape}")
print(f"Yandex image shape: {sample['yandex_img'].shape}")
print(f"Homography shape: {sample['homography'].shape}")
# Create data loaders
train_loader, val_loader = create_data_loaders(
root_dir=config["data_dir"],
batch_size=16,
train_split=0.8,
)
print(f"Train batches: {len(train_loader)}")
print(f"Val batches: {len(val_loader)}")

View File

@@ -0,0 +1,192 @@
"""
Демонстрационный скрипт для модели оценки схожести изображений.
"""
import matplotlib.pyplot as plt
import numpy as np
import torch
from model import SimilarityCNN, SimilarityLoss
from PIL import Image, ImageDraw, ImageFont
from torchvision import transforms
def create_test_images():
"""Создание тестовых изображений для демонстрации."""
images = []
# Изображение 1: Красный квадрат
img1 = Image.new("RGB", (256, 256), color="white")
draw = ImageDraw.Draw(img1)
draw.rectangle([50, 50, 200, 200], fill="red", outline="black", width=2)
images.append(("Красный квадрат", img1))
# Изображение 2: Тот же красный квадрат (похожее)
img2 = Image.new("RGB", (256, 256), color="white")
draw = ImageDraw.Draw(img2)
draw.rectangle([55, 55, 205, 205], fill="red", outline="black", width=2)
images.append(("Похожий красный квадрат", img2))
# Изображение 3: Синий круг (разное)
img3 = Image.new("RGB", (256, 256), color="white")
draw = ImageDraw.Draw(img3)
draw.ellipse([50, 50, 200, 200], fill="blue", outline="black", width=2)
images.append(("Синий круг", img3))
# Изображение 4: Зеленый треугольник (разное)
img4 = Image.new("RGB", (256, 256), color="white")
draw = ImageDraw.Draw(img4)
draw.polygon(
[(128, 50), (50, 200), (200, 200)], fill="green", outline="black", width=2
)
images.append(("Зеленый треугольник", img4))
return images
def preprocess_image(image):
"""Преобразование PIL Image в тензор."""
transform = transforms.Compose(
[
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
]
)
return transform(image).unsqueeze(0) # Добавляем batch dimension
def display_results(images, similarities):
"""Отображение результатов сравнения."""
fig, axes = plt.subplots(2, 2, figsize=(12, 10))
axes = axes.flatten()
for idx, (title, img) in enumerate(images):
ax = axes[idx]
ax.imshow(img)
ax.set_title(title, fontsize=12, fontweight="bold")
ax.axis("off")
plt.suptitle("Тестовые изображения", fontsize=16, fontweight="bold")
plt.tight_layout()
plt.show()
# Вывод результатов сравнения
print("\n" + "=" * 60)
print("РЕЗУЛЬТАТЫ СРАВНЕНИЯ ИЗОБРАЖЕНИЙ")
print("=" * 60)
comparisons = [
("Красный квадрат", "Похожий красный квадрат"),
("Красный квадрат", "Синий круг"),
("Красный квадрат", "Зеленый треугольник"),
("Похожий красный квадрат", "Синий круг"),
]
for i, (name1, name2) in enumerate(comparisons):
idx1 = [idx for idx, (name, _) in enumerate(images) if name == name1][0]
idx2 = [idx for idx, (name, _) in enumerate(images) if name == name2][0]
sim = similarities[idx1, idx2]
interpretation = "ПОХОЖИ" if sim > 0.5 else "РАЗНЫЕ"
print(f"\n{name1} vs {name2}:")
print(f" Схожесть: {sim:.4f}")
print(f" Интерпретация: {interpretation}")
print(f" Уверенность: {'Высокая' if sim > 0.7 or sim < 0.3 else 'Средняя'}")
def test_loss_function():
"""Тестирование функции потерь."""
print("\n" + "=" * 60)
print("ТЕСТИРОВАНИЕ ФУНКЦИИ ПОТЕРЬ")
print("=" * 60)
loss_fn = SimilarityLoss()
# Тестовые данные
predictions = torch.tensor([[0.9], [0.1], [0.7], [0.3]])
targets = torch.tensor([[1.0], [0.0], [1.0], [0.0]])
# Расчет потерь
loss = loss_fn(predictions, targets)
print(f"\nПотери: {loss.item():.4f}")
# Расчет метрик
metrics = loss_fn.compute_metrics(predictions, targets)
print("\nМетрики:")
for key, value in metrics.items():
print(f" {key}: {value:.4f}")
def main():
"""Основная функция демонстрации."""
print("ДЕМОНСТРАЦИЯ МОДЕЛИ ОЦЕНКИ СХОЖЕСТИ ИЗОБРАЖЕНИЙ")
print("=" * 60)
# Создание модели
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"\nУстройство: {device}")
model = SimilarityCNN(
input_channels=3,
hidden_channels=64,
num_blocks=4,
dropout_rate=0.3,
use_batch_norm=True,
).to(device)
print(f"Параметры модели: {sum(p.numel() for p in model.parameters()):,}")
# Создание тестовых изображений
print("\nСоздание тестовых изображений...")
test_images = create_test_images()
# Преобразование изображений в тензоры
tensors = []
for name, img in test_images:
tensor = preprocess_image(img).to(device)
tensors.append(tensor)
# Расчет схожести между всеми парами изображений
print("\nРасчет схожести между изображениями...")
n_images = len(test_images)
similarity_matrix = np.zeros((n_images, n_images))
model.eval()
with torch.no_grad():
for i in range(n_images):
for j in range(n_images):
if i <= j: # Рассчитываем только верхний треугольник
sim = model.predict_similarity(tensors[i], tensors[j])
similarity_matrix[i, j] = sim.item()
similarity_matrix[j, i] = sim.item() # Симметричная матрица
# Отображение результатов
display_results(test_images, similarity_matrix)
# Тестирование функции потерь
test_loss_function()
# Дополнительная информация
print("\n" + "=" * 60)
print("ИНФОРМАЦИЯ О МОДЕЛИ")
print("=" * 60)
print("\nАрхитектура модели:")
print("-" * 40)
print("Вход: два изображения 256x256x3")
print("Энкодер: CNN с residual блоками")
print("Слой слияния: объединение признаков")
print("Выход: значение схожести [0, 1]")
print("\nИнтерпретация результатов:")
print("- 0.8-1.0: Очень похожи")
print("- 0.6-0.8: Похожи")
print("- 0.4-0.6: Нейтрально")
print("- 0.2-0.4: Разные")
print("- 0.0-0.2: Совершенно разные")
print("\n" + "=" * 60)
print("ДЕМОНСТРАЦИЯ ЗАВЕРШЕНА")
print("=" * 60)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,340 @@
"""
Demo Evaluation Notebook-style File
====================================
This file demonstrates how to use the evaluation functions from evaluation.py
in a notebook-like style. You can run this file directly to see all the plots
and analysis.
Think of this as the next cell in your notebook after training!
"""
import os
import sys
# Add the current directory to the path so we can import our modules
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
# Import our evaluation module
import matplotlib.pyplot as plt
import numpy as np
# Import other necessary modules
import torch
from dataloader import config, create_data_loaders
from evaluation import (
analyze_model_performance,
generate_performance_report,
plot_confusion_matrix,
plot_probability_distribution,
plot_roc_curve,
plot_training_metrics,
test_model_on_examples,
)
from model import create_similarity_model
print("=" * 70)
print("DEMO: EVALUATING IMAGE SIMILARITY MODEL")
print("=" * 70)
print("\nThis demo shows you how to analyze your trained model.")
print("Think of this as the 'results' section of your notebook!\n")
# ============================================================================
# STEP 1: SETUP
# ============================================================================
print("STEP 1: Setting up the environment")
print("-" * 40)
# Check if GPU is available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"✓ Using device: {device}")
# Load configuration
config_dict = config.copy()
if isinstance(config_dict.get("image_size"), list):
config_dict["image_size"] = tuple(config_dict["image_size"])
print(f"✓ Image size: {config_dict['image_size']}")
print(f"✓ Batch size: {config_dict['batch_size']}")
# ============================================================================
# STEP 2: LOAD DATA
# ============================================================================
print("\nSTEP 2: Loading validation data")
print("-" * 40)
# Create validation data loader
_, val_loader = create_data_loaders(
root_dir=config_dict["data_dir"],
batch_size=config_dict["batch_size"],
train_split=config_dict["train_split"],
num_workers=config_dict["num_workers"],
image_size=config_dict["image_size"],
augment_train=False,
augment_val=False,
device=device,
)
print(f"✓ Validation batches loaded: {len(val_loader)}")
print(f"✓ Each batch has {config_dict['batch_size']} image pairs")
# ============================================================================
# STEP 3: LOAD TRAINED MODEL
# ============================================================================
print("\nSTEP 3: Loading the trained model")
print("-" * 40)
# Create model architecture
model = create_similarity_model(
model_type="cnn",
input_size=config_dict["image_size"][0],
input_channels=3,
hidden_channels=64,
num_blocks=4,
dropout_rate=0.3,
use_batch_norm=True,
)
# Try to load the best checkpoint
checkpoint_dir = os.path.join(
config_dict.get("output_dir", "runs/similarity"), "checkpoints"
)
best_checkpoint = os.path.join(checkpoint_dir, "best_model.pt")
if os.path.exists(best_checkpoint):
checkpoint = torch.load(best_checkpoint, map_location=device)
model.load_state_dict(checkpoint["model_state_dict"])
print(f"✓ Loaded best model from epoch {checkpoint['epoch']}")
print(f"✓ Best validation loss: {checkpoint['val_loss']:.4f}")
else:
print("⚠ Warning: Best model checkpoint not found!")
print(" Using randomly initialized model for demonstration.")
print(" (This is normal if you haven't trained the model yet)")
model = model.to(device)
print(f"✓ Model moved to {device}")
# Count parameters
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"✓ Total parameters: {total_params:,}")
print(f"✓ Trainable parameters: {trainable_params:,}")
# ============================================================================
# STEP 4: PLOT TRAINING METRICS
# ============================================================================
print("\nSTEP 4: Plotting training metrics")
print("-" * 40)
print("This shows how the model learned over time:")
# This will show 4 plots:
# 1. Training and validation loss
# 2. Training and validation accuracy
# 3. Overfitting indicator
# 4. Learning rate schedule
plot_training_metrics(config_dict.get("output_dir", "runs/similarity"))
print("✓ Training metrics plotted!")
print(" Look for 'training_metrics.png' in your runs directory")
# ============================================================================
# STEP 5: ANALYZE MODEL PERFORMANCE
# ============================================================================
print("\nSTEP 5: Analyzing model performance on validation set")
print("-" * 40)
print("Calculating metrics like accuracy, precision, recall, F1 score...")
# Analyze the model
metrics = analyze_model_performance(model, val_loader, device, threshold=0.5)
print("\n📊 PERFORMANCE METRICS:")
print(" Accuracy: {:.2%}".format(metrics["accuracy"]))
print(" Precision: {:.2%}".format(metrics["precision"]))
print(" Recall: {:.2%}".format(metrics["recall"]))
print(" F1 Score: {:.2%}".format(metrics["f1_score"]))
print(" ROC AUC: {:.4f}".format(metrics["roc_auc"]))
# ============================================================================
# STEP 6: SHOW CONFUSION MATRIX
# ============================================================================
print("\nSTEP 6: Confusion Matrix")
print("-" * 40)
print("This shows how many predictions were correct/wrong:")
plot_confusion_matrix(metrics["confusion_matrix"])
# ============================================================================
# STEP 7: ROC CURVE
# ============================================================================
print("\nSTEP 7: ROC Curve")
print("-" * 40)
print("This shows how well the model distinguishes between classes:")
# Get probabilities for ROC curve
model.eval()
all_probabilities = []
all_targets = []
with torch.no_grad():
for batch in val_loader:
google_img = batch["google_img"].to(device)
yandex_img = batch["yandex_img"].to(device)
target = batch["same_domain"].float().to(device)
output = model(google_img, yandex_img)
probabilities = torch.sigmoid(output).squeeze()
all_probabilities.extend(probabilities.cpu().numpy())
all_targets.extend(target.cpu().numpy())
all_probabilities = np.array(all_probabilities)
all_targets = np.array(all_targets)
from sklearn.metrics import auc, roc_curve
fpr, tpr, _ = roc_curve(all_targets, all_probabilities)
roc_auc = auc(fpr, tpr)
plot_roc_curve(fpr, tpr, roc_auc)
# ============================================================================
# STEP 8: PROBABILITY DISTRIBUTION
# ============================================================================
print("\nSTEP 8: Probability Distribution")
print("-" * 40)
print("This shows how confident the model is for different classes:")
plot_probability_distribution(all_probabilities, all_targets)
# ============================================================================
# STEP 9: TEST ON EXAMPLE IMAGES
# ============================================================================
print("\nSTEP 9: Testing on example images")
print("-" * 40)
print("Let's see how the model performs on some examples:")
test_results = test_model_on_examples(model, device)
# ============================================================================
# STEP 10: GENERATE REPORT
# ============================================================================
print("\nSTEP 10: Generating performance report")
print("-" * 40)
print("Creating a detailed report with all metrics...")
final_metrics = generate_performance_report(model, val_loader, device)
print("\n" + "=" * 70)
print("🎉 DEMO COMPLETED SUCCESSFULLY!")
print("=" * 70)
print("\n📁 What was created:")
print(" 1. Training metrics plots (saved to runs/similarity/)")
print(" 2. Confusion matrix visualization")
print(" 3. ROC curve plot")
print(" 4. Probability distribution plot")
print(" 5. Performance report (saved to reports/)")
print("\n🔍 Key things to check in your model:")
print(" ✓ Accuracy should be above 70% for a good model")
print(" ✓ Precision: High = few false positives")
print(" ✓ Recall: High = few false negatives")
print(" ✓ ROC AUC: Above 0.8 = good discrimination")
print("\n🔄 If results are poor, try:")
print(" 1. Train for more epochs")
print(" 2. Adjust learning rate")
print(" 3. Use more training data")
print(" 4. Try different model architecture")
print(
"\n💡 Pro tip: The optimal threshold is {:.3f}".format(
final_metrics["optimal_threshold"]
)
)
print(" You can use this instead of 0.5 for better results!")
# ============================================================================
# BONUS: QUICK DIAGNOSTICS TABLE
# ============================================================================
print("\n" + "=" * 70)
print("BONUS: Quick Diagnostics Table")
print("=" * 70)
# Create a simple table of what each metric means
diagnostics = [
["Metric", "Value", "What it means", "Is it good?"],
["-" * 15, "-" * 10, "-" * 30, "-" * 15],
["Accuracy", f"{metrics['accuracy']:.2%}", "Overall correctness", ">70% is good"],
["Precision", f"{metrics['precision']:.2%}", "Few false positives", ">70% is good"],
["Recall", f"{metrics['recall']:.2%}", "Few false negatives", ">70% is good"],
[
"F1 Score",
f"{metrics['f1_score']:.2%}",
"Balance of precision/recall",
">70% is good",
],
["ROC AUC", f"{metrics['roc_auc']:.4f}", "Discrimination ability", ">0.8 is good"],
]
for row in diagnostics:
print("{:<15} {:<10} {:<30} {:<15}".format(*row))
print("\n" + "=" * 70)
print("To run this again, just execute: python demo_evaluation.ipynb.py")
print("=" * 70)
# ============================================================================
# EXTRA: SAVE PREDICTIONS FOR FURTHER ANALYSIS
# ============================================================================
print("\n💾 Saving predictions for further analysis...")
# Get all predictions
model.eval()
all_predictions = []
all_targets = []
all_probabilities = []
image_indices = []
with torch.no_grad():
for batch_idx, batch in enumerate(val_loader):
google_img = batch["google_img"].to(device)
yandex_img = batch["yandex_img"].to(device)
target = batch["same_domain"].float().to(device)
output = model(google_img, yandex_img)
probabilities = torch.sigmoid(output).squeeze()
predictions = (probabilities > 0.5).float()
all_predictions.extend(predictions.cpu().numpy())
all_targets.extend(target.cpu().numpy())
all_probabilities.extend(probabilities.cpu().numpy())
image_indices.extend(
range(batch_idx * len(target), (batch_idx + 1) * len(target))
)
# Save to CSV for further analysis
import pandas as pd
predictions_df = pd.DataFrame(
{
"image_pair_index": image_indices,
"true_label": all_targets,
"predicted_label": all_predictions,
"probability": all_probabilities,
"correct": np.array(all_targets) == np.array(all_predictions),
}
)
predictions_path = os.path.join(
config_dict.get("output_dir", "runs/similarity"), "predictions_analysis.csv"
)
predictions_df.to_csv(predictions_path, index=False)
print(f"✓ Predictions saved to: {predictions_path}")
print(f"✓ Total predictions: {len(predictions_df)}")
print(
f"✓ Correct predictions: {predictions_df['correct'].sum()} ({predictions_df['correct'].mean():.2%})"
)
print("\n" + "🎯 You can now analyze individual predictions in the CSV file!")
print(" Look for patterns in the mistakes your model makes.")

View File

@@ -0,0 +1,663 @@
"""
Evaluation and visualization for image similarity model.
This file contains code for plotting training metrics, analyzing model performance,
and testing the trained model.
"""
import os
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
import torch
import torch.nn as nn
from dataloader import config, create_data_loaders
from model import create_similarity_model
from sklearn.metrics import auc, classification_report, confusion_matrix, roc_curve
from torch.utils.data import DataLoader
from train import SimilarityTrainer
# Set style for plots
plt.style.use("seaborn-v0_8-darkgrid")
sns.set_palette("husl")
def plot_training_metrics(log_dir="runs/similarity"):
"""
Plot training and validation metrics from TensorBoard logs or saved metrics.
Args:
log_dir: Directory containing training logs
"""
# In a real scenario, we would read from TensorBoard logs
# For this example, we'll create simulated data to show what plots would look like
# Simulated training data (in reality, you would load this from logs)
epochs = list(range(1, 51))
# Simulated metrics
train_loss = [0.8 - 0.015 * i + np.random.normal(0, 0.02) for i in range(50)]
val_loss = [0.75 - 0.012 * i + np.random.normal(0, 0.03) for i in range(50)]
train_acc = [0.55 + 0.008 * i + np.random.normal(0, 0.01) for i in range(50)]
val_acc = [0.6 + 0.006 * i + np.random.normal(0, 0.015) for i in range(50)]
# Create figure with subplots
fig, axes = plt.subplots(2, 2, figsize=(14, 10))
# Plot 1: Training and Validation Loss
axes[0, 0].plot(epochs, train_loss, "b-", linewidth=2, label="Training Loss")
axes[0, 0].plot(epochs, val_loss, "r-", linewidth=2, label="Validation Loss")
axes[0, 0].set_xlabel("Epoch")
axes[0, 0].set_ylabel("Loss")
axes[0, 0].set_title("Training and Validation Loss")
axes[0, 0].legend()
axes[0, 0].grid(True, alpha=0.3)
# Plot 2: Training and Validation Accuracy
axes[0, 1].plot(epochs, train_acc, "b-", linewidth=2, label="Training Accuracy")
axes[0, 1].plot(epochs, val_acc, "r-", linewidth=2, label="Validation Accuracy")
axes[0, 1].set_xlabel("Epoch")
axes[0, 1].set_ylabel("Accuracy")
axes[0, 1].set_title("Training and Validation Accuracy")
axes[0, 1].legend()
axes[0, 1].grid(True, alpha=0.3)
# Plot 3: Loss difference (train - val)
loss_diff = [t - v for t, v in zip(train_loss, val_loss)]
axes[1, 0].plot(epochs, loss_diff, "g-", linewidth=2)
axes[1, 0].axhline(y=0, color="r", linestyle="--", alpha=0.5)
axes[1, 0].fill_between(
epochs,
0,
loss_diff,
where=np.array(loss_diff) > 0,
alpha=0.3,
color="red",
label="Overfitting (train > val)",
)
axes[1, 0].fill_between(
epochs,
0,
loss_diff,
where=np.array(loss_diff) < 0,
alpha=0.3,
color="green",
label="Underfitting (train < val)",
)
axes[1, 0].set_xlabel("Epoch")
axes[1, 0].set_ylabel("Loss Difference")
axes[1, 0].set_title("Train Loss - Val Loss (Overfitting Indicator)")
axes[1, 0].legend()
axes[1, 0].grid(True, alpha=0.3)
# Plot 4: Learning rate schedule (if available)
axes[1, 1].plot(
epochs, [0.0002 * (0.95**i) for i in range(50)], "purple-", linewidth=2
)
axes[1, 1].set_xlabel("Epoch")
axes[1, 1].set_ylabel("Learning Rate")
axes[1, 1].set_title("Learning Rate Schedule")
axes[1, 1].set_yscale("log")
axes[1, 1].grid(True, alpha=0.3)
plt.tight_layout()
plt.savefig(
os.path.join(log_dir, "training_metrics.png"), dpi=150, bbox_inches="tight"
)
plt.show()
print(
"Training metrics plots saved to:",
os.path.join(log_dir, "training_metrics.png"),
)
def analyze_model_performance(model, data_loader, device, threshold=0.5):
"""
Analyze model performance on a dataset.
Args:
model: Trained model
data_loader: DataLoader with test/validation data
device: torch device
threshold: Decision threshold for binary classification
Returns:
Dictionary with performance metrics
"""
model.eval()
all_predictions = []
all_targets = []
all_probabilities = []
with torch.no_grad():
for batch in data_loader:
google_img = batch["google_img"].to(device)
yandex_img = batch["yandex_img"].to(device)
target = batch["same_domain"].float().to(device)
output = model(google_img, yandex_img)
probabilities = torch.sigmoid(output).squeeze()
predictions = (probabilities > threshold).float()
all_predictions.extend(predictions.cpu().numpy())
all_targets.extend(target.cpu().numpy())
all_probabilities.extend(probabilities.cpu().numpy())
# Convert to numpy arrays
all_predictions = np.array(all_predictions)
all_targets = np.array(all_targets)
all_probabilities = np.array(all_probabilities)
# Calculate confusion matrix
cm = confusion_matrix(all_targets, all_predictions)
# Calculate metrics
tn, fp, fn, tp = cm.ravel()
accuracy = (tp + tn) / (tp + tn + fp + fn)
precision = tp / (tp + fp) if (tp + fp) > 0 else 0
recall = tp / (tp + fn) if (tp + fn) > 0 else 0
f1_score = (
2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0
)
# Create classification report
report = classification_report(
all_targets, all_predictions, target_names=["Different", "Same"]
)
# Calculate ROC curve
fpr, tpr, thresholds = roc_curve(all_targets, all_probabilities)
roc_auc = auc(fpr, tpr)
# Find optimal threshold (Youden's J statistic)
youden_j = tpr - fpr
optimal_idx = np.argmax(youden_j)
optimal_threshold = thresholds[optimal_idx]
metrics = {
"confusion_matrix": cm,
"accuracy": accuracy,
"precision": precision,
"recall": recall,
"f1_score": f1_score,
"roc_auc": roc_auc,
"optimal_threshold": optimal_threshold,
"classification_report": report,
"true_negatives": tn,
"false_positives": fp,
"false_negatives": fn,
"true_positives": tp,
}
return metrics
def plot_confusion_matrix(cm, class_names=["Different", "Same"]):
"""
Plot confusion matrix with annotations.
Args:
cm: Confusion matrix
class_names: List of class names
"""
plt.figure(figsize=(8, 6))
# Create heatmap
sns.heatmap(
cm,
annot=True,
fmt="d",
cmap="Blues",
xticklabels=class_names,
yticklabels=class_names,
)
plt.title("Confusion Matrix")
plt.ylabel("True Label")
plt.xlabel("Predicted Label")
# Add text annotations
tn, fp, fn, tp = cm.ravel()
plt.text(
0.5, -0.15, f"True Negatives: {tn}", ha="center", transform=plt.gca().transAxes
)
plt.text(
0.5, -0.20, f"False Positives: {fp}", ha="center", transform=plt.gca().transAxes
)
plt.text(
0.5, -0.25, f"False Negatives: {fn}", ha="center", transform=plt.gca().transAxes
)
plt.text(
0.5, -0.30, f"True Positives: {tp}", ha="center", transform=plt.gca().transAxes
)
plt.tight_layout()
plt.show()
# Create a summary table
print("\n" + "=" * 50)
print("CONFUSION MATRIX SUMMARY")
print("=" * 50)
summary_data = {
"Metric": [
"True Negatives",
"False Positives",
"False Negatives",
"True Positives",
],
"Count": [tn, fp, fn, tp],
"Description": [
"Correctly predicted as different",
"Incorrectly predicted as same (Type I error)",
"Incorrectly predicted as different (Type II error)",
"Correctly predicted as same",
],
}
df = pd.DataFrame(summary_data)
print(df.to_string(index=False))
print("=" * 50)
def plot_roc_curve(fpr, tpr, roc_auc):
"""
Plot ROC curve.
Args:
fpr: False positive rates
tpr: True positive rates
roc_auc: Area under ROC curve
"""
plt.figure(figsize=(8, 6))
plt.plot(
fpr, tpr, color="darkorange", lw=2, label=f"ROC curve (AUC = {roc_auc:.3f})"
)
plt.plot([0, 1], [0, 1], color="navy", lw=2, linestyle="--", label="Random")
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel("False Positive Rate")
plt.ylabel("True Positive Rate")
plt.title("Receiver Operating Characteristic (ROC) Curve")
plt.legend(loc="lower right")
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()
print(f"ROC AUC Score: {roc_auc:.4f}")
print("AUC Interpretation:")
print("0.90-1.00 = Excellent")
print("0.80-0.90 = Good")
print("0.70-0.80 = Fair")
print("0.60-0.70 = Poor")
print("0.50-0.60 = Fail")
def plot_probability_distribution(all_probabilities, all_targets):
"""
Plot probability distribution for positive and negative classes.
Args:
all_probabilities: List of predicted probabilities
all_targets: List of true labels
"""
# Separate probabilities by true class
pos_probs = [p for p, t in zip(all_probabilities, all_targets) if t == 1]
neg_probs = [p for p, t in zip(all_probabilities, all_targets) if t == 0]
plt.figure(figsize=(10, 6))
# Plot histograms
plt.hist(
pos_probs,
bins=30,
alpha=0.5,
color="green",
label="Same Domain (Positive)",
density=True,
)
plt.hist(
neg_probs,
bins=30,
alpha=0.5,
color="red",
label="Different Domain (Negative)",
density=True,
)
# Add vertical line at threshold 0.5
plt.axvline(
x=0.5,
color="black",
linestyle="--",
linewidth=2,
label="Decision Threshold (0.5)",
)
plt.xlabel("Predicted Probability")
plt.ylabel("Density")
plt.title("Probability Distribution by True Class")
plt.legend()
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()
# Print statistics
print("\nProbability Statistics:")
print(
f"Positive class (Same): Mean = {np.mean(pos_probs):.3f}, Std = {np.std(pos_probs):.3f}"
)
print(
f"Negative class (Different): Mean = {np.mean(neg_probs):.3f}, Std = {np.std(neg_probs):.3f}"
)
def test_model_on_examples(model, device, examples_dir="examples"):
"""
Test model on example image pairs.
Args:
model: Trained model
device: torch device
examples_dir: Directory containing example image pairs
"""
import cv2
from torchvision import transforms
model.eval()
# Define image preprocessing
transform = transforms.Compose(
[
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
]
)
# Check if examples directory exists
if not os.path.exists(examples_dir):
print(f"Examples directory '{examples_dir}' not found.")
print("Creating dummy examples for demonstration...")
# Create dummy example data for demonstration
examples = [
{
"name": "Example 1: Similar locations",
"google_img": torch.randn(1, 3, 224, 224),
"yandex_img": torch.randn(1, 3, 224, 224),
"expected": "Same",
},
{
"name": "Example 2: Different locations",
"google_img": torch.randn(1, 3, 224, 224),
"yandex_img": torch.randn(1, 3, 224, 224) * 2,
"expected": "Different",
},
]
else:
# In real implementation, load actual images
examples = []
print("\n" + "=" * 60)
print("MODEL TESTING ON EXAMPLES")
print("=" * 60)
results = []
for example in examples:
with torch.no_grad():
google_img = example["google_img"].to(device)
yandex_img = example["yandex_img"].to(device)
output = model(google_img, yandex_img)
probability = torch.sigmoid(output).item()
prediction = "Same" if probability > 0.5 else "Different"
result = {
"Example": example["name"],
"Predicted": prediction,
"Probability": probability,
"Expected": example.get("expected", "Unknown"),
"Correct": prediction == example.get("expected", "Unknown"),
}
results.append(result)
print(f"\n{example['name']}:")
print(f" Predicted: {prediction} (probability: {probability:.4f})")
print(f" Expected: {example.get('expected', 'Unknown')}")
print(f" Result: {'✓ CORRECT' if result['Correct'] else '✗ WRONG'}")
# Create results table
print("\n" + "=" * 60)
print("SUMMARY OF TEST RESULTS")
print("=" * 60)
df_results = pd.DataFrame(results)
print(df_results.to_string(index=False))
accuracy = df_results["Correct"].mean() * 100
print(f"\nTest Accuracy: {accuracy:.1f}%")
return df_results
def generate_performance_report(model, data_loader, device, output_dir="reports"):
"""
Generate a comprehensive performance report.
Args:
model: Trained model
data_loader: DataLoader with test data
device: torch device
output_dir: Directory to save reports
"""
os.makedirs(output_dir, exist_ok=True)
print("Generating performance report...")
# Analyze performance
metrics = analyze_model_performance(model, data_loader, device)
# Create report file
report_path = os.path.join(output_dir, "model_performance_report.txt")
with open(report_path, "w") as f:
f.write("=" * 60 + "\n")
f.write("MODEL PERFORMANCE REPORT\n")
f.write("=" * 60 + "\n\n")
f.write("1. BASIC METRICS\n")
f.write("-" * 40 + "\n")
f.write(f"Accuracy: {metrics['accuracy']:.4f}\n")
f.write(f"Precision: {metrics['precision']:.4f}\n")
f.write(f"Recall: {metrics['recall']:.4f}\n")
f.write(f"F1 Score: {metrics['f1_score']:.4f}\n")
f.write(f"ROC AUC: {metrics['roc_auc']:.4f}\n")
f.write(f"Optimal Threshold: {metrics['optimal_threshold']:.4f}\n\n")
f.write("2. CONFUSION MATRIX\n")
f.write("-" * 40 + "\n")
f.write(f"True Negatives: {metrics['true_negatives']}\n")
f.write(f"False Positives: {metrics['false_positives']}\n")
f.write(f"False Negatives: {metrics['false_negatives']}\n")
f.write(f"True Positives: {metrics['true_positives']}\n\n")
f.write("3. CLASSIFICATION REPORT\n")
f.write("-" * 40 + "\n")
f.write(metrics["classification_report"] + "\n")
f.write("4. INTERPRETATION\n")
f.write("-" * 40 + "\n")
f.write("Accuracy: Proportion of correct predictions\n")
f.write("Precision: Proportion of positive predictions that are correct\n")
f.write(
"Recall: Proportion of actual positives that are correctly identified\n"
)
f.write("F1 Score: Harmonic mean of precision and recall\n")
f.write("ROC AUC: Ability to distinguish between classes\n\n")
f.write("5. RECOMMENDATIONS\n")
f.write("-" * 40 + "\n")
if metrics["precision"] < 0.7:
f.write("- Improve precision to reduce false positives\n")
if metrics["recall"] < 0.7:
f.write("- Improve recall to reduce false negatives\n")
if metrics["f1_score"] < 0.7:
f.write("- Overall model performance needs improvement\n")
if metrics["roc_auc"] > 0.8:
f.write("- Good discrimination ability between classes\n")
else:
f.write("- Consider improving feature extraction\n")
print(f"Report saved to: {report_path}")
return metrics
def main():
"""
Main function to run evaluation and generate reports.
"""
print("=" * 60)
print("IMAGE SIMILARITY MODEL EVALUATION")
print("=" * 60)
# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
# Load configuration
config_dict = config.copy()
if isinstance(config_dict.get("image_size"), list):
config_dict["image_size"] = tuple(config_dict["image_size"])
# Create data loaders
print("\n1. Creating data loaders...")
_, val_loader = create_data_loaders(
root_dir=config_dict["data_dir"],
batch_size=config_dict["batch_size"],
train_split=config_dict["train_split"],
num_workers=config_dict["num_workers"],
image_size=config_dict["image_size"],
augment_train=False,
augment_val=False,
device=device,
)
print(f"Validation batches: {len(val_loader)}")
# Load trained model
print("\n2. Loading trained model...")
model = create_similarity_model(
model_type="cnn",
input_size=config_dict["image_size"][0]
if isinstance(config_dict["image_size"], (tuple, list))
else config_dict["image_size"],
input_channels=3,
hidden_channels=64,
num_blocks=4,
dropout_rate=0.3,
use_batch_norm=True,
)
# Load best checkpoint
checkpoint_dir = os.path.join(
config_dict.get("output_dir", "runs/similarity"), "checkpoints"
)
best_checkpoint = os.path.join(checkpoint_dir, "best_model.pt")
if os.path.exists(best_checkpoint):
checkpoint = torch.load(best_checkpoint, map_location=device)
model.load_state_dict(checkpoint["model_state_dict"])
print(f"Loaded best model from epoch {checkpoint['epoch']}")
print(f"Best validation loss: {checkpoint['val_loss']:.4f}")
else:
print("Warning: Best model checkpoint not found!")
print("Using randomly initialized model for demonstration.")
model = model.to(device)
# Plot training metrics
print("\n3. Plotting training metrics...")
plot_training_metrics(config_dict.get("output_dir", "runs/similarity"))
# Analyze model performance
print("\n4. Analyzing model performance...")
metrics = analyze_model_performance(model, val_loader, device)
# Display results
print("\n" + "=" * 60)
print("PERFORMANCE METRICS")
print("=" * 60)
print(f"Accuracy: {metrics['accuracy']:.4f}")
print(f"Precision: {metrics['precision']:.4f}")
print(f"Recall: {metrics['recall']:.4f}")
print(f"F1 Score: {metrics['f1_score']:.4f}")
print(f"ROC AUC: {metrics['roc_auc']:.4f}")
print(f"Optimal Threshold: {metrics['optimal_threshold']:.4f}")
# Plot confusion matrix
print("\n5. Plotting confusion matrix...")
plot_confusion_matrix(metrics["confusion_matrix"])
# Plot ROC curve
print("\n6. Plotting ROC curve...")
# For demonstration, we need to get probabilities again
model.eval()
all_probabilities = []
all_targets = []
with torch.no_grad():
for batch in val_loader:
google_img = batch["google_img"].to(device)
yandex_img = batch["yandex_img"].to(device)
target = batch["same_domain"].float().to(device)
output = model(google_img, yandex_img)
probabilities = torch.sigmoid(output).squeeze()
all_probabilities.extend(probabilities.cpu().numpy())
all_targets.extend(target.cpu().numpy())
all_probabilities = np.array(all_probabilities)
all_targets = np.array(all_targets)
fpr, tpr, _ = roc_curve(all_targets, all_probabilities)
roc_auc = auc(fpr, tpr)
plot_roc_curve(fpr, tpr, roc_auc)
# Plot probability distribution
print("\n7. Plotting probability distribution...")
plot_probability_distribution(all_probabilities, all_targets)
# Test on examples
print("\n8. Testing on examples...")
test_model_on_examples(model, device)
# Generate comprehensive report
print("\n9. Generating performance report...")
generate_performance_report(model, val_loader, device)
print("\n" + "=" * 60)
print("EVALUATION COMPLETED SUCCESSFULLY!")
print("=" * 60)
print("\nNext steps:")
print("1. Check the generated plots in the runs/similarity directory")
print("2. Review the performance report in the reports directory")
print("3. Consider adjusting the decision threshold if needed")
print("4. Retrain with different hyperparameters if performance is poor")
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,216 @@
"""
Пример использования модели оценки схожести с даталоадером.
"""
import torch
from dataloader import YaGoDataset, create_data_loaders
from model import SimilarityCNN, SimilarityLoss
def main():
"""Основной пример использования."""
print("ПРИМЕР ИСПОЛЬЗОВАНИЯ МОДЕЛИ СХОЖЕСТИ С ДАТАЛОАДЕРОМ")
print("=" * 60)
# Конфигурация
config = {
"data_dir": r"C:\Users\admin\Projects\autopilot\datasets\ya_go_maps\images",
"batch_size": 4,
"image_size": (256, 256),
"train_split": 0.8,
"num_workers": 0,
}
# Устройство
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Устройство: {device}")
# 1. Создание датасета
print("\n1. СОЗДАНИЕ ДАТАСЕТА")
print("-" * 40)
dataset = YaGoDataset(
root_dir=config["data_dir"],
augment=False,
image_size=config["image_size"],
)
print(f"Размер датасета: {len(dataset)} пар изображений")
# Получение примера из датасета
sample = dataset[0]
print(f"\nПример из датасета:")
print(f" Google image shape: {sample['google_img'].shape}")
print(f" Yandex image shape: {sample['yandex_img'].shape}")
print(f" Same domain: {sample['same_domain']}")
print(f" Index: {sample['idx'].item()}")
# 2. Создание даталоадеров
print("\n2. СОЗДАНИЕ ДАТАЛОАДЕРОВ")
print("-" * 40)
train_loader, val_loader = create_data_loaders(
root_dir=config["data_dir"],
batch_size=config["batch_size"],
train_split=config["train_split"],
num_workers=config["num_workers"],
image_size=config["image_size"],
augment_train=True,
augment_val=False,
device=device,
)
print(f"Train batches: {len(train_loader)}")
print(f"Val batches: {len(val_loader)}")
# 3. Создание модели
print("\n3. СОЗДАНИЕ МОДЕЛИ")
print("-" * 40)
model = SimilarityCNN(
input_channels=3,
hidden_channels=64,
num_blocks=4,
dropout_rate=0.3,
use_batch_norm=True,
).to(device)
print(f"Параметры модели: {sum(p.numel() for p in model.parameters()):,}")
# 4. Тестирование на одном батче
print("\n4. ТЕСТИРОВАНИЕ НА ОДНОМ БАТЧЕ")
print("-" * 40)
# Получаем батч из train_loader
for batch in train_loader:
google_img = batch["google_img"].to(device)
yandex_img = batch["yandex_img"].to(device)
same_domain = batch["same_domain"].float().to(device).unsqueeze(1)
print(f"Batch size: {google_img.shape[0]}")
print(f"Image shape: {google_img.shape[1:]}")
print(f"Same domain labels: {same_domain.squeeze().tolist()}")
# Предсказание схожести
with torch.no_grad():
predictions = model.predict_similarity(google_img, yandex_img)
print(f"\nПредсказания схожести:")
for i in range(len(predictions)):
print(
f" Sample {i}: {predictions[i].item():.4f} (target: {same_domain[i].item():.1f})"
)
# Расчет потерь
loss_fn = SimilarityLoss().to(device)
loss = loss_fn(predictions, same_domain)
print(f"\nПотери на батче: {loss.item():.4f}")
# Расчет метрик
metrics = loss_fn.compute_metrics(predictions, same_domain)
print("\nМетрики на батче:")
for key, value in metrics.items():
print(f" {key}: {value:.4f}")
break # Только первый батч
# 5. Обучение на одном эпохе (демонстрация)
print("\n5. ДЕМОНСТРАЦИЯ ОБУЧЕНИЯ НА ОДНОЙ ЭПОХЕ")
print("-" * 40)
optimizer = torch.optim.Adam(model.parameters(), lr=2e-4)
model.train()
total_loss = 0
total_samples = 0
for batch_idx, batch in enumerate(train_loader):
if batch_idx >= 3: # Ограничиваем 3 батчами для демонстрации
break
google_img = batch["google_img"].to(device)
yandex_img = batch["yandex_img"].to(device)
same_domain = batch["same_domain"].float().to(device).unsqueeze(1)
optimizer.zero_grad()
predictions = model(google_img, yandex_img)
loss = loss_fn(predictions, same_domain)
loss.backward()
optimizer.step()
batch_loss = loss.item() * google_img.size(0)
total_loss += batch_loss
total_samples += google_img.size(0)
print(f"Batch {batch_idx + 1}: loss = {loss.item():.4f}")
avg_loss = total_loss / total_samples
print(f"\nСредние потери за 3 батча: {avg_loss:.4f}")
# 6. Валидация
print("\n6. ВАЛИДАЦИЯ")
print("-" * 40)
model.eval()
val_loss = 0
val_samples = 0
with torch.no_grad():
for batch_idx, batch in enumerate(val_loader):
if batch_idx >= 2: # Ограничиваем 2 батчами для демонстрации
break
google_img = batch["google_img"].to(device)
yandex_img = batch["yandex_img"].to(device)
same_domain = batch["same_domain"].float().to(device).unsqueeze(1)
predictions = model.predict_similarity(google_img, yandex_img)
loss = loss_fn(predictions, same_domain)
val_loss += loss.item() * google_img.size(0)
val_samples += google_img.size(0)
print(f"Val batch {batch_idx + 1}: loss = {loss.item():.4f}")
avg_val_loss = val_loss / val_samples
print(f"\nСредние потери на валидации: {avg_val_loss:.4f}")
# 7. Пример использования для отдельных изображений
print("\n7. ПРИМЕР ДЛЯ ОТДЕЛЬНЫХ ИЗОБРАЖЕНИЙ")
print("-" * 40)
# Берем два примера из датасета
sample1 = dataset[0]
sample2 = dataset[1]
# Подготавливаем тензоры
img1_1 = sample1["google_img"].unsqueeze(0).to(device)
img1_2 = sample1["yandex_img"].unsqueeze(0).to(device)
img2_1 = sample2["google_img"].unsqueeze(0).to(device)
img2_2 = sample2["yandex_img"].unsqueeze(0).to(device)
# Предсказания
with torch.no_grad():
# Сравнение пар из одного домена
sim_same1 = model.predict_similarity(img1_1, img1_2)
sim_same2 = model.predict_similarity(img2_1, img2_2)
# Сравнение пар из разных доменов
sim_diff1 = model.predict_similarity(img1_1, img2_2)
sim_diff2 = model.predict_similarity(img2_1, img1_2)
print("Сравнение пар изображений:")
print(f" Пара 1 (один домен): {sim_same1.item():.4f}")
print(f" Пара 2 (один домен): {sim_same2.item():.4f}")
print(f" Разные домены 1: {sim_diff1.item():.4f}")
print(f" Разные домены 2: {sim_diff2.item():.4f}")
print("\n" + "=" * 60)
print("ПРИМЕР ЗАВЕРШЕН")
print("=" * 60)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,219 @@
from typing import Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import models
class SimilarityCNN(nn.Module):
"""
Модель для оценки схожести двух изображений на базе предобученного бэкбона.
Интерфейс совместим с исходной:
- forward(img1, img2) -> тензор (B, 1) со скором в [0, 1]
- predict_similarity(img1, img2) -> тензор (B, 1) без градиентов
"""
def __init__(
self,
input_channels: int = 3,
backbone_name: str = "resnet18",
pretrained: bool = True,
dropout_rate: float = 0.3,
use_batch_norm: bool = True,
):
super().__init__()
self.input_channels = input_channels
self.backbone_name = backbone_name
self.pretrained = pretrained
self.dropout_rate = dropout_rate
self.use_batch_norm = use_batch_norm
# 1. Создаём бэкбон и берём фичи до последнего FC
backbone = self._create_backbone(backbone_name, pretrained)
# Для ResNet18 выход фичей = 512
self.feature_dim = backbone.fc.in_features
# Заменяем classification head на Identity, чтобы получать только признаки
backbone.fc = nn.Identity()
self.backbone = backbone
# 2. Голова для сравнения двух векторов признаков
# Вход: [f1, f2, |f1 - f2|, f1 * f2] => 4 * feature_dim
compare_input_dim = self.feature_dim * 4
layers = [
nn.Linear(compare_input_dim, 512),
nn.BatchNorm1d(512) if use_batch_norm else nn.Identity(),
nn.ReLU(inplace=True),
nn.Dropout(dropout_rate),
nn.Linear(512, 256),
nn.BatchNorm1d(256) if use_batch_norm else nn.Identity(),
nn.ReLU(inplace=True),
nn.Dropout(dropout_rate),
nn.Linear(256, 1),
nn.Sigmoid(), # выход в [0, 1]
]
self.head = nn.Sequential(*layers)
def _create_backbone(self, name: str, pretrained: bool) -> nn.Module:
name = name.lower()
if name == "resnet18":
model = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1 if pretrained else None)
elif name == "resnet34":
model = models.resnet34(weights=models.ResNet34_Weights.IMAGENET1K_V1 if pretrained else None)
else:
raise ValueError(f"Unsupported backbone: {name}")
# Если у тебя не 3 канала, можно добавить адаптер 1x1 conv перед model.conv1
if self.input_channels != 3:
old_conv = model.conv1
model.conv1 = nn.Conv2d(
self.input_channels,
old_conv.out_channels,
kernel_size=old_conv.kernel_size,
stride=old_conv.stride,
padding=old_conv.padding,
bias=old_conv.bias is not None,
)
return model
def _extract_features(self, x: torch.Tensor) -> torch.Tensor:
"""
Прогоняет одно изображение через бэкбон и возвращает вектор признаков (B, feature_dim).
Для ResNet: это эквивалентно model.forward(x), когда fc = Identity.
"""
return self.backbone(x) # (B, feature_dim)
def forward(self, img1: torch.Tensor, img2: torch.Tensor) -> torch.Tensor:
"""
img1, img2: (B, C, H, W) -> similarity: (B, 1)
"""
f1 = self._extract_features(img1) # (B, D)
f2 = self._extract_features(img2) # (B, D)
# Вектор сравнения
diff = torch.abs(f1 - f2)
prod = f1 * f2
combined = torch.cat([f1, f2, diff, prod], dim=1) # (B, 4D)
similarity = self.head(combined) # (B, 1) в [0, 1]
return similarity
def predict_similarity(self, img1: torch.Tensor, img2: torch.Tensor) -> torch.Tensor:
"""
Инференс без градиентов, интерфейс как у исходной модели.
"""
was_training = self.training
self.eval()
with torch.no_grad():
sim = self.forward(img1, img2)
if was_training:
self.train()
return sim
class SimilarityLoss(nn.Module):
"""
Оставляю тот же интерфейс loss, что и в твоём коде.
Если таргет бинарный (0/1), BCELoss подходит.
"""
def __init__(self):
super().__init__()
self.criterion = nn.BCELoss()
def forward(self, pred_similarity: torch.Tensor, target_same: torch.Tensor) -> torch.Tensor:
return self.criterion(pred_similarity, target_same)
def compute_metrics(
self,
pred_similarity: torch.Tensor,
target_same: torch.Tensor,
threshold: float = 0.5,
) -> dict:
with torch.no_grad():
pred_binary = (pred_similarity > threshold).float()
target_binary = (target_same > 0.5).float()
correct = (pred_binary == target_binary).float()
accuracy = correct.mean().item()
tp = ((pred_binary == 1) & (target_binary == 1)).float().sum().item()
fp = ((pred_binary == 1) & (target_binary == 0)).float().sum().item()
fn = ((pred_binary == 0) & (target_binary == 1)).float().sum().item()
tn = ((pred_binary == 0) & (target_binary == 0)).float().sum().item()
precision = tp / (tp + fp + 1e-8)
recall = tp / (tp + fn + 1e-8)
f1 = 2 * precision * recall / (precision + recall + 1e-8)
return {
"accuracy": accuracy,
"precision": precision,
"recall": recall,
"f1": f1,
"mean_similarity": pred_similarity.mean().item(),
}
def create_similarity_model(
model_type: str = "backbone",
input_size: Tuple[int, int] = (256, 256),
**kwargs,
) -> nn.Module:
"""
Аналог вашей фабрики, но с новым типом модели.
"""
if model_type == "backbone":
return SimilarityCNN(**kwargs)
else:
raise ValueError(f"Unknown model type: {model_type}")
if __name__ == "__main__":
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
model = SimilarityCNN(
input_channels=3,
backbone_name="resnet18",
pretrained=True,
dropout_rate=0.3,
use_batch_norm=True,
).to(device)
print(
f"Model created with {sum(p.numel() for p in model.parameters()):,} parameters"
)
batch_size = 4
height, width = 256, 256
img1 = torch.randn(batch_size, 3, height, width).to(device)
img2 = torch.randn(batch_size, 3, height, width).to(device)
print("\nTesting forward pass...")
output = model(img1, img2)
print(f"Output shape: {output.shape}")
print(f"Sample output: {output[0].item():.4f}")
print("\nTesting prediction...")
pred = model.predict_similarity(img1, img2)
print(f"Prediction shape: {pred.shape}")
print("\nTesting loss function...")
target = torch.rand(batch_size, 1).to(device)
loss_fn = SimilarityLoss().to(device)
loss = loss_fn(output, target)
print(f"Loss value: {loss.item():.6f}")
print("\nTesting metrics...")
metrics = loss_fn.compute_metrics(output, target)
for key, value in metrics.items():
print(f"{key}: {value:.6f}")
print("\nAll tests completed successfully!")

View File

@@ -0,0 +1,146 @@
"""
Script for predicting similarity between two images.
"""
import argparse
import os
from pathlib import Path
import torch
from model import SimilarityCNN
from PIL import Image
from torchvision import transforms
def load_image(image_path: str, image_size: tuple = (256, 256)) -> torch.Tensor:
"""Load and preprocess image."""
transform = transforms.Compose(
[
transforms.Resize(image_size),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
]
)
image = Image.open(image_path).convert("RGB")
return transform(image).unsqueeze(0) # Add batch dimension
def predict_similarity(
model: SimilarityCNN,
image1_path: str,
image2_path: str,
device: torch.device,
image_size: tuple = (256, 256),
) -> float:
"""Predict similarity between two images."""
model.eval()
img1 = load_image(image1_path, image_size).to(device)
img2 = load_image(image2_path, image_size).to(device)
with torch.no_grad():
similarity = model(img1, img2)
return similarity.item()
def load_model(
checkpoint_path: str,
device: torch.device,
**model_kwargs,
) -> SimilarityCNN:
"""Load model from checkpoint."""
model = SimilarityCNN(**model_kwargs).to(device)
checkpoint = torch.load(checkpoint_path, map_location=device)
model.load_state_dict(checkpoint["model_state_dict"])
return model
def main():
parser = argparse.ArgumentParser(
description="Predict similarity between two images"
)
parser.add_argument("--image1", type=str, required=True, help="Path to first image")
parser.add_argument(
"--image2", type=str, required=True, help="Path to second image"
)
parser.add_argument(
"--checkpoint",
type=str,
default="runs/similarity/checkpoints/best_model.pt",
help="Path to model checkpoint",
)
parser.add_argument(
"--device",
type=str,
default="cuda" if torch.cuda.is_available() else "cpu",
help="Device to use for inference",
)
parser.add_argument(
"--image_size",
type=int,
default=256,
help="Image size for model input",
)
args = parser.parse_args()
device = torch.device(args.device)
print(f"Using device: {device}")
if not os.path.exists(args.image1):
print(f"Error: Image not found: {args.image1}")
return
if not os.path.exists(args.image2):
print(f"Error: Image not found: {args.image2}")
return
if not os.path.exists(args.checkpoint):
print(f"Warning: Checkpoint not found: {args.checkpoint}")
print("Using randomly initialized model for demonstration")
model = SimilarityCNN(
input_channels=3,
hidden_channels=64,
num_blocks=4,
dropout_rate=0.3,
use_batch_norm=True,
).to(device)
else:
print(f"Loading model from: {args.checkpoint}")
model = load_model(
checkpoint_path=args.checkpoint,
device=device,
input_channels=3,
hidden_channels=64,
num_blocks=4,
dropout_rate=0.3,
use_batch_norm=True,
)
print(
f"Model loaded with {sum(p.numel() for p in model.parameters()):,} parameters"
)
similarity = predict_similarity(
model=model,
image1_path=args.image1,
image2_path=args.image2,
device=device,
image_size=(args.image_size, args.image_size),
)
print(f"\nSimilarity between images:")
print(f" Image 1: {args.image1}")
print(f" Image 2: {args.image2}")
print(f" Similarity score: {similarity:.4f}")
print(f" Interpretation: {'Similar' if similarity > 0.5 else 'Different'}")
return similarity
if __name__ == "__main__":
main()

File diff suppressed because one or more lines are too long

View File

@@ -0,0 +1,256 @@
"""
ПРОСТОЕ ОБЪЯСНЕНИЕ РЕЗУЛЬТАТОВ ОБУЧЕНИЯ
========================================
Этот файл объясняет результаты обучения модели простыми словами,
как будто ты студент, который только начал изучать машинное обучение.
Представь, что train.py - это предыдущая ячейка в твоем блокноте,
где ты обучил модель. Теперь давай посмотрим, что у нас получилось!
"""
print("=" * 70)
print("ПРОСТОЕ ОБЪЯСНЕНИЕ РЕЗУЛЬТАТОВ МОДЕЛИ")
print("=" * 70)
print()
# -------------------------------------------------------------------
# ЧАСТЬ 1: ЧТО МЫ СДЕЛАЛИ?
# -------------------------------------------------------------------
print("1. ЧТО МЫ СДЕЛАЛИ?")
print("-" * 40)
print("Мы создали модель, которая смотрит на две картинки и говорит:")
print(" - 'ДА' - если это одно и то же место (с Google и Яндекс карт)")
print(" - 'НЕТ' - если это разные места")
print()
print("Модель училась на тысячах пар картинок!")
print("Сначала она делала много ошибок, но потом научилась.")
print()
# -------------------------------------------------------------------
# ЧАСТЬ 2: КАК МЫ ИЗМЕРЯЕМ УСПЕХ?
# -------------------------------------------------------------------
print("2. КАК МЫ ИЗМЕРЯЕМ УСПЕХ?")
print("-" * 40)
print("Мы проверяем модель на новых картинках, которых она не видела.")
print("Считаем, сколько раз она угадала правильно.")
print()
print("Есть 4 возможных исхода:")
print(" 1. ✅ Истинно-положительный (True Positive - TP):")
print(" Модель сказала 'ДА' и это правда 'ДА'")
print()
print(" 2. ❌ Ложно-положительный (False Positive - FP):")
print(" Модель сказала 'ДА', но на самом деле 'НЕТ'")
print(" (Ошибка типа I: приняла разные места за одинаковые)")
print()
print(" 3. ❌ Ложно-отрицательный (False Negative - FN):")
print(" Модель сказала 'НЕТ', но на самом деле 'ДА'")
print(" (Ошибка типа II: не узнала одинаковые места)")
print()
print(" 4. ✅ Истинно-отрицательный (True Negative - TN):")
print(" Модель сказала 'НЕТ' и это правда 'НЕТ'")
print()
# -------------------------------------------------------------------
# ЧАСТЬ 3: ПРОСТЫЕ МЕТРИКИ
# -------------------------------------------------------------------
print("3. ПРОСТЫЕ МЕТРИКИ (ЧТО ОНИ ЗНАЧАТ?)")
print("-" * 40)
# Примерные результаты (в реальности будут другие)
accuracy = 0.82 # 82%
precision = 0.78 # 78%
recall = 0.85 # 85%
f1_score = 0.81 # 81%
print(f"ТОЧНОСТЬ (Accuracy): {accuracy:.0%}")
print(" Это как общая оценка в школе.")
print(" Сколько всего ответов правильных из 100.")
print(f" Наша модель правильна в {accuracy:.0%} случаев.")
print()
print(f"ТОЧНОСТЬ КЛАССИФИКАЦИИ (Precision): {precision:.0%}")
print(" Когда модель говорит 'ДА', насколько ей можно верить?")
print(" Из 100 раз когда она сказала 'ДА', {precision:.0%} были правдой.")
print(" Высокая точность = мало ложных 'ДА'.")
print()
print(f"ПОЛНОТА (Recall): {recall:.0%}")
print(" Сколько настоящих 'ДА' модель нашла?")
print(f" Из 100 настоящих 'ДА', модель нашла {recall:.0%}.")
print(" Высокая полнота = мало пропущенных 'ДА'.")
print()
print(f"F1-МЕРА (F1 Score): {f1_score:.0%}")
print(" Баланс между точностью и полнотой.")
print(" Как золотая середина - не слишком строгая, не слишком добрая.")
print()
# -------------------------------------------------------------------
# ЧАСТЬ 4: ТАБЛИЦА РЕЗУЛЬТАТОВ (ПРОСТАЯ)
# -------------------------------------------------------------------
print("4. ТАБЛИЦА РЕЗУЛЬТАТОВ")
print("-" * 40)
print("Давай представим, что мы протестировали модель на 1000 пар картинок:")
print()
# Простая таблица
print(" | Модель сказала 'ДА' | Модель сказала 'НЕТ' | Всего")
print("-----------------|---------------------|----------------------|-------")
print(f"На самом деле 'ДА' | TP: 425 | FN: 75 | 500")
print(f"На самом деле 'НЕТ' | FP: 95 | TN: 405 | 500")
print("-----------------|---------------------|----------------------|-------")
print(f"Всего | 520 | 480 | 1000")
print()
print("Расчеты:")
print(f" Точность = (TP + TN) / Всего = (425 + 405) / 1000 = {accuracy:.0%}")
print(f" Точность классификации = TP / (TP + FP) = 425 / 520 = {precision:.0%}")
print(f" Полнота = TP / (TP + FN) = 425 / 500 = {recall:.0%}")
print()
# -------------------------------------------------------------------
# ЧАСТЬ 5: КАК ИНТЕРПРЕТИРОВАТЬ РЕЗУЛЬТАТЫ?
# -------------------------------------------------------------------
print("5. ЧТО ЭТО ЗНАЧИТ ДЛЯ НАШЕЙ ЗАДАЧИ?")
print("-" * 40)
if precision > 0.75:
print("✅ ХОРОШО: Когда модель говорит 'это одно место',")
print(" ей можно доверять ({precision:.0%} случаев она права).")
else:
print("⚠ МОЖНО ЛУЧШЕ: Модель иногда путает разные места с одинаковыми.")
if recall > 0.75:
print("✅ ХОРОШО: Модель находит большинство одинаковых мест")
print(f" ({recall:.0%} настоящих 'одинаковых' мест она находит).")
else:
print("⚠ МОЖНО ЛУЧШЕ: Модель пропускает много одинаковых мест.")
print()
print("ДЛЯ АВТОПИЛОТА:")
print(" - Ложные 'ДА' (FP): Может думать, что мы в нужном месте,")
print(" когда это не так → опасно!")
print(" - Ложные 'НЕТ' (FN): Не узнает нужное место → менее опасно,")
print(" но машина может проехать мимо.")
print()
# -------------------------------------------------------------------
# ЧАСТЬ 6: ГРАФИКИ (ЧТО МЫ ВИДИМ?)
# -------------------------------------------------------------------
print("6. КАКИЕ ГРАФИКИ МЫ ПОЛУЧАЕМ?")
print("-" * 40)
print("После обучения мы строим 4 основных графика:")
print()
print("1. 📉 ГРАФИК ОШИБОК (Loss):")
print(" - Синяя линия: ошибки на обучающих данных")
print(" - Красная линия: ошибки на проверочных данных")
print(" - ХОРОШО: обе линии идут вниз и близки друг к другу")
print(" - ПЛОХО: линии далеко друг от друга (переобучение)")
print()
print("2. 📈 ГРАФИК ТОЧНОСТИ (Accuracy):")
print(" - Показывает, как растет точность со временем")
print(" - Должен расти и стабилизироваться")
print()
print("3. 🎯 МАТРИЦА ОШИБОК (Confusion Matrix):")
print(" - Квадратная таблица 2x2")
print(" - Показывает все 4 типа ответов (TP, FP, FN, TN)")
print(" - Идеально: все числа на диагонали, нули вне диагонали")
print()
print("4. 📊 ROC-КРИВАЯ:")
print(" - Показывает, насколько хорошо модель отличает 'ДА' от 'НЕТ'")
print(" - Чем больше площадь под кривой, тем лучше")
print(" - Идеально: площадь = 1.0 (100%)")
print()
# -------------------------------------------------------------------
# ЧАСТЬ 7: ЧТО ДЕЛАТЬ ДАЛЬШЕ?
# -------------------------------------------------------------------
print("7. ЧТО ДЕЛАТЬ, ЕСЛИ РЕЗУЛЬТАТЫ ПЛОХИЕ?")
print("-" * 40)
print("Если точность меньше 70%:")
print("1. 🎯 ПРОБЛЕМА: Модель плохо учится")
print(" РЕШЕНИЕ:")
print(" - Учить дольше (увеличить количество эпох)")
print(" - Изменить скорость обучения (learning rate)")
print(" - Добавить больше данных для обучения")
print()
print("2. 🎯 ПРОБЛЕМА: Модель запоминает, а не учится (переобучение)")
print(" РЕШЕНИЕ:")
print(" - Добавить регуляризацию (dropout)")
print(" - Использовать augmentation (искажать картинки)")
print(" - Упростить модель (меньше слоев)")
print()
print("3. 🎯 ПРОБЛЕМА: Много ложных 'ДА' (FP)")
print(" РЕШЕНИЕ:")
print(" - Повысить порог принятия решения (например, 0.7 вместо 0.5)")
print(" - Добавить больше примеров 'разных' мест")
print()
print("4. 🎯 ПРОБЛЕМА: Много ложных 'НЕТ' (FN)")
print(" РЕШЕНИЕ:")
print(" - Понизить порог принятия решения (например, 0.3 вместо 0.5)")
print(" - Добавить больше примеров 'одинаковых' мест")
print()
# -------------------------------------------------------------------
# ЧАСТЬ 8: ПРАКТИЧЕСКИЙ ПРИМЕР
# -------------------------------------------------------------------
print("8. ПРАКТИЧЕСКИЙ ПРИМЕР: КАК ИСПОЛЬЗОВАТЬ МОДЕЛЬ")
print("-" * 40)
print("После обучения модель можно использовать так:")
print()
print("```python")
print("# 1. Загружаем обученную модель")
print("model = load_trained_model('best_model.pt')")
print()
print("# 2. Берем две картинки")
print("google_img = load_image('google_map.png')")
print("yandex_img = load_image('yandex_map.png')")
print()
print("# 3. Спрашиваем у модели")
print("similarity_score = model.predict(google_img, yandex_img)")
print()
print("# 4. Интерпретируем результат")
print("if similarity_score > 0.5:")
print(" print('✅ Это похоже на одно и то же место!')")
print("else:")
print(" print('❌ Это разные места')")
print("```")
print()
print(f"Порог 0.5 можно менять:")
print(f" - Порог 0.7: более строгая модель (меньше ложных 'ДА')")
print(f" - Порог 0.3: более добрая модель (меньше ложных 'НЕТ')")
print()
# -------------------------------------------------------------------
# ЧАСТЬ 9: ЗАКЛЮЧЕНИЕ
# -------------------------------------------------------------------
print("9. ЧТО МЫ УЗНАЛИ?")
print("-" * 40)
print("✅ Модель учится сравнивать картинки")
print("✅ Мы можем измерить, насколько она хороша")
print("✅ Есть разные метрики для разных целей")
print("✅ Графики помогают понять процесс обучения")
print("✅ Можно улучшить модель, если результаты плохие")
print()
print("=" * 70)
print("🎉 ВОТ И ВСЁ! ТЕПЕРЬ ТЫ ЗНАЕШЬ, КАК ОЦЕНИВАТЬ МОДЕЛЬ!")
print("=" * 70)
print()
print("Следующие шаги:")
print("1. Запусти evaluation.py чтобы увидеть реальные графики")
print("2. Посмотри на матрицу ошибок - какие ошибки чаще?")
print("3. Попробуй изменить порог принятия решений")
print("4. Если нужно - переобучи модель с другими параметрами")

View File

@@ -0,0 +1,917 @@
import os
import time
from datetime import datetime
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm
# =============================================================================
# TRAINING LOOP WITH VISUALIZATION
# =============================================================================
class SimilarityTrainer:
def __init__(
self,
model: nn.Module,
trainloader: DataLoader,
valloader: DataLoader,
device: torch.device,
config: dict,
):
self.model = model.to(device)
self.trainloader = trainloader
self.valloader = valloader
self.device = device
self.config = config
self.criterion = SimilarityLoss()
self.optimizer = optim.Adam(
model.parameters(),
lr=config.get('learning_rate', 2e-4),
betas=(config.get('beta1', 0.5), config.get('beta2', 0.999))
)
self.writer = None
self.best_val_loss = float('inf')
self.epochs_without_improvement = 0
# Для хранения истории метрик
self.history = {
'train_loss': [],
'val_loss': [],
'val_accuracy': [],
'val_precision': [],
'val_recall': [],
'val_f1': [],
'learning_rate': []
}
def train_epoch(self, epoch: int) -> dict:
"""Обучение на одной эпохе"""
self.model.train()
total_loss = 0
total_samples = 0
all_metrics = []
pbar = tqdm(self.trainloader, desc=f'Epoch {epoch}')
for batch_idx, batch in enumerate(pbar):
google_img = batch['google_img'].to(self.device)
yandex_img = batch['yandex_img'].to(self.device)
target = batch['same_domain'].float().to(self.device).unsqueeze(1)
self.optimizer.zero_grad()
# Forward pass
output = self.model(google_img, yandex_img)
loss = self.criterion(output, target)
# Backward pass
loss.backward()
# Gradient clipping
if self.config.get('grad_clip', None):
torch.nn.utils.clip_grad_norm_(
self.model.parameters(),
self.config['grad_clip']
)
self.optimizer.step()
total_loss += loss.item() * google_img.size(0)
total_samples += google_img.size(0)
# Compute metrics
if batch_idx % self.config.get('log_interval', 10) == 0:
metrics = self.criterion.compute_metrics(output, target)
all_metrics.append(metrics)
pbar.set_postfix({
'loss': f"{loss.item():.4f}",
'acc': f"{metrics['accuracy']:.4f}"
})
if self.writer:
global_step = epoch * len(self.trainloader) + batch_idx
self.writer.add_scalar('train/loss', loss.item(), global_step)
self.writer.add_scalar('train/accuracy', metrics['accuracy'], global_step)
avg_loss = total_loss / total_samples
# Average metrics
if all_metrics:
avg_metrics = {
key: sum(m[key] for m in all_metrics) / len(all_metrics)
for key in all_metrics[0].keys()
}
else:
avg_metrics = {}
return {'loss': avg_loss, **avg_metrics}
def validate(self) -> dict:
"""Валидация модели"""
self.model.eval()
total_loss = 0
total_samples = 0
all_metrics = []
# Для ROC и confusion matrix
all_predictions = []
all_targets = []
with torch.no_grad():
for batch in tqdm(self.valloader, desc='Validation'):
google_img = batch['google_img'].to(self.device)
yandex_img = batch['yandex_img'].to(self.device)
target = batch['same_domain'].float().to(self.device).unsqueeze(1)
output = self.model(google_img, yandex_img)
loss = self.criterion(output, target)
total_loss += loss.item() * google_img.size(0)
total_samples += google_img.size(0)
metrics = self.criterion.compute_metrics(output, target)
all_metrics.append(metrics)
all_predictions.append(output.cpu())
all_targets.append(target.cpu())
avg_loss = total_loss / total_samples
avg_metrics = {
key: sum(m[key] for m in all_metrics) / len(all_metrics)
for key in all_metrics[0].keys()
}
# Concatenate all predictions and targets
all_predictions = torch.cat(all_predictions, dim=0)
all_targets = torch.cat(all_targets, dim=0)
return {
'loss': avg_loss,
**avg_metrics,
'predictions': all_predictions,
'targets': all_targets
}
def train(self, num_epochs: int):
"""Основной цикл обучения"""
log_dir = os.path.join(self.config.get('output_dir', 'runs/similarity'))
os.makedirs(log_dir, exist_ok=True)
self.writer = SummaryWriter(log_dir)
print(f'\n{"="*70}')
print(f'Starting training for {num_epochs} epochs')
print(f'Logging to {log_dir}')
print(f'{"="*70}\n')
start_time = time.time()
for epoch in range(1, num_epochs + 1):
epoch_start = time.time()
print(f'\n--- Epoch {epoch}/{num_epochs} ---')
# Train
train_metrics = self.train_epoch(epoch)
# Validate
val_metrics = self.validate()
# Store history
self.history['train_loss'].append(train_metrics['loss'])
self.history['val_loss'].append(val_metrics['loss'])
self.history['val_accuracy'].append(val_metrics['accuracy'])
self.history['val_precision'].append(val_metrics['precision'])
self.history['val_recall'].append(val_metrics['recall'])
self.history['val_f1'].append(val_metrics['f1'])
self.history['learning_rate'].append(
self.optimizer.param_groups[0]['lr']
)
# Print metrics
print(f'\nTrain Loss: {train_metrics["loss"]:.4f}')
print(f'Val Loss: {val_metrics["loss"]:.4f}')
print(f'Val Accuracy: {val_metrics["accuracy"]:.4f}')
print(f'Val Precision: {val_metrics["precision"]:.4f}')
print(f'Val Recall: {val_metrics["recall"]:.4f}')
print(f'Val F1: {val_metrics["f1"]:.4f}')
epoch_time = time.time() - epoch_start
print(f'Epoch time: {epoch_time:.2f}s')
# TensorBoard logging
if self.writer:
self.writer.add_scalar('epoch/train_loss', train_metrics['loss'], epoch)
self.writer.add_scalar('epoch/val_loss', val_metrics['loss'], epoch)
self.writer.add_scalar('epoch/val_accuracy', val_metrics['accuracy'], epoch)
self.writer.add_scalar('epoch/val_precision', val_metrics['precision'], epoch)
self.writer.add_scalar('epoch/val_recall', val_metrics['recall'], epoch)
self.writer.add_scalar('epoch/val_f1', val_metrics['f1'], epoch)
# Save checkpoint
if val_metrics['loss'] < self.best_val_loss:
self.best_val_loss = val_metrics['loss']
self.epochs_without_improvement = 0
self.save_checkpoint(epoch, val_metrics['loss'], is_best=True)
print(f'✓ New best model saved with val loss: {val_metrics["loss"]:.4f}')
else:
self.epochs_without_improvement += 1
if epoch % self.config.get('save_interval', 5) == 0:
self.save_checkpoint(epoch, val_metrics['loss'], is_best=False)
# Early stopping
patience = self.config.get('early_stopping_patience', 20)
if self.epochs_without_improvement >= patience:
print(f'\n⚠ Early stopping triggered after {patience} epochs without improvement')
break
total_time = time.time() - start_time
print(f'\n{"="*70}')
print(f'Training completed in {total_time/60:.2f} minutes')
print(f'Best validation loss: {self.best_val_loss:.4f}')
print(f'{"="*70}\n')
self.writer.close()
return self.history
def save_checkpoint(self, epoch: int, val_loss: float, is_best: bool = False):
"""Сохранение чекпоинта модели"""
checkpoint_dir = os.path.join(
self.config.get('output_dir', 'runs/similarity'),
'checkpoints'
)
os.makedirs(checkpoint_dir, exist_ok=True)
checkpoint = {
'epoch': epoch,
'model_state_dict': self.model.state_dict(),
'optimizer_state_dict': self.optimizer.state_dict(),
'val_loss': val_loss,
'config': self.config,
'history': self.history
}
checkpoint_path = os.path.join(checkpoint_dir, f'checkpoint_epoch{epoch}.pt')
torch.save(checkpoint, checkpoint_path)
if is_best:
best_path = os.path.join(checkpoint_dir, 'best_model.pt')
torch.save(checkpoint, best_path)
def load_checkpoint(self, checkpoint_path: str):
"""Загрузка чекпоинта"""
checkpoint = torch.load(checkpoint_path, map_location=self.device)
self.model.load_state_dict(checkpoint['model_state_dict'])
self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
if 'history' in checkpoint:
self.history = checkpoint['history']
return checkpoint['epoch'], checkpoint['val_loss']
# =============================================================================
# VISUALIZATION FUNCTIONS
# =============================================================================
def plot_training_history(history: dict, save_path: str = None):
"""Построение графиков обучения"""
fig, axes = plt.subplots(2, 3, figsize=(18, 10))
fig.suptitle('Training History - Siamese Network для корреляции снимков',
fontsize=16, fontweight='bold')
epochs = range(1, len(history['train_loss']) + 1)
# Loss
axes[0, 0].plot(epochs, history['train_loss'], 'b-', label='Train Loss', linewidth=2)
axes[0, 0].plot(epochs, history['val_loss'], 'r-', label='Val Loss', linewidth=2)
axes[0, 0].set_xlabel('Epoch')
axes[0, 0].set_ylabel('Loss')
axes[0, 0].set_title('Loss Curves')
axes[0, 0].legend()
axes[0, 0].grid(True, alpha=0.3)
# Accuracy
axes[0, 1].plot(epochs, history['val_accuracy'], 'g-', linewidth=2)
axes[0, 1].set_xlabel('Epoch')
axes[0, 1].set_ylabel('Accuracy')
axes[0, 1].set_title('Validation Accuracy')
axes[0, 1].grid(True, alpha=0.3)
axes[0, 1].set_ylim([0, 1])
# F1 Score
axes[0, 2].plot(epochs, history['val_f1'], 'm-', linewidth=2)
axes[0, 2].set_xlabel('Epoch')
axes[0, 2].set_ylabel('F1 Score')
axes[0, 2].set_title('Validation F1 Score')
axes[0, 2].grid(True, alpha=0.3)
axes[0, 2].set_ylim([0, 1])
# Precision
axes[1, 0].plot(epochs, history['val_precision'], 'c-', linewidth=2)
axes[1, 0].set_xlabel('Epoch')
axes[1, 0].set_ylabel('Precision')
axes[1, 0].set_title('Validation Precision')
axes[1, 0].grid(True, alpha=0.3)
axes[1, 0].set_ylim([0, 1])
# Recall
axes[1, 1].plot(epochs, history['val_recall'], 'y-', linewidth=2)
axes[1, 1].set_xlabel('Epoch')
axes[1, 1].set_ylabel('Recall')
axes[1, 1].set_title('Validation Recall')
axes[1, 1].grid(True, alpha=0.3)
axes[1, 1].set_ylim([0, 1])
# Learning Rate
axes[1, 2].plot(epochs, history['learning_rate'], 'k-', linewidth=2)
axes[1, 2].set_xlabel('Epoch')
axes[1, 2].set_ylabel('Learning Rate')
axes[1, 2].set_title('Learning Rate Schedule')
axes[1, 2].grid(True, alpha=0.3)
axes[1, 2].set_yscale('log')
plt.tight_layout()
if save_path:
plt.savefig(save_path, dpi=300, bbox_inches='tight')
print(f'Training history plot saved to {save_path}')
plt.show()
def plot_roc_curve(predictions: torch.Tensor, targets: torch.Tensor, save_path: str = None):
"""Построение ROC кривой"""
from sklearn.metrics import roc_curve, auc
predictions_np = predictions.numpy().flatten()
targets_np = targets.numpy().flatten()
fpr, tpr, thresholds = roc_curve(targets_np, predictions_np)
roc_auc = auc(fpr, tpr)
plt.figure(figsize=(10, 8))
plt.plot(fpr, tpr, color='darkorange', lw=2,
label=f'ROC curve (AUC = {roc_auc:.3f})')
plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--', label='Random')
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('False Positive Rate', fontsize=12)
plt.ylabel('True Positive Rate', fontsize=12)
plt.title('ROC Curve - Siamese Network', fontsize=14, fontweight='bold')
plt.legend(loc="lower right", fontsize=12)
plt.grid(True, alpha=0.3)
if save_path:
plt.savefig(save_path, dpi=300, bbox_inches='tight')
print(f'ROC curve saved to {save_path}')
plt.show()
return roc_auc
def plot_confusion_matrix(predictions: torch.Tensor, targets: torch.Tensor,
threshold: float = 0.5, save_path: str = None):
"""Построение матрицы ошибок"""
from sklearn.metrics import confusion_matrix
predictions_binary = (predictions.numpy().flatten() >= threshold).astype(int)
targets_binary = (targets.numpy().flatten() >= 0.5).astype(int)
cm = confusion_matrix(targets_binary, predictions_binary)
plt.figure(figsize=(10, 8))
plt.imshow(cm, interpolation='nearest', cmap=plt.cm.Blues)
plt.title('Confusion Matrix - Correlation Detection', fontsize=14, fontweight='bold')
plt.colorbar()
classes = ['Different Domains', 'Same Domain']
tick_marks = np.arange(len(classes))
plt.xticks(tick_marks, classes, rotation=45, fontsize=12)
plt.yticks(tick_marks, classes, fontsize=12)
# Add text annotations
thresh = cm.max() / 2.
for i in range(cm.shape[0]):
for j in range(cm.shape[1]):
plt.text(j, i, format(cm[i, j], 'd'),
ha="center", va="center",
color="white" if cm[i, j] > thresh else "black",
fontsize=16, fontweight='bold')
plt.ylabel('True Label', fontsize=12)
plt.xlabel('Predicted Label', fontsize=12)
plt.tight_layout()
if save_path:
plt.savefig(save_path, dpi=300, bbox_inches='tight')
print(f'Confusion matrix saved to {save_path}')
plt.show()
def plot_similarity_distribution(predictions: torch.Tensor, targets: torch.Tensor,
save_path: str = None):
"""Распределение предсказанных значений схожести"""
predictions_np = predictions.numpy().flatten()
targets_np = targets.numpy().flatten()
same_domain = predictions_np[targets_np >= 0.5]
diff_domain = predictions_np[targets_np < 0.5]
plt.figure(figsize=(12, 6))
plt.subplot(1, 2, 1)
plt.hist(same_domain, bins=50, alpha=0.7, color='green', edgecolor='black', label='Same Domain')
plt.hist(diff_domain, bins=50, alpha=0.7, color='red', edgecolor='black', label='Different Domains')
plt.xlabel('Predicted Similarity Score', fontsize=12)
plt.ylabel('Frequency', fontsize=12)
plt.title('Distribution of Similarity Scores', fontsize=14, fontweight='bold')
plt.legend(fontsize=11)
plt.grid(True, alpha=0.3)
plt.subplot(1, 2, 2)
plt.boxplot([diff_domain, same_domain], labels=['Different', 'Same'])
plt.ylabel('Similarity Score', fontsize=12)
plt.xlabel('Domain Match', fontsize=12)
plt.title('Similarity Score by Domain Match', fontsize=14, fontweight='bold')
plt.grid(True, alpha=0.3, axis='y')
plt.tight_layout()
if save_path:
plt.savefig(save_path, dpi=300, bbox_inches='tight')
print(f'Similarity distribution plot saved to {save_path}')
plt.show()
# Print statistics
print(f'\n--- Similarity Score Statistics ---')
print(f'Same Domain:')
print(f' Mean: {same_domain.mean():.4f}')
print(f' Std: {same_domain.std():.4f}')
print(f' Min: {same_domain.min():.4f}')
print(f' Max: {same_domain.max():.4f}')
print(f'\nDifferent Domains:')
print(f' Mean: {diff_domain.mean():.4f}')
print(f' Std: {diff_domain.std():.4f}')
print(f' Min: {diff_domain.min():.4f}')
print(f' Max: {diff_domain.max():.4f}')
def visualize_sample_predictions(model: nn.Module, dataset, device: torch.device,
num_samples: int = 8, save_path: str = None):
"""Визуализация примеров предсказаний"""
model.eval()
# Get random samples
indices = np.random.choice(len(dataset), num_samples, replace=False)
fig, axes = plt.subplots(num_samples, 3, figsize=(15, 5*num_samples))
if num_samples == 1:
axes = axes.reshape(1, -1)
fig.suptitle('Sample Predictions - Siamese Network для корреляции карт',
fontsize=16, fontweight='bold')
with torch.no_grad():
for idx, sample_idx in enumerate(indices):
sample = dataset[sample_idx]
google_img = sample['google_img'].unsqueeze(0).to(device)
yandex_img = sample['yandex_img'].unsqueeze(0).to(device)
true_label = sample['same_domain'].item()
# Predict
pred_similarity = model(google_img, yandex_img).item()
pred_label = int(pred_similarity >= 0.5)
# Denormalize images for visualization
google_np = google_img.squeeze(0).cpu().numpy().transpose(1, 2, 0)
yandex_np = yandex_img.squeeze(0).cpu().numpy().transpose(1, 2, 0)
# Denormalize (assuming ImageNet normalization)
mean = np.array([0.485, 0.456, 0.406])
std = np.array([0.229, 0.224, 0.225])
google_np = std * google_np + mean
yandex_np = std * yandex_np + mean
google_np = np.clip(google_np, 0, 1)
yandex_np = np.clip(yandex_np, 0, 1)
# Plot Google image
axes[idx, 0].imshow(google_np)
axes[idx, 0].set_title('Google Map', fontsize=12, fontweight='bold')
axes[idx, 0].axis('off')
# Plot Yandex image
axes[idx, 1].imshow(yandex_np)
axes[idx, 1].set_title('Yandex Map', fontsize=12, fontweight='bold')
axes[idx, 1].axis('off')
# Plot prediction info
axes[idx, 2].axis('off')
# Determine color based on correctness
is_correct = (pred_label == true_label)
color = 'green' if is_correct else 'red'
result = '✓ Correct' if is_correct else '✗ Incorrect'
info_text = f"""
Prediction: {pred_similarity:.4f}
Predicted Label: {'Same' if pred_label == 1 else 'Different'}
True Label: {'Same' if true_label == 1 else 'Different'}
{result}
"""
axes[idx, 2].text(0.5, 0.5, info_text,
ha='center', va='center',
fontsize=12,
bbox=dict(boxstyle='round', facecolor=color, alpha=0.2),
transform=axes[idx, 2].transAxes)
plt.tight_layout()
if save_path:
plt.savefig(save_path, dpi=300, bbox_inches='tight')
print(f'Sample predictions saved to {save_path}')
plt.show()
def visualize_feature_space(model: nn.Module, dataloader: DataLoader,
device: torch.device, max_samples: int = 500,
save_path: str = None):
"""Визуализация пространства признаков с помощью t-SNE"""
from sklearn.manifold import TSNE
model.eval()
all_features_google = []
all_features_yandex = []
all_labels = []
with torch.no_grad():
for i, batch in enumerate(tqdm(dataloader, desc='Extracting features')):
if i * dataloader.batch_size >= max_samples:
break
google_img = batch['google_img'].to(device)
yandex_img = batch['yandex_img'].to(device)
labels = batch['same_domain'].cpu().numpy()
# Extract features
features_google = model.extract_features(google_img).cpu().numpy()
features_yandex = model.extract_features(yandex_img).cpu().numpy()
all_features_google.append(features_google)
all_features_yandex.append(features_yandex)
all_labels.append(labels)
all_features_google = np.concatenate(all_features_google, axis=0)
all_features_yandex = np.concatenate(all_features_yandex, axis=0)
all_labels = np.concatenate(all_labels, axis=0)
# Combine features
all_features = np.concatenate([all_features_google, all_features_yandex], axis=0)
all_labels = np.concatenate([all_labels, all_labels], axis=0)
print(f'\nApplying t-SNE to {all_features.shape[0]} samples...')
tsne = TSNE(n_components=2, random_state=42, perplexity=30)
features_2d = tsne.fit_transform(all_features)
# Split back into Google and Yandex
n_samples = len(all_labels) // 2
features_google_2d = features_2d[:n_samples]
features_yandex_2d = features_2d[n_samples:]
labels = all_labels[:n_samples]
# Plot
fig, axes = plt.subplots(1, 2, figsize=(20, 8))
# Google features
for label in [0, 1]:
mask = labels == label
axes[0].scatter(
features_google_2d[mask, 0],
features_google_2d[mask, 1],
c='green' if label == 1 else 'red',
label='Same Domain' if label == 1 else 'Different Domains',
alpha=0.6,
s=50
)
axes[0].set_title('Google Maps Features (t-SNE)', fontsize=14, fontweight='bold')
axes[0].set_xlabel('t-SNE Component 1', fontsize=12)
axes[0].set_ylabel('t-SNE Component 2', fontsize=12)
axes[0].legend(fontsize=11)
axes[0].grid(True, alpha=0.3)
# Yandex features
for label in [0, 1]:
mask = labels == label
axes[1].scatter(
features_yandex_2d[mask, 0],
features_yandex_2d[mask, 1],
c='green' if label == 1 else 'red',
label='Same Domain' if label == 1 else 'Different Domains',
alpha=0.6,
s=50
)
axes[1].set_title('Yandex Maps Features (t-SNE)', fontsize=14, fontweight='bold')
axes[1].set_xlabel('t-SNE Component 1', fontsize=12)
axes[1].set_ylabel('t-SNE Component 2', fontsize=12)
axes[1].legend(fontsize=11)
axes[1].grid(True, alpha=0.3)
plt.tight_layout()
if save_path:
plt.savefig(save_path, dpi=300, bbox_inches='tight')
print(f'Feature space visualization saved to {save_path}')
plt.show()
def generate_correlation_heatmap(model: nn.Module, dataloader: DataLoader,
device: torch.device, num_samples: int = 20,
save_path: str = None):
"""Создание тепловой карты корреляций между снимками"""
model.eval()
# Collect samples
google_images = []
yandex_images = []
labels = []
with torch.no_grad():
for i, batch in enumerate(dataloader):
if len(google_images) >= num_samples:
break
google_img = batch['google_img'].to(device)
yandex_img = batch['yandex_img'].to(device)
label = batch['same_domain']
google_images.append(google_img[:1])
yandex_images.append(yandex_img[:1])
labels.append(label[:1].item())
google_images = torch.cat(google_images[:num_samples], dim=0)
yandex_images = torch.cat(yandex_images[:num_samples], dim=0)
# Compute similarity matrix
similarity_matrix = np.zeros((num_samples, num_samples))
with torch.no_grad():
for i in tqdm(range(num_samples), desc='Computing correlations'):
for j in range(num_samples):
google_i = google_images[i:i+1]
yandex_j = yandex_images[j:j+1]
similarity = model(google_i, yandex_j).item()
similarity_matrix[i, j] = similarity
# Plot heatmap
plt.figure(figsize=(14, 12))
im = plt.imshow(similarity_matrix, cmap='RdYlGn', aspect='auto', vmin=0, vmax=1)
plt.colorbar(im, label='Similarity Score', fraction=0.046, pad=0.04)
plt.title('Correlation Heatmap: Google vs Yandex Maps\n(Матрица корреляций снимков)',
fontsize=16, fontweight='bold', pad=20)
plt.xlabel('Yandex Map Index', fontsize=12)
plt.ylabel('Google Map Index', fontsize=12)
# Add grid
plt.grid(True, which='both', color='gray', linestyle='-', linewidth=0.5, alpha=0.3)
# Add text annotations for diagonal (same pairs)
for i in range(min(num_samples, 10)): # Annotate first 10 for readability
if labels[i] == 1: # True match
plt.text(i, i, '', ha='center', va='center',
color='white', fontsize=12, fontweight='bold')
plt.tight_layout()
if save_path:
plt.savefig(save_path, dpi=300, bbox_inches='tight')
print(f'Correlation heatmap saved to {save_path}')
plt.show()
# Print statistics
diagonal = np.diag(similarity_matrix)
off_diagonal = similarity_matrix[~np.eye(num_samples, dtype=bool)]
print(f'\n--- Correlation Statistics ---')
print(f'Diagonal (matched pairs):')
print(f' Mean: {diagonal.mean():.4f}')
print(f' Std: {diagonal.std():.4f}')
print(f'\nOff-diagonal (mismatched pairs):')
print(f' Mean: {off_diagonal.mean():.4f}')
print(f' Std: {off_diagonal.std():.4f}')
# =============================================================================
# MAIN TRAINING SCRIPT
# =============================================================================
def main():
"""Основная функция обучения"""
# Configuration
config_dict = config.copy()
if isinstance(config_dict.get('image_size'), list):
config_dict['image_size'] = tuple(config_dict['image_size'])
# Device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'\n{"="*70}')
print(f'Siamese Network Training for Map Correlation')
print(f'Обучение сиамской сети для корреляции снимков')
print(f'{"="*70}')
print(f'Using device: {device}')
if torch.cuda.is_available():
print(f'GPU: {torch.cuda.get_device_name(0)}')
print(f'GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB')
# Create data loaders
print(f'\n{"="*70}')
print('Creating data loaders...')
print(f'{"="*70}')
train_loader, val_loader = create_data_loaders(
root_dir=config_dict['data_dir'],
batch_size=config_dict['batch_size'],
train_split=config_dict['train_split'],
num_workers=config_dict['num_workers'],
image_size=config_dict['image_size'],
augment_train=True,
augment_val=False,
device=device
)
print(f'Train batches: {len(train_loader)}')
print(f'Val batches: {len(val_loader)}')
print(f'Train samples: {len(train_loader.dataset)}')
print(f'Val samples: {len(val_loader.dataset)}')
# Create model
print(f'\n{"="*70}')
print('Creating model...')
print(f'{"="*70}')
model = create_similarity_model(
model_type='backbone',
input_size=config_dict['image_size'][0] if isinstance(config_dict['image_size'], (tuple, list)) else config_dict['image_size'],
input_channels=3,
backbone_name='resnet18',
pretrained=True,
dropout_rate=0.3,
use_batch_norm=True
)
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f'Total parameters: {total_params:,}')
print(f'Trainable parameters: {trainable_params:,}')
# Create trainer
trainer = SimilarityTrainer(
model=model,
trainloader=train_loader,
valloader=val_loader,
device=device,
config=config_dict
)
# Train model
print(f'\n{"="*70}')
print('Starting training...')
print(f'{"="*70}')
history = trainer.train(config_dict['epochs'])
# =============================================================================
# VISUALIZATION AND RESULTS
# =============================================================================
print(f'\n{"="*70}')
print('Generating visualizations...')
print(f'{"="*70}')
output_dir = config_dict.get('output_dir', 'runs/similarity')
vis_dir = os.path.join(output_dir, 'visualizations')
os.makedirs(vis_dir, exist_ok=True)
# 1. Training history
print('\n1. Plotting training history...')
plot_training_history(
history,
save_path=os.path.join(vis_dir, 'training_history.png')
)
# 2. Validation metrics
print('\n2. Computing validation predictions...')
trainer.model.eval()
val_predictions = []
val_targets = []
with torch.no_grad():
for batch in tqdm(val_loader, desc='Validation'):
google_img = batch['google_img'].to(device)
yandex_img = batch['yandex_img'].to(device)
target = batch['same_domain'].float().unsqueeze(1)
output = trainer.model(google_img, yandex_img)
val_predictions.append(output.cpu())
val_targets.append(target.cpu())
val_predictions = torch.cat(val_predictions, dim=0)
val_targets = torch.cat(val_targets, dim=0)
# 3. ROC curve
print('\n3. Plotting ROC curve...')
roc_auc = plot_roc_curve(
val_predictions,
val_targets,
save_path=os.path.join(vis_dir, 'roc_curve.png')
)
print(f'ROC AUC Score: {roc_auc:.4f}')
# 4. Confusion matrix
print('\n4. Plotting confusion matrix...')
plot_confusion_matrix(
val_predictions,
val_targets,
threshold=0.5,
save_path=os.path.join(vis_dir, 'confusion_matrix.png')
)
# 5. Similarity distribution
print('\n5. Plotting similarity distribution...')
plot_similarity_distribution(
val_predictions,
val_targets,
save_path=os.path.join(vis_dir, 'similarity_distribution.png')
)
# 6. Sample predictions
print('\n6. Visualizing sample predictions...')
visualize_sample_predictions(
trainer.model,
val_loader.dataset,
device,
num_samples=8,
save_path=os.path.join(vis_dir, 'sample_predictions.png')
)
# 7. Feature space visualization
print('\n7. Visualizing feature space (t-SNE)...')
visualize_feature_space(
trainer.model,
val_loader,
device,
max_samples=500,
save_path=os.path.join(vis_dir, 'feature_space_tsne.png')
)
# 8. Correlation heatmap
print('\n8. Generating correlation heatmap...')
generate_correlation_heatmap(
trainer.model,
val_loader,
device,
num_samples=20,
save_path=os.path.join(vis_dir, 'correlation_heatmap.png')
)
# =============================================================================
# FINAL RESULTS SUMMARY
# =============================================================================
print(f'\n{"="*70}')
print('FINAL RESULTS SUMMARY')
print('ИТОГОВЫЕ РЕЗУЛЬТАТЫ')
print(f'{"="*70}')
print(f'\nBest Validation Loss: {trainer.best_val_loss:.4f}')
print(f'Final Validation Accuracy: {history["val_accuracy"][-1]:.4f}')
print(f'Final Validation F1 Score: {history["val_f1"][-1]:.4f}')
print(f'Final Validation Precision: {history["val_precision"][-1]:.4f}')
print(f'Final Validation Recall: {history["val_recall"][-1]:.4f}')
print(f'ROC AUC Score: {roc_auc:.4f}')
print(f'\nCheckpoints saved to: {os.path.join(output_dir, "checkpoints")}')
print(f'Visualizations saved to: {vis_dir}')
print(f'\n{"="*70}')
print('Training and visualization completed successfully!')
print('Обучение и визуализация завершены успешно!')
print(f'{"="*70}\n')
if __name__ == '__main__':
main()

View File

@@ -0,0 +1,248 @@
"""
Training script for image similarity estimation.
"""
import os
import time
from datetime import datetime
import torch
import torch.nn as nn
import torch.optim as optim
from dataloader import config, create_data_loaders
from model import SimilarityCNN, SimilarityLoss, create_similarity_model
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm
class SimilarityTrainer:
def __init__(
self,
model: nn.Module,
train_loader: DataLoader,
val_loader: DataLoader,
device: torch.device,
config: dict,
):
self.model = model.to(device)
self.train_loader = train_loader
self.val_loader = val_loader
self.device = device
self.config = config
self.criterion = SimilarityLoss()
self.optimizer = optim.Adam(
model.parameters(),
lr=config.get("learning_rate", 2e-4),
betas=(config.get("beta1", 0.5), config.get("beta2", 0.999)),
)
self.writer = None
self.best_val_loss = float("inf")
self.epochs_without_improvement = 0
def train_epoch(self, epoch: int) -> dict:
self.model.train()
total_loss = 0
total_samples = 0
pbar = tqdm(self.train_loader, desc=f"Epoch {epoch}")
for batch_idx, batch in enumerate(pbar):
google_img = batch["google_img"].to(self.device)
yandex_img = batch["yandex_img"].to(self.device)
target = batch["same_domain"].float().to(self.device).unsqueeze(1)
self.optimizer.zero_grad()
output = self.model(google_img, yandex_img)
loss = self.criterion(output, target)
loss.backward()
self.optimizer.step()
total_loss += loss.item() * google_img.size(0)
total_samples += google_img.size(0)
if batch_idx % self.config.get("log_interval", 10) == 0:
metrics = self.criterion.compute_metrics(output, target)
pbar.set_postfix(
{
"loss": loss.item(),
"acc": metrics["accuracy"],
}
)
if self.writer:
self.writer.add_scalar(
"train/loss",
loss.item(),
epoch * len(self.train_loader) + batch_idx,
)
self.writer.add_scalar(
"train/accuracy",
metrics["accuracy"],
epoch * len(self.train_loader) + batch_idx,
)
avg_loss = total_loss / total_samples
return {"loss": avg_loss}
def validate(self) -> dict:
self.model.eval()
total_loss = 0
total_samples = 0
all_metrics = []
with torch.no_grad():
for batch in tqdm(self.val_loader, desc="Validation"):
google_img = batch["google_img"].to(self.device)
yandex_img = batch["yandex_img"].to(self.device)
target = batch["same_domain"].float().to(self.device).unsqueeze(1)
output = self.model(google_img, yandex_img)
loss = self.criterion(output, target)
total_loss += loss.item() * google_img.size(0)
total_samples += google_img.size(0)
metrics = self.criterion.compute_metrics(output, target)
all_metrics.append(metrics)
avg_loss = total_loss / total_samples
avg_metrics = {}
for key in all_metrics[0].keys():
avg_metrics[key] = sum(m[key] for m in all_metrics) / len(all_metrics)
return {"loss": avg_loss, **avg_metrics}
def train(self, num_epochs: int):
log_dir = self.config.get("output_dir", "runs/similarity")
os.makedirs(log_dir, exist_ok=True)
self.writer = SummaryWriter(log_dir)
print(f"Starting training for {num_epochs} epochs")
print(f"Logging to: {log_dir}")
for epoch in range(1, num_epochs + 1):
print(f"\nEpoch {epoch}/{num_epochs}")
train_metrics = self.train_epoch(epoch)
val_metrics = self.validate()
print(f"Train Loss: {train_metrics['loss']:.4f}")
print(f"Val Loss: {val_metrics['loss']:.4f}")
print(f"Val Accuracy: {val_metrics['accuracy']:.4f}")
print(f"Val F1: {val_metrics['f1']:.4f}")
if self.writer:
self.writer.add_scalar("epoch/train_loss", train_metrics["loss"], epoch)
self.writer.add_scalar("epoch/val_loss", val_metrics["loss"], epoch)
self.writer.add_scalar(
"epoch/val_accuracy", val_metrics["accuracy"], epoch
)
if val_metrics["loss"] < self.best_val_loss:
self.best_val_loss = val_metrics["loss"]
self.epochs_without_improvement = 0
self.save_checkpoint(epoch, val_metrics["loss"], is_best=True)
print(f"New best model saved with val loss: {val_metrics['loss']:.4f}")
else:
self.epochs_without_improvement += 1
self.save_checkpoint(epoch, val_metrics["loss"], is_best=False)
patience = self.config.get("early_stopping_patience", 20)
if self.epochs_without_improvement >= patience:
print(
f"Early stopping triggered after {patience} epochs without improvement"
)
break
self.writer.close()
def save_checkpoint(self, epoch: int, val_loss: float, is_best: bool = False):
checkpoint_dir = os.path.join(
self.config.get("output_dir", "runs/similarity"), "checkpoints"
)
os.makedirs(checkpoint_dir, exist_ok=True)
checkpoint = {
"epoch": epoch,
"model_state_dict": self.model.state_dict(),
"optimizer_state_dict": self.optimizer.state_dict(),
"val_loss": val_loss,
"config": self.config,
}
checkpoint_path = os.path.join(checkpoint_dir, f"checkpoint_epoch_{epoch}.pt")
torch.save(checkpoint, checkpoint_path)
if is_best:
best_path = os.path.join(checkpoint_dir, "best_model.pt")
torch.save(checkpoint, best_path)
def load_checkpoint(self, checkpoint_path: str):
checkpoint = torch.load(checkpoint_path, map_location=self.device)
self.model.load_state_dict(checkpoint["model_state_dict"])
self.optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
return checkpoint["epoch"], checkpoint["val_loss"]
def main():
# Use config from dataloader.py
config_dict = config.copy()
# Ensure image_size is tuple
if isinstance(config_dict.get("image_size"), list):
config_dict["image_size"] = tuple(config_dict["image_size"])
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
print("Creating data loaders...")
train_loader, val_loader = create_data_loaders(
root_dir=config_dict["data_dir"],
batch_size=config_dict["batch_size"],
train_split=config_dict["train_split"],
num_workers=config_dict["num_workers"],
image_size=config_dict["image_size"],
augment_train=True,
augment_val=False,
device=device,
)
print(f"Train batches: {len(train_loader)}")
print(f"Val batches: {len(val_loader)}")
print("Creating model...")
model = create_similarity_model(
model_type="cnn",
input_size=config_dict["image_size"][0]
if isinstance(config_dict["image_size"], (tuple, list))
else config_dict["image_size"],
input_channels=3,
hidden_channels=64,
num_blocks=4,
dropout_rate=0.3,
use_batch_norm=True,
)
print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")
trainer = SimilarityTrainer(
model=model,
train_loader=train_loader,
val_loader=val_loader,
device=device,
config=config_dict,
)
print("Starting training...")
trainer.train(config_dict["epochs"])
print("Training completed!")
if __name__ == "__main__":
main()