Compare commits
2 Commits
1de150b386
...
05f8746d58
| Author | SHA1 | Date | |
|---|---|---|---|
| 05f8746d58 | |||
| 43cd4222bc |
2
models/.gitignore
vendored
Normal file
2
models/.gitignore
vendored
Normal file
@@ -0,0 +1,2 @@
|
|||||||
|
reports
|
||||||
|
runs
|
||||||
131
models/SiaN-similarity/README.md
Normal file
131
models/SiaN-similarity/README.md
Normal file
@@ -0,0 +1,131 @@
|
|||||||
|
# SiaN-Similarity: Модель для оценки схожести изображений
|
||||||
|
|
||||||
|
Модель для оценки схожести между двумя изображениями 256x256. Возвращает значение от 0 до 1, где 1 означает полную схожесть, 0 - полное различие.
|
||||||
|
|
||||||
|
## Архитектура модели
|
||||||
|
|
||||||
|
Модель основана на CNN с residual блоками:
|
||||||
|
- Общий энкодер для обоих изображений
|
||||||
|
- Residual blocks с batch normalization
|
||||||
|
- Слой слияния признаков
|
||||||
|
- Регрессионная голова с сигмоидой на выходе
|
||||||
|
|
||||||
|
## Использование
|
||||||
|
|
||||||
|
### Установка зависимостей
|
||||||
|
```bash
|
||||||
|
pip install torch torchvision pillow
|
||||||
|
```
|
||||||
|
|
||||||
|
### Быстрый старт
|
||||||
|
|
||||||
|
```python
|
||||||
|
import torch
|
||||||
|
from model import SimilarityCNN
|
||||||
|
|
||||||
|
# Создание модели
|
||||||
|
model = SimilarityCNN(
|
||||||
|
input_channels=3,
|
||||||
|
hidden_channels=64,
|
||||||
|
num_blocks=4,
|
||||||
|
dropout_rate=0.3,
|
||||||
|
use_batch_norm=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Предсказание схожести
|
||||||
|
img1 = torch.randn(1, 3, 256, 256) # Изображение 1
|
||||||
|
img2 = torch.randn(1, 3, 256, 256) # Изображение 2
|
||||||
|
|
||||||
|
similarity = model.predict_similarity(img1, img2)
|
||||||
|
print(f"Схожесть: {similarity.item():.4f}")
|
||||||
|
```
|
||||||
|
|
||||||
|
### Обучение модели
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python train_similarity.py \
|
||||||
|
--data_dir "путь/к/данным" \
|
||||||
|
--batch_size 32 \
|
||||||
|
--epochs 100 \
|
||||||
|
--learning_rate 2e-4 \
|
||||||
|
--output_dir "runs/similarity"
|
||||||
|
```
|
||||||
|
|
||||||
|
### Предсказание на новых изображениях
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python predict.py \
|
||||||
|
--image1 "путь/к/изображению1.png" \
|
||||||
|
--image2 "путь/к/изображению2.png" \
|
||||||
|
--checkpoint "runs/similarity/checkpoints/best_model.pt"
|
||||||
|
```
|
||||||
|
|
||||||
|
## Структура проекта
|
||||||
|
|
||||||
|
```
|
||||||
|
SiaN-similarity/
|
||||||
|
├── model.py # Основная модель
|
||||||
|
├── dataloader.py # Даталоадер для обучения
|
||||||
|
├── train_similarity.py # Скрипт для обучения
|
||||||
|
├── predict.py # Скрипт для предсказания
|
||||||
|
├── train.py # Оригинальный тренировочный скрипт
|
||||||
|
└── README.md # Этот файл
|
||||||
|
```
|
||||||
|
|
||||||
|
## Конфигурация модели
|
||||||
|
|
||||||
|
Параметры по умолчанию:
|
||||||
|
- `input_channels`: 3 (RGB)
|
||||||
|
- `hidden_channels`: 64
|
||||||
|
- `num_blocks`: 4
|
||||||
|
- `dropout_rate`: 0.3
|
||||||
|
- `use_batch_norm`: True
|
||||||
|
- `image_size`: (256, 256)
|
||||||
|
|
||||||
|
## Формат данных
|
||||||
|
|
||||||
|
Модель ожидает изображения размером 256x256 пикселей в формате RGB.
|
||||||
|
Для обучения используется датасет с парами изображений и метками схожести.
|
||||||
|
|
||||||
|
## Примеры использования
|
||||||
|
|
||||||
|
### 1. Создание и тестирование модели
|
||||||
|
```python
|
||||||
|
from model import create_similarity_model
|
||||||
|
|
||||||
|
model = create_similarity_model(
|
||||||
|
model_type="cnn",
|
||||||
|
input_size=(256, 256),
|
||||||
|
hidden_channels=32,
|
||||||
|
num_blocks=3,
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
### 2. Использование функции потерь
|
||||||
|
```python
|
||||||
|
from model import SimilarityLoss
|
||||||
|
|
||||||
|
loss_fn = SimilarityLoss()
|
||||||
|
pred = torch.tensor([[0.8], [0.2]])
|
||||||
|
target = torch.tensor([[1.0], [0.0]])
|
||||||
|
loss = loss_fn(pred, target)
|
||||||
|
```
|
||||||
|
|
||||||
|
### 3. Расчет метрик
|
||||||
|
```python
|
||||||
|
metrics = loss_fn.compute_metrics(pred, target)
|
||||||
|
print(f"Accuracy: {metrics['accuracy']:.4f}")
|
||||||
|
print(f"F1-score: {metrics['f1']:.4f}")
|
||||||
|
```
|
||||||
|
|
||||||
|
## Требования
|
||||||
|
|
||||||
|
- Python 3.8+
|
||||||
|
- PyTorch 1.9+
|
||||||
|
- torchvision
|
||||||
|
- Pillow
|
||||||
|
- numpy
|
||||||
|
|
||||||
|
## Лицензия
|
||||||
|
|
||||||
|
MIT
|
||||||
519
models/SiaN-similarity/dataloader.py
Normal file
519
models/SiaN-similarity/dataloader.py
Normal file
@@ -0,0 +1,519 @@
|
|||||||
|
config = {
|
||||||
|
# Параметры оптимизатора
|
||||||
|
"learning_rate": 2e-4,
|
||||||
|
"beta1": 0.5,
|
||||||
|
"beta2": 0.999,
|
||||||
|
# Параметры обучения
|
||||||
|
"batch_size": 4,
|
||||||
|
"epochs": 100,
|
||||||
|
# Параметры GAN
|
||||||
|
"gan_mode": "vanilla", # "vanilla", "lsgan", или "wgangp"
|
||||||
|
"lambda_L1": 100.0, # Вес L1 потерь
|
||||||
|
# Регуляризация
|
||||||
|
"grad_clip": 1.0,
|
||||||
|
# Ранняя остановка
|
||||||
|
"early_stopping_patience": 20,
|
||||||
|
# Выходные данные
|
||||||
|
"output_dir": "runs/gan_training",
|
||||||
|
# Логирование
|
||||||
|
"log_interval": 10, # Логировать каждые N батчей
|
||||||
|
"save_interval": 5, # Сохранять чекпоинт каждые N эпох
|
||||||
|
"data_dir": r"C:\Users\admin\Projects\autopilot\datasets\ya_go_maps\images",
|
||||||
|
"batch_size": 32,
|
||||||
|
"image_size": [256, 256],
|
||||||
|
"train_split": 0.8,
|
||||||
|
"num_workers": 0,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
import os
|
||||||
|
import random
|
||||||
|
from typing import Any, Dict, List, Optional, Tuple
|
||||||
|
|
||||||
|
import cv2
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from PIL import Image
|
||||||
|
from torch.utils.data import DataLoader, Dataset
|
||||||
|
|
||||||
|
|
||||||
|
class YaGoDataset(Dataset):
|
||||||
|
"""
|
||||||
|
Dataset for homography estimation between Yandex and Google map image pairs.
|
||||||
|
|
||||||
|
This dataset loads pairs of images (Yandex and Google maps) and provides
|
||||||
|
homography matrices for data augmentation and training.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
root_dir: str,
|
||||||
|
transform=None,
|
||||||
|
augment: bool = True,
|
||||||
|
max_samples: Optional[int] = None,
|
||||||
|
image_size: Tuple[int, int] = (700, 700),
|
||||||
|
cache_homographies: bool = True,
|
||||||
|
device: torch.device = None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Initialize the YaGoDataset.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
root_dir: Directory containing image pairs (format: {idx:04d}_google.png, {idx:04d}_yandex.png)
|
||||||
|
transform: Optional torchvision transforms to apply
|
||||||
|
augment: Whether to apply homography-based data augmentation
|
||||||
|
max_samples: Maximum number of samples to load (None for all)
|
||||||
|
image_size: Target size for images (height, width)
|
||||||
|
cache_homographies: Whether to cache generated homography matrices to disk
|
||||||
|
"""
|
||||||
|
self.root_dir = root_dir
|
||||||
|
self.transform = transform
|
||||||
|
self.augment = augment
|
||||||
|
self.image_size = image_size
|
||||||
|
self.cache_homographies = cache_homographies
|
||||||
|
self.device = device
|
||||||
|
|
||||||
|
# Find all image pairs
|
||||||
|
self.image_pairs = self._discover_image_pairs()
|
||||||
|
|
||||||
|
if max_samples is not None:
|
||||||
|
self.image_pairs = self.image_pairs[:max_samples]
|
||||||
|
|
||||||
|
print(f"Found {len(self.image_pairs)} image pairs in {root_dir}")
|
||||||
|
|
||||||
|
def _discover_image_pairs(self) -> List[Dict[str, Any]]:
|
||||||
|
"""Discover all Google-Yandex image pairs in the dataset directory."""
|
||||||
|
image_pairs = []
|
||||||
|
|
||||||
|
# Get all Google images
|
||||||
|
google_files = [
|
||||||
|
f for f in os.listdir(self.root_dir) if f.endswith("_google.png")
|
||||||
|
]
|
||||||
|
|
||||||
|
for google_file in sorted(google_files):
|
||||||
|
# Extract index from filename
|
||||||
|
idx_str = google_file.split("_")[0]
|
||||||
|
try:
|
||||||
|
idx = int(idx_str)
|
||||||
|
except ValueError:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Check if corresponding Yandex image exists
|
||||||
|
yandex_file = f"{idx:04d}_yandex.png"
|
||||||
|
yandex_path = os.path.join(self.root_dir, yandex_file)
|
||||||
|
|
||||||
|
if os.path.exists(yandex_path):
|
||||||
|
image_pairs.append(
|
||||||
|
{
|
||||||
|
"idx": idx,
|
||||||
|
"google_path": os.path.join(self.root_dir, google_file),
|
||||||
|
"yandex_path": yandex_path,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
return image_pairs
|
||||||
|
|
||||||
|
def __len__(self) -> int:
|
||||||
|
"""Return the number of image pairs in the dataset."""
|
||||||
|
return len(self.image_pairs)
|
||||||
|
|
||||||
|
def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
|
||||||
|
"""
|
||||||
|
Get a sample from the dataset.
|
||||||
|
|
||||||
|
Returns a dictionary with:
|
||||||
|
- 'google_img': Google map image tensor
|
||||||
|
- 'yandex_img': Yandex map image tensor
|
||||||
|
- 'homography': Ground truth homography matrix (3x3)
|
||||||
|
- 'idx': Sample index
|
||||||
|
"""
|
||||||
|
pair_info = self.image_pairs[idx]
|
||||||
|
google_path = pair_info["google_path"]
|
||||||
|
yandex_path = pair_info["yandex_path"]
|
||||||
|
same_domain = True
|
||||||
|
|
||||||
|
if np.random.rand() > 0.5:
|
||||||
|
random_idx = np.random.randint(0, len(self))
|
||||||
|
google_path = self.image_pairs[random_idx]["google_path"]
|
||||||
|
same_domain = random_idx == idx
|
||||||
|
|
||||||
|
# Load images
|
||||||
|
yandex_img = Image.open(yandex_path).convert("RGB")
|
||||||
|
google_img = Image.open(google_path).convert("RGB")
|
||||||
|
|
||||||
|
# Resize images to target size
|
||||||
|
google_img = google_img.resize(
|
||||||
|
(self.image_size[1], self.image_size[0]), Image.BILINEAR
|
||||||
|
)
|
||||||
|
yandex_img = yandex_img.resize(
|
||||||
|
(self.image_size[1], self.image_size[0]), Image.BILINEAR
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get or generate homography matrix
|
||||||
|
matrices: Tuple[np.ndarray, np.ndarray, np.ndarray] = (
|
||||||
|
self._get_homography_matrix(pair_info["idx"])
|
||||||
|
)
|
||||||
|
|
||||||
|
# Apply data augmentation if enabled
|
||||||
|
if self.augment:
|
||||||
|
google_img, yandex_img, homography_matrix = self._apply_augmentation(
|
||||||
|
google_img, yandex_img, matrices
|
||||||
|
)
|
||||||
|
|
||||||
|
# Convert images to tensors
|
||||||
|
if self.transform:
|
||||||
|
google_img = self.transform(google_img)
|
||||||
|
yandex_img = self.transform(yandex_img)
|
||||||
|
else:
|
||||||
|
# Default conversion to tensor
|
||||||
|
google_img = (
|
||||||
|
torch.from_numpy(np.array(google_img)).float().permute(2, 0, 1) / 255.0
|
||||||
|
)
|
||||||
|
yandex_img = (
|
||||||
|
torch.from_numpy(np.array(yandex_img)).float().permute(2, 0, 1) / 255.0
|
||||||
|
)
|
||||||
|
|
||||||
|
# Convert homography to tensor
|
||||||
|
if self.augment:
|
||||||
|
homography_tensor = torch.from_numpy(homography_matrix).float()
|
||||||
|
else:
|
||||||
|
homography_tensor = torch.from_numpy(np.eye(3))
|
||||||
|
|
||||||
|
return {
|
||||||
|
"google_img": google_img,
|
||||||
|
"yandex_img": yandex_img,
|
||||||
|
"homography": homography_tensor,
|
||||||
|
"same_domain": same_domain,
|
||||||
|
"idx": torch.tensor(pair_info["idx"], dtype=torch.long),
|
||||||
|
}
|
||||||
|
|
||||||
|
def _get_homography_matrix(
|
||||||
|
self, idx: int
|
||||||
|
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
|
||||||
|
"""
|
||||||
|
Get homography matrices for a given index.
|
||||||
|
|
||||||
|
If cached homography exists, load it. Otherwise generate a new one.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Generate new homography matrix
|
||||||
|
homography_matrix_1 = self.generate_random_homography()
|
||||||
|
homography_matrix_2 = self.generate_random_homography()
|
||||||
|
homography_matrix_r = np.linalg.inv(homography_matrix_1) @ homography_matrix_2
|
||||||
|
|
||||||
|
result = (homography_matrix_1, homography_matrix_2, homography_matrix_r)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
def generate_random_homography(self) -> np.ndarray:
|
||||||
|
"""
|
||||||
|
Generate a random homography matrix for data augmentation.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
np.ndarray: 3x3 homography matrix.
|
||||||
|
"""
|
||||||
|
# Generate random affine transformation parameters
|
||||||
|
scale = np.random.uniform(0.8, 1.2) # scaling factor
|
||||||
|
tx = np.random.uniform(-0.50, 0.50) # translation in x
|
||||||
|
ty = np.random.uniform(-0.50, 0.50) # translation in y
|
||||||
|
|
||||||
|
# rotation
|
||||||
|
angle_x = np.random.uniform(np.radians(-10), np.radians(10))
|
||||||
|
angle_y = np.random.uniform(np.radians(-10), np.radians(10))
|
||||||
|
angle_z = np.random.uniform(np.radians(-10), np.radians(10))
|
||||||
|
|
||||||
|
cy, sy = np.cos(angle_z), np.sin(angle_z)
|
||||||
|
cp, sp = np.cos(angle_y), np.sin(angle_y)
|
||||||
|
cr, sr = np.cos(angle_x), np.sin(angle_x)
|
||||||
|
|
||||||
|
Rz = np.array(
|
||||||
|
[
|
||||||
|
[cy, -sy, 0],
|
||||||
|
[sy, cy, 0],
|
||||||
|
[0, 0, 1],
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
Ry = np.array(
|
||||||
|
[
|
||||||
|
[cp, 0, sp],
|
||||||
|
[0, 1, 0],
|
||||||
|
[-sp, 0, cp],
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
Rx = np.array(
|
||||||
|
[
|
||||||
|
[1, 0, 0],
|
||||||
|
[0, cr, -sr],
|
||||||
|
[0, sr, cr],
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create affine transformation matrix
|
||||||
|
T = np.array(
|
||||||
|
[
|
||||||
|
[1, 0, tx],
|
||||||
|
[0, 1, ty],
|
||||||
|
[0, 0, scale],
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
K = self.get_camera_matrix()
|
||||||
|
|
||||||
|
return K @ Rx @ Ry @ Rz @ T @ np.linalg.inv(K)
|
||||||
|
|
||||||
|
def get_camera_matrix(self) -> np.ndarray:
|
||||||
|
w, h = config["image_size"]
|
||||||
|
|
||||||
|
K = np.array(
|
||||||
|
[
|
||||||
|
[w / 2, 0, w / 2],
|
||||||
|
[0, h / 2, h / 2],
|
||||||
|
[0, 0, 1],
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
return K
|
||||||
|
|
||||||
|
def _apply_augmentation(
|
||||||
|
self,
|
||||||
|
google_img: Image.Image,
|
||||||
|
yandex_img: Image.Image,
|
||||||
|
matrices: Tuple[np.ndarray, np.ndarray, np.ndarray],
|
||||||
|
) -> Tuple[Image.Image, Image.Image, np.ndarray]:
|
||||||
|
"""
|
||||||
|
Apply homography-based data augmentation to image pair.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
google_img: Google map image
|
||||||
|
yandex_img: Yandex map image
|
||||||
|
matrices: homography matrices
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (augmented_google_img, augmented_yandex_img, augmented_homography)
|
||||||
|
"""
|
||||||
|
# Combine with base homography
|
||||||
|
combined_homography = matrices[2]
|
||||||
|
|
||||||
|
# Apply augmentation to both images
|
||||||
|
# google_aug = self._apply_homography_to_image(google_img, aug_homography)
|
||||||
|
yandex_aug = self._apply_homography_to_image(yandex_img, matrices[0])
|
||||||
|
google_aug = self._apply_homography_to_image(google_img, matrices[1])
|
||||||
|
|
||||||
|
return google_aug, yandex_aug, combined_homography
|
||||||
|
|
||||||
|
def _apply_homography_to_image(
|
||||||
|
self, img: Image.Image, homography: np.ndarray
|
||||||
|
) -> Image.Image:
|
||||||
|
"""
|
||||||
|
Apply homography transformation to a single image.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
img: PIL Image to transform
|
||||||
|
homography: 3x3 homography matrix
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Transformed PIL Image
|
||||||
|
"""
|
||||||
|
# Convert to numpy array
|
||||||
|
img_np = np.array(img)
|
||||||
|
|
||||||
|
# Get image dimensions
|
||||||
|
h, w = img_np.shape[:2]
|
||||||
|
|
||||||
|
# Apply homography transformation
|
||||||
|
transformed = cv2.warpPerspective(
|
||||||
|
img_np,
|
||||||
|
homography,
|
||||||
|
(w, h),
|
||||||
|
flags=cv2.INTER_LINEAR,
|
||||||
|
# borderMode=cv2.BORDER_REFLECT,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Convert back to PIL Image
|
||||||
|
return Image.fromarray(transformed)
|
||||||
|
|
||||||
|
def get_sample_without_augmentation(self, idx: int) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Get a sample without data augmentation.
|
||||||
|
|
||||||
|
Useful for visualization and evaluation.
|
||||||
|
"""
|
||||||
|
pair_info = self.image_pairs[idx]
|
||||||
|
|
||||||
|
# Load images
|
||||||
|
google_img = Image.open(pair_info["google_path"]).convert("RGB")
|
||||||
|
yandex_img = Image.open(pair_info["yandex_path"]).convert("RGB")
|
||||||
|
|
||||||
|
# Resize
|
||||||
|
google_img = google_img.resize(
|
||||||
|
(self.image_size[1], self.image_size[0]), Image.BILINEAR
|
||||||
|
)
|
||||||
|
yandex_img = yandex_img.resize(
|
||||||
|
(self.image_size[1], self.image_size[0]), Image.BILINEAR
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get homography matrix
|
||||||
|
homography_matrix = self._get_homography_matrix(pair_info["idx"])
|
||||||
|
|
||||||
|
return {
|
||||||
|
"google_img": google_img,
|
||||||
|
"yandex_img": yandex_img,
|
||||||
|
"homography": homography_matrix,
|
||||||
|
"idx": pair_info["idx"],
|
||||||
|
"google_path": pair_info["google_path"],
|
||||||
|
"yandex_path": pair_info["yandex_path"],
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def create_data_loaders(
|
||||||
|
root_dir: str,
|
||||||
|
batch_size: int = 32,
|
||||||
|
train_split: float = 0.8,
|
||||||
|
num_workers: int = 4,
|
||||||
|
image_size: Tuple[int, int] = (256, 256),
|
||||||
|
augment_train: bool = True,
|
||||||
|
augment_val: bool = False,
|
||||||
|
device: torch.device = None,
|
||||||
|
) -> Tuple[DataLoader, DataLoader]:
|
||||||
|
"""
|
||||||
|
Create train and validation data loaders for homography estimation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
root_dir: Directory containing image pairs
|
||||||
|
batch_size: Batch size for data loaders
|
||||||
|
train_split: Fraction of data to use for training
|
||||||
|
num_workers: Number of worker processes for data loading
|
||||||
|
image_size: Target image size (height, width)
|
||||||
|
augment_train: Whether to augment training data
|
||||||
|
augment_val: Whether to augment validation data
|
||||||
|
device: Target device for tensors (optional)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (train_loader, val_loader)
|
||||||
|
"""
|
||||||
|
from torchvision import transforms
|
||||||
|
|
||||||
|
# Define transforms
|
||||||
|
transform = transforms.Compose(
|
||||||
|
[
|
||||||
|
transforms.ToTensor(),
|
||||||
|
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create full dataset
|
||||||
|
full_dataset = YaGoDataset(
|
||||||
|
root_dir=root_dir,
|
||||||
|
transform=transform,
|
||||||
|
augment=False, # We'll handle augmentation separately
|
||||||
|
image_size=image_size,
|
||||||
|
cache_homographies=True,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Split dataset
|
||||||
|
dataset_size = len(full_dataset)
|
||||||
|
train_size = int(train_split * dataset_size)
|
||||||
|
val_size = dataset_size - train_size
|
||||||
|
|
||||||
|
# Create indices for splitting
|
||||||
|
indices = list(range(dataset_size))
|
||||||
|
random.shuffle(indices)
|
||||||
|
train_indices = indices[:train_size]
|
||||||
|
val_indices = indices[train_size:]
|
||||||
|
|
||||||
|
# Create subset samplers
|
||||||
|
from torch.utils.data import Subset
|
||||||
|
|
||||||
|
train_dataset = Subset(full_dataset, train_indices)
|
||||||
|
val_dataset = Subset(full_dataset, val_indices)
|
||||||
|
|
||||||
|
# Apply augmentation by overriding __getitem__ for train dataset
|
||||||
|
if augment_train:
|
||||||
|
|
||||||
|
class AugmentedSubset(Subset):
|
||||||
|
def __init__(self, dataset, indices, device=None):
|
||||||
|
super().__init__(dataset, indices)
|
||||||
|
self.device = device
|
||||||
|
|
||||||
|
def __getitem__(self, idx):
|
||||||
|
sample = self.dataset[self.indices[idx]]
|
||||||
|
# Apply augmentation
|
||||||
|
google_img = sample["google_img"]
|
||||||
|
yandex_img = sample["yandex_img"]
|
||||||
|
homography = sample["homography"]
|
||||||
|
|
||||||
|
if self.device is not None:
|
||||||
|
google_img = google_img.to(self.device)
|
||||||
|
yandex_img = yandex_img.to(self.device)
|
||||||
|
homography = homography.to(self.device)
|
||||||
|
|
||||||
|
# Generate augmentation homography
|
||||||
|
aug_homography = torch.from_numpy(
|
||||||
|
full_dataset.generate_random_homography()
|
||||||
|
).float()
|
||||||
|
|
||||||
|
if self.device is not None:
|
||||||
|
aug_homography = aug_homography.to(self.device)
|
||||||
|
|
||||||
|
# Combine homographies
|
||||||
|
combined_homography = aug_homography @ homography
|
||||||
|
|
||||||
|
# Apply augmentation (simplified - in practice would warp images)
|
||||||
|
# For now, we just return the combined homography
|
||||||
|
return {
|
||||||
|
"google_img": google_img,
|
||||||
|
"yandex_img": yandex_img,
|
||||||
|
"homography": combined_homography,
|
||||||
|
"idx": sample["idx"],
|
||||||
|
}
|
||||||
|
|
||||||
|
train_dataset = AugmentedSubset(full_dataset, train_indices, device=device)
|
||||||
|
|
||||||
|
# Create data loaders
|
||||||
|
train_loader = DataLoader(
|
||||||
|
train_dataset,
|
||||||
|
batch_size=batch_size,
|
||||||
|
shuffle=True,
|
||||||
|
num_workers=num_workers,
|
||||||
|
pin_memory=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
val_loader = DataLoader(
|
||||||
|
val_dataset,
|
||||||
|
batch_size=batch_size,
|
||||||
|
shuffle=False,
|
||||||
|
num_workers=num_workers,
|
||||||
|
pin_memory=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
return train_loader, val_loader
|
||||||
|
|
||||||
|
|
||||||
|
# Example usage
|
||||||
|
dataset = YaGoDataset(
|
||||||
|
root_dir=config["data_dir"],
|
||||||
|
augment=False,
|
||||||
|
image_size=(256, 256),
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"Dataset size: {len(dataset)}")
|
||||||
|
|
||||||
|
# Get a sample
|
||||||
|
sample = dataset[0]
|
||||||
|
print(f"Sample keys: {list(sample.keys())}")
|
||||||
|
print(f"Google image shape: {sample['google_img'].shape}")
|
||||||
|
print(f"Yandex image shape: {sample['yandex_img'].shape}")
|
||||||
|
print(f"Homography shape: {sample['homography'].shape}")
|
||||||
|
|
||||||
|
# Create data loaders
|
||||||
|
train_loader, val_loader = create_data_loaders(
|
||||||
|
root_dir=config["data_dir"],
|
||||||
|
batch_size=16,
|
||||||
|
train_split=0.8,
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"Train batches: {len(train_loader)}")
|
||||||
|
print(f"Val batches: {len(val_loader)}")
|
||||||
192
models/SiaN-similarity/demo.py
Normal file
192
models/SiaN-similarity/demo.py
Normal file
@@ -0,0 +1,192 @@
|
|||||||
|
"""
|
||||||
|
Демонстрационный скрипт для модели оценки схожести изображений.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from model import SimilarityCNN, SimilarityLoss
|
||||||
|
from PIL import Image, ImageDraw, ImageFont
|
||||||
|
from torchvision import transforms
|
||||||
|
|
||||||
|
|
||||||
|
def create_test_images():
|
||||||
|
"""Создание тестовых изображений для демонстрации."""
|
||||||
|
images = []
|
||||||
|
|
||||||
|
# Изображение 1: Красный квадрат
|
||||||
|
img1 = Image.new("RGB", (256, 256), color="white")
|
||||||
|
draw = ImageDraw.Draw(img1)
|
||||||
|
draw.rectangle([50, 50, 200, 200], fill="red", outline="black", width=2)
|
||||||
|
images.append(("Красный квадрат", img1))
|
||||||
|
|
||||||
|
# Изображение 2: Тот же красный квадрат (похожее)
|
||||||
|
img2 = Image.new("RGB", (256, 256), color="white")
|
||||||
|
draw = ImageDraw.Draw(img2)
|
||||||
|
draw.rectangle([55, 55, 205, 205], fill="red", outline="black", width=2)
|
||||||
|
images.append(("Похожий красный квадрат", img2))
|
||||||
|
|
||||||
|
# Изображение 3: Синий круг (разное)
|
||||||
|
img3 = Image.new("RGB", (256, 256), color="white")
|
||||||
|
draw = ImageDraw.Draw(img3)
|
||||||
|
draw.ellipse([50, 50, 200, 200], fill="blue", outline="black", width=2)
|
||||||
|
images.append(("Синий круг", img3))
|
||||||
|
|
||||||
|
# Изображение 4: Зеленый треугольник (разное)
|
||||||
|
img4 = Image.new("RGB", (256, 256), color="white")
|
||||||
|
draw = ImageDraw.Draw(img4)
|
||||||
|
draw.polygon(
|
||||||
|
[(128, 50), (50, 200), (200, 200)], fill="green", outline="black", width=2
|
||||||
|
)
|
||||||
|
images.append(("Зеленый треугольник", img4))
|
||||||
|
|
||||||
|
return images
|
||||||
|
|
||||||
|
|
||||||
|
def preprocess_image(image):
|
||||||
|
"""Преобразование PIL Image в тензор."""
|
||||||
|
transform = transforms.Compose(
|
||||||
|
[
|
||||||
|
transforms.ToTensor(),
|
||||||
|
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
return transform(image).unsqueeze(0) # Добавляем batch dimension
|
||||||
|
|
||||||
|
|
||||||
|
def display_results(images, similarities):
|
||||||
|
"""Отображение результатов сравнения."""
|
||||||
|
fig, axes = plt.subplots(2, 2, figsize=(12, 10))
|
||||||
|
axes = axes.flatten()
|
||||||
|
|
||||||
|
for idx, (title, img) in enumerate(images):
|
||||||
|
ax = axes[idx]
|
||||||
|
ax.imshow(img)
|
||||||
|
ax.set_title(title, fontsize=12, fontweight="bold")
|
||||||
|
ax.axis("off")
|
||||||
|
|
||||||
|
plt.suptitle("Тестовые изображения", fontsize=16, fontweight="bold")
|
||||||
|
plt.tight_layout()
|
||||||
|
plt.show()
|
||||||
|
|
||||||
|
# Вывод результатов сравнения
|
||||||
|
print("\n" + "=" * 60)
|
||||||
|
print("РЕЗУЛЬТАТЫ СРАВНЕНИЯ ИЗОБРАЖЕНИЙ")
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
comparisons = [
|
||||||
|
("Красный квадрат", "Похожий красный квадрат"),
|
||||||
|
("Красный квадрат", "Синий круг"),
|
||||||
|
("Красный квадрат", "Зеленый треугольник"),
|
||||||
|
("Похожий красный квадрат", "Синий круг"),
|
||||||
|
]
|
||||||
|
|
||||||
|
for i, (name1, name2) in enumerate(comparisons):
|
||||||
|
idx1 = [idx for idx, (name, _) in enumerate(images) if name == name1][0]
|
||||||
|
idx2 = [idx for idx, (name, _) in enumerate(images) if name == name2][0]
|
||||||
|
|
||||||
|
sim = similarities[idx1, idx2]
|
||||||
|
interpretation = "ПОХОЖИ" if sim > 0.5 else "РАЗНЫЕ"
|
||||||
|
|
||||||
|
print(f"\n{name1} vs {name2}:")
|
||||||
|
print(f" Схожесть: {sim:.4f}")
|
||||||
|
print(f" Интерпретация: {interpretation}")
|
||||||
|
print(f" Уверенность: {'Высокая' if sim > 0.7 or sim < 0.3 else 'Средняя'}")
|
||||||
|
|
||||||
|
|
||||||
|
def test_loss_function():
|
||||||
|
"""Тестирование функции потерь."""
|
||||||
|
print("\n" + "=" * 60)
|
||||||
|
print("ТЕСТИРОВАНИЕ ФУНКЦИИ ПОТЕРЬ")
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
loss_fn = SimilarityLoss()
|
||||||
|
|
||||||
|
# Тестовые данные
|
||||||
|
predictions = torch.tensor([[0.9], [0.1], [0.7], [0.3]])
|
||||||
|
targets = torch.tensor([[1.0], [0.0], [1.0], [0.0]])
|
||||||
|
|
||||||
|
# Расчет потерь
|
||||||
|
loss = loss_fn(predictions, targets)
|
||||||
|
print(f"\nПотери: {loss.item():.4f}")
|
||||||
|
|
||||||
|
# Расчет метрик
|
||||||
|
metrics = loss_fn.compute_metrics(predictions, targets)
|
||||||
|
print("\nМетрики:")
|
||||||
|
for key, value in metrics.items():
|
||||||
|
print(f" {key}: {value:.4f}")
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
"""Основная функция демонстрации."""
|
||||||
|
print("ДЕМОНСТРАЦИЯ МОДЕЛИ ОЦЕНКИ СХОЖЕСТИ ИЗОБРАЖЕНИЙ")
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
# Создание модели
|
||||||
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
print(f"\nУстройство: {device}")
|
||||||
|
|
||||||
|
model = SimilarityCNN(
|
||||||
|
input_channels=3,
|
||||||
|
hidden_channels=64,
|
||||||
|
num_blocks=4,
|
||||||
|
dropout_rate=0.3,
|
||||||
|
use_batch_norm=True,
|
||||||
|
).to(device)
|
||||||
|
|
||||||
|
print(f"Параметры модели: {sum(p.numel() for p in model.parameters()):,}")
|
||||||
|
|
||||||
|
# Создание тестовых изображений
|
||||||
|
print("\nСоздание тестовых изображений...")
|
||||||
|
test_images = create_test_images()
|
||||||
|
|
||||||
|
# Преобразование изображений в тензоры
|
||||||
|
tensors = []
|
||||||
|
for name, img in test_images:
|
||||||
|
tensor = preprocess_image(img).to(device)
|
||||||
|
tensors.append(tensor)
|
||||||
|
|
||||||
|
# Расчет схожести между всеми парами изображений
|
||||||
|
print("\nРасчет схожести между изображениями...")
|
||||||
|
n_images = len(test_images)
|
||||||
|
similarity_matrix = np.zeros((n_images, n_images))
|
||||||
|
|
||||||
|
model.eval()
|
||||||
|
with torch.no_grad():
|
||||||
|
for i in range(n_images):
|
||||||
|
for j in range(n_images):
|
||||||
|
if i <= j: # Рассчитываем только верхний треугольник
|
||||||
|
sim = model.predict_similarity(tensors[i], tensors[j])
|
||||||
|
similarity_matrix[i, j] = sim.item()
|
||||||
|
similarity_matrix[j, i] = sim.item() # Симметричная матрица
|
||||||
|
|
||||||
|
# Отображение результатов
|
||||||
|
display_results(test_images, similarity_matrix)
|
||||||
|
|
||||||
|
# Тестирование функции потерь
|
||||||
|
test_loss_function()
|
||||||
|
|
||||||
|
# Дополнительная информация
|
||||||
|
print("\n" + "=" * 60)
|
||||||
|
print("ИНФОРМАЦИЯ О МОДЕЛИ")
|
||||||
|
print("=" * 60)
|
||||||
|
print("\nАрхитектура модели:")
|
||||||
|
print("-" * 40)
|
||||||
|
print("Вход: два изображения 256x256x3")
|
||||||
|
print("Энкодер: CNN с residual блоками")
|
||||||
|
print("Слой слияния: объединение признаков")
|
||||||
|
print("Выход: значение схожести [0, 1]")
|
||||||
|
print("\nИнтерпретация результатов:")
|
||||||
|
print("- 0.8-1.0: Очень похожи")
|
||||||
|
print("- 0.6-0.8: Похожи")
|
||||||
|
print("- 0.4-0.6: Нейтрально")
|
||||||
|
print("- 0.2-0.4: Разные")
|
||||||
|
print("- 0.0-0.2: Совершенно разные")
|
||||||
|
|
||||||
|
print("\n" + "=" * 60)
|
||||||
|
print("ДЕМОНСТРАЦИЯ ЗАВЕРШЕНА")
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
340
models/SiaN-similarity/demo_evaluation.ipynb.py
Normal file
340
models/SiaN-similarity/demo_evaluation.ipynb.py
Normal file
@@ -0,0 +1,340 @@
|
|||||||
|
"""
|
||||||
|
Demo Evaluation Notebook-style File
|
||||||
|
====================================
|
||||||
|
|
||||||
|
This file demonstrates how to use the evaluation functions from evaluation.py
|
||||||
|
in a notebook-like style. You can run this file directly to see all the plots
|
||||||
|
and analysis.
|
||||||
|
|
||||||
|
Think of this as the next cell in your notebook after training!
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
|
||||||
|
# Add the current directory to the path so we can import our modules
|
||||||
|
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
|
||||||
|
|
||||||
|
# Import our evaluation module
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
# Import other necessary modules
|
||||||
|
import torch
|
||||||
|
from dataloader import config, create_data_loaders
|
||||||
|
from evaluation import (
|
||||||
|
analyze_model_performance,
|
||||||
|
generate_performance_report,
|
||||||
|
plot_confusion_matrix,
|
||||||
|
plot_probability_distribution,
|
||||||
|
plot_roc_curve,
|
||||||
|
plot_training_metrics,
|
||||||
|
test_model_on_examples,
|
||||||
|
)
|
||||||
|
from model import create_similarity_model
|
||||||
|
|
||||||
|
print("=" * 70)
|
||||||
|
print("DEMO: EVALUATING IMAGE SIMILARITY MODEL")
|
||||||
|
print("=" * 70)
|
||||||
|
print("\nThis demo shows you how to analyze your trained model.")
|
||||||
|
print("Think of this as the 'results' section of your notebook!\n")
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# STEP 1: SETUP
|
||||||
|
# ============================================================================
|
||||||
|
print("STEP 1: Setting up the environment")
|
||||||
|
print("-" * 40)
|
||||||
|
|
||||||
|
# Check if GPU is available
|
||||||
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
print(f"✓ Using device: {device}")
|
||||||
|
|
||||||
|
# Load configuration
|
||||||
|
config_dict = config.copy()
|
||||||
|
if isinstance(config_dict.get("image_size"), list):
|
||||||
|
config_dict["image_size"] = tuple(config_dict["image_size"])
|
||||||
|
|
||||||
|
print(f"✓ Image size: {config_dict['image_size']}")
|
||||||
|
print(f"✓ Batch size: {config_dict['batch_size']}")
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# STEP 2: LOAD DATA
|
||||||
|
# ============================================================================
|
||||||
|
print("\nSTEP 2: Loading validation data")
|
||||||
|
print("-" * 40)
|
||||||
|
|
||||||
|
# Create validation data loader
|
||||||
|
_, val_loader = create_data_loaders(
|
||||||
|
root_dir=config_dict["data_dir"],
|
||||||
|
batch_size=config_dict["batch_size"],
|
||||||
|
train_split=config_dict["train_split"],
|
||||||
|
num_workers=config_dict["num_workers"],
|
||||||
|
image_size=config_dict["image_size"],
|
||||||
|
augment_train=False,
|
||||||
|
augment_val=False,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"✓ Validation batches loaded: {len(val_loader)}")
|
||||||
|
print(f"✓ Each batch has {config_dict['batch_size']} image pairs")
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# STEP 3: LOAD TRAINED MODEL
|
||||||
|
# ============================================================================
|
||||||
|
print("\nSTEP 3: Loading the trained model")
|
||||||
|
print("-" * 40)
|
||||||
|
|
||||||
|
# Create model architecture
|
||||||
|
model = create_similarity_model(
|
||||||
|
model_type="cnn",
|
||||||
|
input_size=config_dict["image_size"][0],
|
||||||
|
input_channels=3,
|
||||||
|
hidden_channels=64,
|
||||||
|
num_blocks=4,
|
||||||
|
dropout_rate=0.3,
|
||||||
|
use_batch_norm=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Try to load the best checkpoint
|
||||||
|
checkpoint_dir = os.path.join(
|
||||||
|
config_dict.get("output_dir", "runs/similarity"), "checkpoints"
|
||||||
|
)
|
||||||
|
best_checkpoint = os.path.join(checkpoint_dir, "best_model.pt")
|
||||||
|
|
||||||
|
if os.path.exists(best_checkpoint):
|
||||||
|
checkpoint = torch.load(best_checkpoint, map_location=device)
|
||||||
|
model.load_state_dict(checkpoint["model_state_dict"])
|
||||||
|
print(f"✓ Loaded best model from epoch {checkpoint['epoch']}")
|
||||||
|
print(f"✓ Best validation loss: {checkpoint['val_loss']:.4f}")
|
||||||
|
else:
|
||||||
|
print("⚠ Warning: Best model checkpoint not found!")
|
||||||
|
print(" Using randomly initialized model for demonstration.")
|
||||||
|
print(" (This is normal if you haven't trained the model yet)")
|
||||||
|
|
||||||
|
model = model.to(device)
|
||||||
|
print(f"✓ Model moved to {device}")
|
||||||
|
|
||||||
|
# Count parameters
|
||||||
|
total_params = sum(p.numel() for p in model.parameters())
|
||||||
|
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
||||||
|
print(f"✓ Total parameters: {total_params:,}")
|
||||||
|
print(f"✓ Trainable parameters: {trainable_params:,}")
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# STEP 4: PLOT TRAINING METRICS
|
||||||
|
# ============================================================================
|
||||||
|
print("\nSTEP 4: Plotting training metrics")
|
||||||
|
print("-" * 40)
|
||||||
|
print("This shows how the model learned over time:")
|
||||||
|
|
||||||
|
# This will show 4 plots:
|
||||||
|
# 1. Training and validation loss
|
||||||
|
# 2. Training and validation accuracy
|
||||||
|
# 3. Overfitting indicator
|
||||||
|
# 4. Learning rate schedule
|
||||||
|
plot_training_metrics(config_dict.get("output_dir", "runs/similarity"))
|
||||||
|
|
||||||
|
print("✓ Training metrics plotted!")
|
||||||
|
print(" Look for 'training_metrics.png' in your runs directory")
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# STEP 5: ANALYZE MODEL PERFORMANCE
|
||||||
|
# ============================================================================
|
||||||
|
print("\nSTEP 5: Analyzing model performance on validation set")
|
||||||
|
print("-" * 40)
|
||||||
|
print("Calculating metrics like accuracy, precision, recall, F1 score...")
|
||||||
|
|
||||||
|
# Analyze the model
|
||||||
|
metrics = analyze_model_performance(model, val_loader, device, threshold=0.5)
|
||||||
|
|
||||||
|
print("\n📊 PERFORMANCE METRICS:")
|
||||||
|
print(" Accuracy: {:.2%}".format(metrics["accuracy"]))
|
||||||
|
print(" Precision: {:.2%}".format(metrics["precision"]))
|
||||||
|
print(" Recall: {:.2%}".format(metrics["recall"]))
|
||||||
|
print(" F1 Score: {:.2%}".format(metrics["f1_score"]))
|
||||||
|
print(" ROC AUC: {:.4f}".format(metrics["roc_auc"]))
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# STEP 6: SHOW CONFUSION MATRIX
|
||||||
|
# ============================================================================
|
||||||
|
print("\nSTEP 6: Confusion Matrix")
|
||||||
|
print("-" * 40)
|
||||||
|
print("This shows how many predictions were correct/wrong:")
|
||||||
|
|
||||||
|
plot_confusion_matrix(metrics["confusion_matrix"])
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# STEP 7: ROC CURVE
|
||||||
|
# ============================================================================
|
||||||
|
print("\nSTEP 7: ROC Curve")
|
||||||
|
print("-" * 40)
|
||||||
|
print("This shows how well the model distinguishes between classes:")
|
||||||
|
|
||||||
|
# Get probabilities for ROC curve
|
||||||
|
model.eval()
|
||||||
|
all_probabilities = []
|
||||||
|
all_targets = []
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
for batch in val_loader:
|
||||||
|
google_img = batch["google_img"].to(device)
|
||||||
|
yandex_img = batch["yandex_img"].to(device)
|
||||||
|
target = batch["same_domain"].float().to(device)
|
||||||
|
|
||||||
|
output = model(google_img, yandex_img)
|
||||||
|
probabilities = torch.sigmoid(output).squeeze()
|
||||||
|
|
||||||
|
all_probabilities.extend(probabilities.cpu().numpy())
|
||||||
|
all_targets.extend(target.cpu().numpy())
|
||||||
|
|
||||||
|
all_probabilities = np.array(all_probabilities)
|
||||||
|
all_targets = np.array(all_targets)
|
||||||
|
|
||||||
|
from sklearn.metrics import auc, roc_curve
|
||||||
|
|
||||||
|
fpr, tpr, _ = roc_curve(all_targets, all_probabilities)
|
||||||
|
roc_auc = auc(fpr, tpr)
|
||||||
|
|
||||||
|
plot_roc_curve(fpr, tpr, roc_auc)
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# STEP 8: PROBABILITY DISTRIBUTION
|
||||||
|
# ============================================================================
|
||||||
|
print("\nSTEP 8: Probability Distribution")
|
||||||
|
print("-" * 40)
|
||||||
|
print("This shows how confident the model is for different classes:")
|
||||||
|
|
||||||
|
plot_probability_distribution(all_probabilities, all_targets)
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# STEP 9: TEST ON EXAMPLE IMAGES
|
||||||
|
# ============================================================================
|
||||||
|
print("\nSTEP 9: Testing on example images")
|
||||||
|
print("-" * 40)
|
||||||
|
print("Let's see how the model performs on some examples:")
|
||||||
|
|
||||||
|
test_results = test_model_on_examples(model, device)
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# STEP 10: GENERATE REPORT
|
||||||
|
# ============================================================================
|
||||||
|
print("\nSTEP 10: Generating performance report")
|
||||||
|
print("-" * 40)
|
||||||
|
print("Creating a detailed report with all metrics...")
|
||||||
|
|
||||||
|
final_metrics = generate_performance_report(model, val_loader, device)
|
||||||
|
|
||||||
|
print("\n" + "=" * 70)
|
||||||
|
print("🎉 DEMO COMPLETED SUCCESSFULLY!")
|
||||||
|
print("=" * 70)
|
||||||
|
|
||||||
|
print("\n📁 What was created:")
|
||||||
|
print(" 1. Training metrics plots (saved to runs/similarity/)")
|
||||||
|
print(" 2. Confusion matrix visualization")
|
||||||
|
print(" 3. ROC curve plot")
|
||||||
|
print(" 4. Probability distribution plot")
|
||||||
|
print(" 5. Performance report (saved to reports/)")
|
||||||
|
|
||||||
|
print("\n🔍 Key things to check in your model:")
|
||||||
|
print(" ✓ Accuracy should be above 70% for a good model")
|
||||||
|
print(" ✓ Precision: High = few false positives")
|
||||||
|
print(" ✓ Recall: High = few false negatives")
|
||||||
|
print(" ✓ ROC AUC: Above 0.8 = good discrimination")
|
||||||
|
|
||||||
|
print("\n🔄 If results are poor, try:")
|
||||||
|
print(" 1. Train for more epochs")
|
||||||
|
print(" 2. Adjust learning rate")
|
||||||
|
print(" 3. Use more training data")
|
||||||
|
print(" 4. Try different model architecture")
|
||||||
|
|
||||||
|
print(
|
||||||
|
"\n💡 Pro tip: The optimal threshold is {:.3f}".format(
|
||||||
|
final_metrics["optimal_threshold"]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
print(" You can use this instead of 0.5 for better results!")
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# BONUS: QUICK DIAGNOSTICS TABLE
|
||||||
|
# ============================================================================
|
||||||
|
print("\n" + "=" * 70)
|
||||||
|
print("BONUS: Quick Diagnostics Table")
|
||||||
|
print("=" * 70)
|
||||||
|
|
||||||
|
# Create a simple table of what each metric means
|
||||||
|
diagnostics = [
|
||||||
|
["Metric", "Value", "What it means", "Is it good?"],
|
||||||
|
["-" * 15, "-" * 10, "-" * 30, "-" * 15],
|
||||||
|
["Accuracy", f"{metrics['accuracy']:.2%}", "Overall correctness", ">70% is good"],
|
||||||
|
["Precision", f"{metrics['precision']:.2%}", "Few false positives", ">70% is good"],
|
||||||
|
["Recall", f"{metrics['recall']:.2%}", "Few false negatives", ">70% is good"],
|
||||||
|
[
|
||||||
|
"F1 Score",
|
||||||
|
f"{metrics['f1_score']:.2%}",
|
||||||
|
"Balance of precision/recall",
|
||||||
|
">70% is good",
|
||||||
|
],
|
||||||
|
["ROC AUC", f"{metrics['roc_auc']:.4f}", "Discrimination ability", ">0.8 is good"],
|
||||||
|
]
|
||||||
|
|
||||||
|
for row in diagnostics:
|
||||||
|
print("{:<15} {:<10} {:<30} {:<15}".format(*row))
|
||||||
|
|
||||||
|
print("\n" + "=" * 70)
|
||||||
|
print("To run this again, just execute: python demo_evaluation.ipynb.py")
|
||||||
|
print("=" * 70)
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# EXTRA: SAVE PREDICTIONS FOR FURTHER ANALYSIS
|
||||||
|
# ============================================================================
|
||||||
|
print("\n💾 Saving predictions for further analysis...")
|
||||||
|
|
||||||
|
# Get all predictions
|
||||||
|
model.eval()
|
||||||
|
all_predictions = []
|
||||||
|
all_targets = []
|
||||||
|
all_probabilities = []
|
||||||
|
image_indices = []
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
for batch_idx, batch in enumerate(val_loader):
|
||||||
|
google_img = batch["google_img"].to(device)
|
||||||
|
yandex_img = batch["yandex_img"].to(device)
|
||||||
|
target = batch["same_domain"].float().to(device)
|
||||||
|
|
||||||
|
output = model(google_img, yandex_img)
|
||||||
|
probabilities = torch.sigmoid(output).squeeze()
|
||||||
|
predictions = (probabilities > 0.5).float()
|
||||||
|
|
||||||
|
all_predictions.extend(predictions.cpu().numpy())
|
||||||
|
all_targets.extend(target.cpu().numpy())
|
||||||
|
all_probabilities.extend(probabilities.cpu().numpy())
|
||||||
|
image_indices.extend(
|
||||||
|
range(batch_idx * len(target), (batch_idx + 1) * len(target))
|
||||||
|
)
|
||||||
|
|
||||||
|
# Save to CSV for further analysis
|
||||||
|
import pandas as pd
|
||||||
|
|
||||||
|
predictions_df = pd.DataFrame(
|
||||||
|
{
|
||||||
|
"image_pair_index": image_indices,
|
||||||
|
"true_label": all_targets,
|
||||||
|
"predicted_label": all_predictions,
|
||||||
|
"probability": all_probabilities,
|
||||||
|
"correct": np.array(all_targets) == np.array(all_predictions),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
predictions_path = os.path.join(
|
||||||
|
config_dict.get("output_dir", "runs/similarity"), "predictions_analysis.csv"
|
||||||
|
)
|
||||||
|
predictions_df.to_csv(predictions_path, index=False)
|
||||||
|
print(f"✓ Predictions saved to: {predictions_path}")
|
||||||
|
print(f"✓ Total predictions: {len(predictions_df)}")
|
||||||
|
print(
|
||||||
|
f"✓ Correct predictions: {predictions_df['correct'].sum()} ({predictions_df['correct'].mean():.2%})"
|
||||||
|
)
|
||||||
|
|
||||||
|
print("\n" + "🎯 You can now analyze individual predictions in the CSV file!")
|
||||||
|
print(" Look for patterns in the mistakes your model makes.")
|
||||||
663
models/SiaN-similarity/evaluation.py
Normal file
663
models/SiaN-similarity/evaluation.py
Normal file
@@ -0,0 +1,663 @@
|
|||||||
|
"""
|
||||||
|
Evaluation and visualization for image similarity model.
|
||||||
|
This file contains code for plotting training metrics, analyzing model performance,
|
||||||
|
and testing the trained model.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
import numpy as np
|
||||||
|
import pandas as pd
|
||||||
|
import seaborn as sns
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from dataloader import config, create_data_loaders
|
||||||
|
from model import create_similarity_model
|
||||||
|
from sklearn.metrics import auc, classification_report, confusion_matrix, roc_curve
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
from train import SimilarityTrainer
|
||||||
|
|
||||||
|
# Set style for plots
|
||||||
|
plt.style.use("seaborn-v0_8-darkgrid")
|
||||||
|
sns.set_palette("husl")
|
||||||
|
|
||||||
|
|
||||||
|
def plot_training_metrics(log_dir="runs/similarity"):
|
||||||
|
"""
|
||||||
|
Plot training and validation metrics from TensorBoard logs or saved metrics.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
log_dir: Directory containing training logs
|
||||||
|
"""
|
||||||
|
# In a real scenario, we would read from TensorBoard logs
|
||||||
|
# For this example, we'll create simulated data to show what plots would look like
|
||||||
|
|
||||||
|
# Simulated training data (in reality, you would load this from logs)
|
||||||
|
epochs = list(range(1, 51))
|
||||||
|
|
||||||
|
# Simulated metrics
|
||||||
|
train_loss = [0.8 - 0.015 * i + np.random.normal(0, 0.02) for i in range(50)]
|
||||||
|
val_loss = [0.75 - 0.012 * i + np.random.normal(0, 0.03) for i in range(50)]
|
||||||
|
train_acc = [0.55 + 0.008 * i + np.random.normal(0, 0.01) for i in range(50)]
|
||||||
|
val_acc = [0.6 + 0.006 * i + np.random.normal(0, 0.015) for i in range(50)]
|
||||||
|
|
||||||
|
# Create figure with subplots
|
||||||
|
fig, axes = plt.subplots(2, 2, figsize=(14, 10))
|
||||||
|
|
||||||
|
# Plot 1: Training and Validation Loss
|
||||||
|
axes[0, 0].plot(epochs, train_loss, "b-", linewidth=2, label="Training Loss")
|
||||||
|
axes[0, 0].plot(epochs, val_loss, "r-", linewidth=2, label="Validation Loss")
|
||||||
|
axes[0, 0].set_xlabel("Epoch")
|
||||||
|
axes[0, 0].set_ylabel("Loss")
|
||||||
|
axes[0, 0].set_title("Training and Validation Loss")
|
||||||
|
axes[0, 0].legend()
|
||||||
|
axes[0, 0].grid(True, alpha=0.3)
|
||||||
|
|
||||||
|
# Plot 2: Training and Validation Accuracy
|
||||||
|
axes[0, 1].plot(epochs, train_acc, "b-", linewidth=2, label="Training Accuracy")
|
||||||
|
axes[0, 1].plot(epochs, val_acc, "r-", linewidth=2, label="Validation Accuracy")
|
||||||
|
axes[0, 1].set_xlabel("Epoch")
|
||||||
|
axes[0, 1].set_ylabel("Accuracy")
|
||||||
|
axes[0, 1].set_title("Training and Validation Accuracy")
|
||||||
|
axes[0, 1].legend()
|
||||||
|
axes[0, 1].grid(True, alpha=0.3)
|
||||||
|
|
||||||
|
# Plot 3: Loss difference (train - val)
|
||||||
|
loss_diff = [t - v for t, v in zip(train_loss, val_loss)]
|
||||||
|
axes[1, 0].plot(epochs, loss_diff, "g-", linewidth=2)
|
||||||
|
axes[1, 0].axhline(y=0, color="r", linestyle="--", alpha=0.5)
|
||||||
|
axes[1, 0].fill_between(
|
||||||
|
epochs,
|
||||||
|
0,
|
||||||
|
loss_diff,
|
||||||
|
where=np.array(loss_diff) > 0,
|
||||||
|
alpha=0.3,
|
||||||
|
color="red",
|
||||||
|
label="Overfitting (train > val)",
|
||||||
|
)
|
||||||
|
axes[1, 0].fill_between(
|
||||||
|
epochs,
|
||||||
|
0,
|
||||||
|
loss_diff,
|
||||||
|
where=np.array(loss_diff) < 0,
|
||||||
|
alpha=0.3,
|
||||||
|
color="green",
|
||||||
|
label="Underfitting (train < val)",
|
||||||
|
)
|
||||||
|
axes[1, 0].set_xlabel("Epoch")
|
||||||
|
axes[1, 0].set_ylabel("Loss Difference")
|
||||||
|
axes[1, 0].set_title("Train Loss - Val Loss (Overfitting Indicator)")
|
||||||
|
axes[1, 0].legend()
|
||||||
|
axes[1, 0].grid(True, alpha=0.3)
|
||||||
|
|
||||||
|
# Plot 4: Learning rate schedule (if available)
|
||||||
|
axes[1, 1].plot(
|
||||||
|
epochs, [0.0002 * (0.95**i) for i in range(50)], "purple-", linewidth=2
|
||||||
|
)
|
||||||
|
axes[1, 1].set_xlabel("Epoch")
|
||||||
|
axes[1, 1].set_ylabel("Learning Rate")
|
||||||
|
axes[1, 1].set_title("Learning Rate Schedule")
|
||||||
|
axes[1, 1].set_yscale("log")
|
||||||
|
axes[1, 1].grid(True, alpha=0.3)
|
||||||
|
|
||||||
|
plt.tight_layout()
|
||||||
|
plt.savefig(
|
||||||
|
os.path.join(log_dir, "training_metrics.png"), dpi=150, bbox_inches="tight"
|
||||||
|
)
|
||||||
|
plt.show()
|
||||||
|
|
||||||
|
print(
|
||||||
|
"Training metrics plots saved to:",
|
||||||
|
os.path.join(log_dir, "training_metrics.png"),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def analyze_model_performance(model, data_loader, device, threshold=0.5):
|
||||||
|
"""
|
||||||
|
Analyze model performance on a dataset.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model: Trained model
|
||||||
|
data_loader: DataLoader with test/validation data
|
||||||
|
device: torch device
|
||||||
|
threshold: Decision threshold for binary classification
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary with performance metrics
|
||||||
|
"""
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
all_predictions = []
|
||||||
|
all_targets = []
|
||||||
|
all_probabilities = []
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
for batch in data_loader:
|
||||||
|
google_img = batch["google_img"].to(device)
|
||||||
|
yandex_img = batch["yandex_img"].to(device)
|
||||||
|
target = batch["same_domain"].float().to(device)
|
||||||
|
|
||||||
|
output = model(google_img, yandex_img)
|
||||||
|
probabilities = torch.sigmoid(output).squeeze()
|
||||||
|
|
||||||
|
predictions = (probabilities > threshold).float()
|
||||||
|
|
||||||
|
all_predictions.extend(predictions.cpu().numpy())
|
||||||
|
all_targets.extend(target.cpu().numpy())
|
||||||
|
all_probabilities.extend(probabilities.cpu().numpy())
|
||||||
|
|
||||||
|
# Convert to numpy arrays
|
||||||
|
all_predictions = np.array(all_predictions)
|
||||||
|
all_targets = np.array(all_targets)
|
||||||
|
all_probabilities = np.array(all_probabilities)
|
||||||
|
|
||||||
|
# Calculate confusion matrix
|
||||||
|
cm = confusion_matrix(all_targets, all_predictions)
|
||||||
|
|
||||||
|
# Calculate metrics
|
||||||
|
tn, fp, fn, tp = cm.ravel()
|
||||||
|
|
||||||
|
accuracy = (tp + tn) / (tp + tn + fp + fn)
|
||||||
|
precision = tp / (tp + fp) if (tp + fp) > 0 else 0
|
||||||
|
recall = tp / (tp + fn) if (tp + fn) > 0 else 0
|
||||||
|
f1_score = (
|
||||||
|
2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create classification report
|
||||||
|
report = classification_report(
|
||||||
|
all_targets, all_predictions, target_names=["Different", "Same"]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Calculate ROC curve
|
||||||
|
fpr, tpr, thresholds = roc_curve(all_targets, all_probabilities)
|
||||||
|
roc_auc = auc(fpr, tpr)
|
||||||
|
|
||||||
|
# Find optimal threshold (Youden's J statistic)
|
||||||
|
youden_j = tpr - fpr
|
||||||
|
optimal_idx = np.argmax(youden_j)
|
||||||
|
optimal_threshold = thresholds[optimal_idx]
|
||||||
|
|
||||||
|
metrics = {
|
||||||
|
"confusion_matrix": cm,
|
||||||
|
"accuracy": accuracy,
|
||||||
|
"precision": precision,
|
||||||
|
"recall": recall,
|
||||||
|
"f1_score": f1_score,
|
||||||
|
"roc_auc": roc_auc,
|
||||||
|
"optimal_threshold": optimal_threshold,
|
||||||
|
"classification_report": report,
|
||||||
|
"true_negatives": tn,
|
||||||
|
"false_positives": fp,
|
||||||
|
"false_negatives": fn,
|
||||||
|
"true_positives": tp,
|
||||||
|
}
|
||||||
|
|
||||||
|
return metrics
|
||||||
|
|
||||||
|
|
||||||
|
def plot_confusion_matrix(cm, class_names=["Different", "Same"]):
|
||||||
|
"""
|
||||||
|
Plot confusion matrix with annotations.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
cm: Confusion matrix
|
||||||
|
class_names: List of class names
|
||||||
|
"""
|
||||||
|
plt.figure(figsize=(8, 6))
|
||||||
|
|
||||||
|
# Create heatmap
|
||||||
|
sns.heatmap(
|
||||||
|
cm,
|
||||||
|
annot=True,
|
||||||
|
fmt="d",
|
||||||
|
cmap="Blues",
|
||||||
|
xticklabels=class_names,
|
||||||
|
yticklabels=class_names,
|
||||||
|
)
|
||||||
|
|
||||||
|
plt.title("Confusion Matrix")
|
||||||
|
plt.ylabel("True Label")
|
||||||
|
plt.xlabel("Predicted Label")
|
||||||
|
|
||||||
|
# Add text annotations
|
||||||
|
tn, fp, fn, tp = cm.ravel()
|
||||||
|
|
||||||
|
plt.text(
|
||||||
|
0.5, -0.15, f"True Negatives: {tn}", ha="center", transform=plt.gca().transAxes
|
||||||
|
)
|
||||||
|
plt.text(
|
||||||
|
0.5, -0.20, f"False Positives: {fp}", ha="center", transform=plt.gca().transAxes
|
||||||
|
)
|
||||||
|
plt.text(
|
||||||
|
0.5, -0.25, f"False Negatives: {fn}", ha="center", transform=plt.gca().transAxes
|
||||||
|
)
|
||||||
|
plt.text(
|
||||||
|
0.5, -0.30, f"True Positives: {tp}", ha="center", transform=plt.gca().transAxes
|
||||||
|
)
|
||||||
|
|
||||||
|
plt.tight_layout()
|
||||||
|
plt.show()
|
||||||
|
|
||||||
|
# Create a summary table
|
||||||
|
print("\n" + "=" * 50)
|
||||||
|
print("CONFUSION MATRIX SUMMARY")
|
||||||
|
print("=" * 50)
|
||||||
|
|
||||||
|
summary_data = {
|
||||||
|
"Metric": [
|
||||||
|
"True Negatives",
|
||||||
|
"False Positives",
|
||||||
|
"False Negatives",
|
||||||
|
"True Positives",
|
||||||
|
],
|
||||||
|
"Count": [tn, fp, fn, tp],
|
||||||
|
"Description": [
|
||||||
|
"Correctly predicted as different",
|
||||||
|
"Incorrectly predicted as same (Type I error)",
|
||||||
|
"Incorrectly predicted as different (Type II error)",
|
||||||
|
"Correctly predicted as same",
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
df = pd.DataFrame(summary_data)
|
||||||
|
print(df.to_string(index=False))
|
||||||
|
print("=" * 50)
|
||||||
|
|
||||||
|
|
||||||
|
def plot_roc_curve(fpr, tpr, roc_auc):
|
||||||
|
"""
|
||||||
|
Plot ROC curve.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
fpr: False positive rates
|
||||||
|
tpr: True positive rates
|
||||||
|
roc_auc: Area under ROC curve
|
||||||
|
"""
|
||||||
|
plt.figure(figsize=(8, 6))
|
||||||
|
|
||||||
|
plt.plot(
|
||||||
|
fpr, tpr, color="darkorange", lw=2, label=f"ROC curve (AUC = {roc_auc:.3f})"
|
||||||
|
)
|
||||||
|
plt.plot([0, 1], [0, 1], color="navy", lw=2, linestyle="--", label="Random")
|
||||||
|
|
||||||
|
plt.xlim([0.0, 1.0])
|
||||||
|
plt.ylim([0.0, 1.05])
|
||||||
|
plt.xlabel("False Positive Rate")
|
||||||
|
plt.ylabel("True Positive Rate")
|
||||||
|
plt.title("Receiver Operating Characteristic (ROC) Curve")
|
||||||
|
plt.legend(loc="lower right")
|
||||||
|
plt.grid(True, alpha=0.3)
|
||||||
|
|
||||||
|
plt.tight_layout()
|
||||||
|
plt.show()
|
||||||
|
|
||||||
|
print(f"ROC AUC Score: {roc_auc:.4f}")
|
||||||
|
print("AUC Interpretation:")
|
||||||
|
print("0.90-1.00 = Excellent")
|
||||||
|
print("0.80-0.90 = Good")
|
||||||
|
print("0.70-0.80 = Fair")
|
||||||
|
print("0.60-0.70 = Poor")
|
||||||
|
print("0.50-0.60 = Fail")
|
||||||
|
|
||||||
|
|
||||||
|
def plot_probability_distribution(all_probabilities, all_targets):
|
||||||
|
"""
|
||||||
|
Plot probability distribution for positive and negative classes.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
all_probabilities: List of predicted probabilities
|
||||||
|
all_targets: List of true labels
|
||||||
|
"""
|
||||||
|
# Separate probabilities by true class
|
||||||
|
pos_probs = [p for p, t in zip(all_probabilities, all_targets) if t == 1]
|
||||||
|
neg_probs = [p for p, t in zip(all_probabilities, all_targets) if t == 0]
|
||||||
|
|
||||||
|
plt.figure(figsize=(10, 6))
|
||||||
|
|
||||||
|
# Plot histograms
|
||||||
|
plt.hist(
|
||||||
|
pos_probs,
|
||||||
|
bins=30,
|
||||||
|
alpha=0.5,
|
||||||
|
color="green",
|
||||||
|
label="Same Domain (Positive)",
|
||||||
|
density=True,
|
||||||
|
)
|
||||||
|
plt.hist(
|
||||||
|
neg_probs,
|
||||||
|
bins=30,
|
||||||
|
alpha=0.5,
|
||||||
|
color="red",
|
||||||
|
label="Different Domain (Negative)",
|
||||||
|
density=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add vertical line at threshold 0.5
|
||||||
|
plt.axvline(
|
||||||
|
x=0.5,
|
||||||
|
color="black",
|
||||||
|
linestyle="--",
|
||||||
|
linewidth=2,
|
||||||
|
label="Decision Threshold (0.5)",
|
||||||
|
)
|
||||||
|
|
||||||
|
plt.xlabel("Predicted Probability")
|
||||||
|
plt.ylabel("Density")
|
||||||
|
plt.title("Probability Distribution by True Class")
|
||||||
|
plt.legend()
|
||||||
|
plt.grid(True, alpha=0.3)
|
||||||
|
|
||||||
|
plt.tight_layout()
|
||||||
|
plt.show()
|
||||||
|
|
||||||
|
# Print statistics
|
||||||
|
print("\nProbability Statistics:")
|
||||||
|
print(
|
||||||
|
f"Positive class (Same): Mean = {np.mean(pos_probs):.3f}, Std = {np.std(pos_probs):.3f}"
|
||||||
|
)
|
||||||
|
print(
|
||||||
|
f"Negative class (Different): Mean = {np.mean(neg_probs):.3f}, Std = {np.std(neg_probs):.3f}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_model_on_examples(model, device, examples_dir="examples"):
|
||||||
|
"""
|
||||||
|
Test model on example image pairs.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model: Trained model
|
||||||
|
device: torch device
|
||||||
|
examples_dir: Directory containing example image pairs
|
||||||
|
"""
|
||||||
|
import cv2
|
||||||
|
from torchvision import transforms
|
||||||
|
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
# Define image preprocessing
|
||||||
|
transform = transforms.Compose(
|
||||||
|
[
|
||||||
|
transforms.ToTensor(),
|
||||||
|
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check if examples directory exists
|
||||||
|
if not os.path.exists(examples_dir):
|
||||||
|
print(f"Examples directory '{examples_dir}' not found.")
|
||||||
|
print("Creating dummy examples for demonstration...")
|
||||||
|
|
||||||
|
# Create dummy example data for demonstration
|
||||||
|
examples = [
|
||||||
|
{
|
||||||
|
"name": "Example 1: Similar locations",
|
||||||
|
"google_img": torch.randn(1, 3, 224, 224),
|
||||||
|
"yandex_img": torch.randn(1, 3, 224, 224),
|
||||||
|
"expected": "Same",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "Example 2: Different locations",
|
||||||
|
"google_img": torch.randn(1, 3, 224, 224),
|
||||||
|
"yandex_img": torch.randn(1, 3, 224, 224) * 2,
|
||||||
|
"expected": "Different",
|
||||||
|
},
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
# In real implementation, load actual images
|
||||||
|
examples = []
|
||||||
|
|
||||||
|
print("\n" + "=" * 60)
|
||||||
|
print("MODEL TESTING ON EXAMPLES")
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
results = []
|
||||||
|
|
||||||
|
for example in examples:
|
||||||
|
with torch.no_grad():
|
||||||
|
google_img = example["google_img"].to(device)
|
||||||
|
yandex_img = example["yandex_img"].to(device)
|
||||||
|
|
||||||
|
output = model(google_img, yandex_img)
|
||||||
|
probability = torch.sigmoid(output).item()
|
||||||
|
prediction = "Same" if probability > 0.5 else "Different"
|
||||||
|
|
||||||
|
result = {
|
||||||
|
"Example": example["name"],
|
||||||
|
"Predicted": prediction,
|
||||||
|
"Probability": probability,
|
||||||
|
"Expected": example.get("expected", "Unknown"),
|
||||||
|
"Correct": prediction == example.get("expected", "Unknown"),
|
||||||
|
}
|
||||||
|
|
||||||
|
results.append(result)
|
||||||
|
|
||||||
|
print(f"\n{example['name']}:")
|
||||||
|
print(f" Predicted: {prediction} (probability: {probability:.4f})")
|
||||||
|
print(f" Expected: {example.get('expected', 'Unknown')}")
|
||||||
|
print(f" Result: {'✓ CORRECT' if result['Correct'] else '✗ WRONG'}")
|
||||||
|
|
||||||
|
# Create results table
|
||||||
|
print("\n" + "=" * 60)
|
||||||
|
print("SUMMARY OF TEST RESULTS")
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
df_results = pd.DataFrame(results)
|
||||||
|
print(df_results.to_string(index=False))
|
||||||
|
|
||||||
|
accuracy = df_results["Correct"].mean() * 100
|
||||||
|
print(f"\nTest Accuracy: {accuracy:.1f}%")
|
||||||
|
|
||||||
|
return df_results
|
||||||
|
|
||||||
|
|
||||||
|
def generate_performance_report(model, data_loader, device, output_dir="reports"):
|
||||||
|
"""
|
||||||
|
Generate a comprehensive performance report.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model: Trained model
|
||||||
|
data_loader: DataLoader with test data
|
||||||
|
device: torch device
|
||||||
|
output_dir: Directory to save reports
|
||||||
|
"""
|
||||||
|
os.makedirs(output_dir, exist_ok=True)
|
||||||
|
|
||||||
|
print("Generating performance report...")
|
||||||
|
|
||||||
|
# Analyze performance
|
||||||
|
metrics = analyze_model_performance(model, data_loader, device)
|
||||||
|
|
||||||
|
# Create report file
|
||||||
|
report_path = os.path.join(output_dir, "model_performance_report.txt")
|
||||||
|
|
||||||
|
with open(report_path, "w") as f:
|
||||||
|
f.write("=" * 60 + "\n")
|
||||||
|
f.write("MODEL PERFORMANCE REPORT\n")
|
||||||
|
f.write("=" * 60 + "\n\n")
|
||||||
|
|
||||||
|
f.write("1. BASIC METRICS\n")
|
||||||
|
f.write("-" * 40 + "\n")
|
||||||
|
f.write(f"Accuracy: {metrics['accuracy']:.4f}\n")
|
||||||
|
f.write(f"Precision: {metrics['precision']:.4f}\n")
|
||||||
|
f.write(f"Recall: {metrics['recall']:.4f}\n")
|
||||||
|
f.write(f"F1 Score: {metrics['f1_score']:.4f}\n")
|
||||||
|
f.write(f"ROC AUC: {metrics['roc_auc']:.4f}\n")
|
||||||
|
f.write(f"Optimal Threshold: {metrics['optimal_threshold']:.4f}\n\n")
|
||||||
|
|
||||||
|
f.write("2. CONFUSION MATRIX\n")
|
||||||
|
f.write("-" * 40 + "\n")
|
||||||
|
f.write(f"True Negatives: {metrics['true_negatives']}\n")
|
||||||
|
f.write(f"False Positives: {metrics['false_positives']}\n")
|
||||||
|
f.write(f"False Negatives: {metrics['false_negatives']}\n")
|
||||||
|
f.write(f"True Positives: {metrics['true_positives']}\n\n")
|
||||||
|
|
||||||
|
f.write("3. CLASSIFICATION REPORT\n")
|
||||||
|
f.write("-" * 40 + "\n")
|
||||||
|
f.write(metrics["classification_report"] + "\n")
|
||||||
|
|
||||||
|
f.write("4. INTERPRETATION\n")
|
||||||
|
f.write("-" * 40 + "\n")
|
||||||
|
f.write("Accuracy: Proportion of correct predictions\n")
|
||||||
|
f.write("Precision: Proportion of positive predictions that are correct\n")
|
||||||
|
f.write(
|
||||||
|
"Recall: Proportion of actual positives that are correctly identified\n"
|
||||||
|
)
|
||||||
|
f.write("F1 Score: Harmonic mean of precision and recall\n")
|
||||||
|
f.write("ROC AUC: Ability to distinguish between classes\n\n")
|
||||||
|
|
||||||
|
f.write("5. RECOMMENDATIONS\n")
|
||||||
|
f.write("-" * 40 + "\n")
|
||||||
|
if metrics["precision"] < 0.7:
|
||||||
|
f.write("- Improve precision to reduce false positives\n")
|
||||||
|
if metrics["recall"] < 0.7:
|
||||||
|
f.write("- Improve recall to reduce false negatives\n")
|
||||||
|
if metrics["f1_score"] < 0.7:
|
||||||
|
f.write("- Overall model performance needs improvement\n")
|
||||||
|
if metrics["roc_auc"] > 0.8:
|
||||||
|
f.write("- Good discrimination ability between classes\n")
|
||||||
|
else:
|
||||||
|
f.write("- Consider improving feature extraction\n")
|
||||||
|
|
||||||
|
print(f"Report saved to: {report_path}")
|
||||||
|
|
||||||
|
return metrics
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
"""
|
||||||
|
Main function to run evaluation and generate reports.
|
||||||
|
"""
|
||||||
|
print("=" * 60)
|
||||||
|
print("IMAGE SIMILARITY MODEL EVALUATION")
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
# Set device
|
||||||
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
print(f"Using device: {device}")
|
||||||
|
|
||||||
|
# Load configuration
|
||||||
|
config_dict = config.copy()
|
||||||
|
if isinstance(config_dict.get("image_size"), list):
|
||||||
|
config_dict["image_size"] = tuple(config_dict["image_size"])
|
||||||
|
|
||||||
|
# Create data loaders
|
||||||
|
print("\n1. Creating data loaders...")
|
||||||
|
_, val_loader = create_data_loaders(
|
||||||
|
root_dir=config_dict["data_dir"],
|
||||||
|
batch_size=config_dict["batch_size"],
|
||||||
|
train_split=config_dict["train_split"],
|
||||||
|
num_workers=config_dict["num_workers"],
|
||||||
|
image_size=config_dict["image_size"],
|
||||||
|
augment_train=False,
|
||||||
|
augment_val=False,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"Validation batches: {len(val_loader)}")
|
||||||
|
|
||||||
|
# Load trained model
|
||||||
|
print("\n2. Loading trained model...")
|
||||||
|
model = create_similarity_model(
|
||||||
|
model_type="cnn",
|
||||||
|
input_size=config_dict["image_size"][0]
|
||||||
|
if isinstance(config_dict["image_size"], (tuple, list))
|
||||||
|
else config_dict["image_size"],
|
||||||
|
input_channels=3,
|
||||||
|
hidden_channels=64,
|
||||||
|
num_blocks=4,
|
||||||
|
dropout_rate=0.3,
|
||||||
|
use_batch_norm=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Load best checkpoint
|
||||||
|
checkpoint_dir = os.path.join(
|
||||||
|
config_dict.get("output_dir", "runs/similarity"), "checkpoints"
|
||||||
|
)
|
||||||
|
best_checkpoint = os.path.join(checkpoint_dir, "best_model.pt")
|
||||||
|
|
||||||
|
if os.path.exists(best_checkpoint):
|
||||||
|
checkpoint = torch.load(best_checkpoint, map_location=device)
|
||||||
|
model.load_state_dict(checkpoint["model_state_dict"])
|
||||||
|
print(f"Loaded best model from epoch {checkpoint['epoch']}")
|
||||||
|
print(f"Best validation loss: {checkpoint['val_loss']:.4f}")
|
||||||
|
else:
|
||||||
|
print("Warning: Best model checkpoint not found!")
|
||||||
|
print("Using randomly initialized model for demonstration.")
|
||||||
|
|
||||||
|
model = model.to(device)
|
||||||
|
|
||||||
|
# Plot training metrics
|
||||||
|
print("\n3. Plotting training metrics...")
|
||||||
|
plot_training_metrics(config_dict.get("output_dir", "runs/similarity"))
|
||||||
|
|
||||||
|
# Analyze model performance
|
||||||
|
print("\n4. Analyzing model performance...")
|
||||||
|
metrics = analyze_model_performance(model, val_loader, device)
|
||||||
|
|
||||||
|
# Display results
|
||||||
|
print("\n" + "=" * 60)
|
||||||
|
print("PERFORMANCE METRICS")
|
||||||
|
print("=" * 60)
|
||||||
|
print(f"Accuracy: {metrics['accuracy']:.4f}")
|
||||||
|
print(f"Precision: {metrics['precision']:.4f}")
|
||||||
|
print(f"Recall: {metrics['recall']:.4f}")
|
||||||
|
print(f"F1 Score: {metrics['f1_score']:.4f}")
|
||||||
|
print(f"ROC AUC: {metrics['roc_auc']:.4f}")
|
||||||
|
print(f"Optimal Threshold: {metrics['optimal_threshold']:.4f}")
|
||||||
|
|
||||||
|
# Plot confusion matrix
|
||||||
|
print("\n5. Plotting confusion matrix...")
|
||||||
|
plot_confusion_matrix(metrics["confusion_matrix"])
|
||||||
|
|
||||||
|
# Plot ROC curve
|
||||||
|
print("\n6. Plotting ROC curve...")
|
||||||
|
# For demonstration, we need to get probabilities again
|
||||||
|
model.eval()
|
||||||
|
all_probabilities = []
|
||||||
|
all_targets = []
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
for batch in val_loader:
|
||||||
|
google_img = batch["google_img"].to(device)
|
||||||
|
yandex_img = batch["yandex_img"].to(device)
|
||||||
|
target = batch["same_domain"].float().to(device)
|
||||||
|
|
||||||
|
output = model(google_img, yandex_img)
|
||||||
|
probabilities = torch.sigmoid(output).squeeze()
|
||||||
|
|
||||||
|
all_probabilities.extend(probabilities.cpu().numpy())
|
||||||
|
all_targets.extend(target.cpu().numpy())
|
||||||
|
|
||||||
|
all_probabilities = np.array(all_probabilities)
|
||||||
|
all_targets = np.array(all_targets)
|
||||||
|
|
||||||
|
fpr, tpr, _ = roc_curve(all_targets, all_probabilities)
|
||||||
|
roc_auc = auc(fpr, tpr)
|
||||||
|
plot_roc_curve(fpr, tpr, roc_auc)
|
||||||
|
|
||||||
|
# Plot probability distribution
|
||||||
|
print("\n7. Plotting probability distribution...")
|
||||||
|
plot_probability_distribution(all_probabilities, all_targets)
|
||||||
|
|
||||||
|
# Test on examples
|
||||||
|
print("\n8. Testing on examples...")
|
||||||
|
test_model_on_examples(model, device)
|
||||||
|
|
||||||
|
# Generate comprehensive report
|
||||||
|
print("\n9. Generating performance report...")
|
||||||
|
generate_performance_report(model, val_loader, device)
|
||||||
|
|
||||||
|
print("\n" + "=" * 60)
|
||||||
|
print("EVALUATION COMPLETED SUCCESSFULLY!")
|
||||||
|
print("=" * 60)
|
||||||
|
print("\nNext steps:")
|
||||||
|
print("1. Check the generated plots in the runs/similarity directory")
|
||||||
|
print("2. Review the performance report in the reports directory")
|
||||||
|
print("3. Consider adjusting the decision threshold if needed")
|
||||||
|
print("4. Retrain with different hyperparameters if performance is poor")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
216
models/SiaN-similarity/example.py
Normal file
216
models/SiaN-similarity/example.py
Normal file
@@ -0,0 +1,216 @@
|
|||||||
|
"""
|
||||||
|
Пример использования модели оценки схожести с даталоадером.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from dataloader import YaGoDataset, create_data_loaders
|
||||||
|
from model import SimilarityCNN, SimilarityLoss
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
"""Основной пример использования."""
|
||||||
|
print("ПРИМЕР ИСПОЛЬЗОВАНИЯ МОДЕЛИ СХОЖЕСТИ С ДАТАЛОАДЕРОМ")
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
# Конфигурация
|
||||||
|
config = {
|
||||||
|
"data_dir": r"C:\Users\admin\Projects\autopilot\datasets\ya_go_maps\images",
|
||||||
|
"batch_size": 4,
|
||||||
|
"image_size": (256, 256),
|
||||||
|
"train_split": 0.8,
|
||||||
|
"num_workers": 0,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Устройство
|
||||||
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
print(f"Устройство: {device}")
|
||||||
|
|
||||||
|
# 1. Создание датасета
|
||||||
|
print("\n1. СОЗДАНИЕ ДАТАСЕТА")
|
||||||
|
print("-" * 40)
|
||||||
|
|
||||||
|
dataset = YaGoDataset(
|
||||||
|
root_dir=config["data_dir"],
|
||||||
|
augment=False,
|
||||||
|
image_size=config["image_size"],
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"Размер датасета: {len(dataset)} пар изображений")
|
||||||
|
|
||||||
|
# Получение примера из датасета
|
||||||
|
sample = dataset[0]
|
||||||
|
print(f"\nПример из датасета:")
|
||||||
|
print(f" Google image shape: {sample['google_img'].shape}")
|
||||||
|
print(f" Yandex image shape: {sample['yandex_img'].shape}")
|
||||||
|
print(f" Same domain: {sample['same_domain']}")
|
||||||
|
print(f" Index: {sample['idx'].item()}")
|
||||||
|
|
||||||
|
# 2. Создание даталоадеров
|
||||||
|
print("\n2. СОЗДАНИЕ ДАТАЛОАДЕРОВ")
|
||||||
|
print("-" * 40)
|
||||||
|
|
||||||
|
train_loader, val_loader = create_data_loaders(
|
||||||
|
root_dir=config["data_dir"],
|
||||||
|
batch_size=config["batch_size"],
|
||||||
|
train_split=config["train_split"],
|
||||||
|
num_workers=config["num_workers"],
|
||||||
|
image_size=config["image_size"],
|
||||||
|
augment_train=True,
|
||||||
|
augment_val=False,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"Train batches: {len(train_loader)}")
|
||||||
|
print(f"Val batches: {len(val_loader)}")
|
||||||
|
|
||||||
|
# 3. Создание модели
|
||||||
|
print("\n3. СОЗДАНИЕ МОДЕЛИ")
|
||||||
|
print("-" * 40)
|
||||||
|
|
||||||
|
model = SimilarityCNN(
|
||||||
|
input_channels=3,
|
||||||
|
hidden_channels=64,
|
||||||
|
num_blocks=4,
|
||||||
|
dropout_rate=0.3,
|
||||||
|
use_batch_norm=True,
|
||||||
|
).to(device)
|
||||||
|
|
||||||
|
print(f"Параметры модели: {sum(p.numel() for p in model.parameters()):,}")
|
||||||
|
|
||||||
|
# 4. Тестирование на одном батче
|
||||||
|
print("\n4. ТЕСТИРОВАНИЕ НА ОДНОМ БАТЧЕ")
|
||||||
|
print("-" * 40)
|
||||||
|
|
||||||
|
# Получаем батч из train_loader
|
||||||
|
for batch in train_loader:
|
||||||
|
google_img = batch["google_img"].to(device)
|
||||||
|
yandex_img = batch["yandex_img"].to(device)
|
||||||
|
same_domain = batch["same_domain"].float().to(device).unsqueeze(1)
|
||||||
|
|
||||||
|
print(f"Batch size: {google_img.shape[0]}")
|
||||||
|
print(f"Image shape: {google_img.shape[1:]}")
|
||||||
|
print(f"Same domain labels: {same_domain.squeeze().tolist()}")
|
||||||
|
|
||||||
|
# Предсказание схожести
|
||||||
|
with torch.no_grad():
|
||||||
|
predictions = model.predict_similarity(google_img, yandex_img)
|
||||||
|
print(f"\nПредсказания схожести:")
|
||||||
|
for i in range(len(predictions)):
|
||||||
|
print(
|
||||||
|
f" Sample {i}: {predictions[i].item():.4f} (target: {same_domain[i].item():.1f})"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Расчет потерь
|
||||||
|
loss_fn = SimilarityLoss().to(device)
|
||||||
|
loss = loss_fn(predictions, same_domain)
|
||||||
|
print(f"\nПотери на батче: {loss.item():.4f}")
|
||||||
|
|
||||||
|
# Расчет метрик
|
||||||
|
metrics = loss_fn.compute_metrics(predictions, same_domain)
|
||||||
|
print("\nМетрики на батче:")
|
||||||
|
for key, value in metrics.items():
|
||||||
|
print(f" {key}: {value:.4f}")
|
||||||
|
|
||||||
|
break # Только первый батч
|
||||||
|
|
||||||
|
# 5. Обучение на одном эпохе (демонстрация)
|
||||||
|
print("\n5. ДЕМОНСТРАЦИЯ ОБУЧЕНИЯ НА ОДНОЙ ЭПОХЕ")
|
||||||
|
print("-" * 40)
|
||||||
|
|
||||||
|
optimizer = torch.optim.Adam(model.parameters(), lr=2e-4)
|
||||||
|
model.train()
|
||||||
|
|
||||||
|
total_loss = 0
|
||||||
|
total_samples = 0
|
||||||
|
|
||||||
|
for batch_idx, batch in enumerate(train_loader):
|
||||||
|
if batch_idx >= 3: # Ограничиваем 3 батчами для демонстрации
|
||||||
|
break
|
||||||
|
|
||||||
|
google_img = batch["google_img"].to(device)
|
||||||
|
yandex_img = batch["yandex_img"].to(device)
|
||||||
|
same_domain = batch["same_domain"].float().to(device).unsqueeze(1)
|
||||||
|
|
||||||
|
optimizer.zero_grad()
|
||||||
|
|
||||||
|
predictions = model(google_img, yandex_img)
|
||||||
|
loss = loss_fn(predictions, same_domain)
|
||||||
|
|
||||||
|
loss.backward()
|
||||||
|
optimizer.step()
|
||||||
|
|
||||||
|
batch_loss = loss.item() * google_img.size(0)
|
||||||
|
total_loss += batch_loss
|
||||||
|
total_samples += google_img.size(0)
|
||||||
|
|
||||||
|
print(f"Batch {batch_idx + 1}: loss = {loss.item():.4f}")
|
||||||
|
|
||||||
|
avg_loss = total_loss / total_samples
|
||||||
|
print(f"\nСредние потери за 3 батча: {avg_loss:.4f}")
|
||||||
|
|
||||||
|
# 6. Валидация
|
||||||
|
print("\n6. ВАЛИДАЦИЯ")
|
||||||
|
print("-" * 40)
|
||||||
|
|
||||||
|
model.eval()
|
||||||
|
val_loss = 0
|
||||||
|
val_samples = 0
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
for batch_idx, batch in enumerate(val_loader):
|
||||||
|
if batch_idx >= 2: # Ограничиваем 2 батчами для демонстрации
|
||||||
|
break
|
||||||
|
|
||||||
|
google_img = batch["google_img"].to(device)
|
||||||
|
yandex_img = batch["yandex_img"].to(device)
|
||||||
|
same_domain = batch["same_domain"].float().to(device).unsqueeze(1)
|
||||||
|
|
||||||
|
predictions = model.predict_similarity(google_img, yandex_img)
|
||||||
|
loss = loss_fn(predictions, same_domain)
|
||||||
|
|
||||||
|
val_loss += loss.item() * google_img.size(0)
|
||||||
|
val_samples += google_img.size(0)
|
||||||
|
|
||||||
|
print(f"Val batch {batch_idx + 1}: loss = {loss.item():.4f}")
|
||||||
|
|
||||||
|
avg_val_loss = val_loss / val_samples
|
||||||
|
print(f"\nСредние потери на валидации: {avg_val_loss:.4f}")
|
||||||
|
|
||||||
|
# 7. Пример использования для отдельных изображений
|
||||||
|
print("\n7. ПРИМЕР ДЛЯ ОТДЕЛЬНЫХ ИЗОБРАЖЕНИЙ")
|
||||||
|
print("-" * 40)
|
||||||
|
|
||||||
|
# Берем два примера из датасета
|
||||||
|
sample1 = dataset[0]
|
||||||
|
sample2 = dataset[1]
|
||||||
|
|
||||||
|
# Подготавливаем тензоры
|
||||||
|
img1_1 = sample1["google_img"].unsqueeze(0).to(device)
|
||||||
|
img1_2 = sample1["yandex_img"].unsqueeze(0).to(device)
|
||||||
|
|
||||||
|
img2_1 = sample2["google_img"].unsqueeze(0).to(device)
|
||||||
|
img2_2 = sample2["yandex_img"].unsqueeze(0).to(device)
|
||||||
|
|
||||||
|
# Предсказания
|
||||||
|
with torch.no_grad():
|
||||||
|
# Сравнение пар из одного домена
|
||||||
|
sim_same1 = model.predict_similarity(img1_1, img1_2)
|
||||||
|
sim_same2 = model.predict_similarity(img2_1, img2_2)
|
||||||
|
|
||||||
|
# Сравнение пар из разных доменов
|
||||||
|
sim_diff1 = model.predict_similarity(img1_1, img2_2)
|
||||||
|
sim_diff2 = model.predict_similarity(img2_1, img1_2)
|
||||||
|
|
||||||
|
print("Сравнение пар изображений:")
|
||||||
|
print(f" Пара 1 (один домен): {sim_same1.item():.4f}")
|
||||||
|
print(f" Пара 2 (один домен): {sim_same2.item():.4f}")
|
||||||
|
print(f" Разные домены 1: {sim_diff1.item():.4f}")
|
||||||
|
print(f" Разные домены 2: {sim_diff2.item():.4f}")
|
||||||
|
|
||||||
|
print("\n" + "=" * 60)
|
||||||
|
print("ПРИМЕР ЗАВЕРШЕН")
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
219
models/SiaN-similarity/model.py
Normal file
219
models/SiaN-similarity/model.py
Normal file
@@ -0,0 +1,219 @@
|
|||||||
|
from typing import Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from torchvision import models
|
||||||
|
|
||||||
|
|
||||||
|
class SimilarityCNN(nn.Module):
|
||||||
|
"""
|
||||||
|
Модель для оценки схожести двух изображений на базе предобученного бэкбона.
|
||||||
|
|
||||||
|
Интерфейс совместим с исходной:
|
||||||
|
- forward(img1, img2) -> тензор (B, 1) со скором в [0, 1]
|
||||||
|
- predict_similarity(img1, img2) -> тензор (B, 1) без градиентов
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
input_channels: int = 3,
|
||||||
|
backbone_name: str = "resnet18",
|
||||||
|
pretrained: bool = True,
|
||||||
|
dropout_rate: float = 0.3,
|
||||||
|
use_batch_norm: bool = True,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.input_channels = input_channels
|
||||||
|
self.backbone_name = backbone_name
|
||||||
|
self.pretrained = pretrained
|
||||||
|
self.dropout_rate = dropout_rate
|
||||||
|
self.use_batch_norm = use_batch_norm
|
||||||
|
|
||||||
|
# 1. Создаём бэкбон и берём фичи до последнего FC
|
||||||
|
backbone = self._create_backbone(backbone_name, pretrained)
|
||||||
|
|
||||||
|
# Для ResNet18 выход фичей = 512
|
||||||
|
self.feature_dim = backbone.fc.in_features
|
||||||
|
# Заменяем classification head на Identity, чтобы получать только признаки
|
||||||
|
backbone.fc = nn.Identity()
|
||||||
|
self.backbone = backbone
|
||||||
|
|
||||||
|
# 2. Голова для сравнения двух векторов признаков
|
||||||
|
# Вход: [f1, f2, |f1 - f2|, f1 * f2] => 4 * feature_dim
|
||||||
|
compare_input_dim = self.feature_dim * 4
|
||||||
|
|
||||||
|
layers = [
|
||||||
|
nn.Linear(compare_input_dim, 512),
|
||||||
|
nn.BatchNorm1d(512) if use_batch_norm else nn.Identity(),
|
||||||
|
nn.ReLU(inplace=True),
|
||||||
|
nn.Dropout(dropout_rate),
|
||||||
|
|
||||||
|
nn.Linear(512, 256),
|
||||||
|
nn.BatchNorm1d(256) if use_batch_norm else nn.Identity(),
|
||||||
|
nn.ReLU(inplace=True),
|
||||||
|
nn.Dropout(dropout_rate),
|
||||||
|
|
||||||
|
nn.Linear(256, 1),
|
||||||
|
nn.Sigmoid(), # выход в [0, 1]
|
||||||
|
]
|
||||||
|
self.head = nn.Sequential(*layers)
|
||||||
|
|
||||||
|
def _create_backbone(self, name: str, pretrained: bool) -> nn.Module:
|
||||||
|
name = name.lower()
|
||||||
|
if name == "resnet18":
|
||||||
|
model = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1 if pretrained else None)
|
||||||
|
elif name == "resnet34":
|
||||||
|
model = models.resnet34(weights=models.ResNet34_Weights.IMAGENET1K_V1 if pretrained else None)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported backbone: {name}")
|
||||||
|
# Если у тебя не 3 канала, можно добавить адаптер 1x1 conv перед model.conv1
|
||||||
|
if self.input_channels != 3:
|
||||||
|
old_conv = model.conv1
|
||||||
|
model.conv1 = nn.Conv2d(
|
||||||
|
self.input_channels,
|
||||||
|
old_conv.out_channels,
|
||||||
|
kernel_size=old_conv.kernel_size,
|
||||||
|
stride=old_conv.stride,
|
||||||
|
padding=old_conv.padding,
|
||||||
|
bias=old_conv.bias is not None,
|
||||||
|
)
|
||||||
|
return model
|
||||||
|
|
||||||
|
def _extract_features(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Прогоняет одно изображение через бэкбон и возвращает вектор признаков (B, feature_dim).
|
||||||
|
Для ResNet: это эквивалентно model.forward(x), когда fc = Identity.
|
||||||
|
"""
|
||||||
|
return self.backbone(x) # (B, feature_dim)
|
||||||
|
|
||||||
|
def forward(self, img1: torch.Tensor, img2: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
img1, img2: (B, C, H, W) -> similarity: (B, 1)
|
||||||
|
"""
|
||||||
|
f1 = self._extract_features(img1) # (B, D)
|
||||||
|
f2 = self._extract_features(img2) # (B, D)
|
||||||
|
|
||||||
|
# Вектор сравнения
|
||||||
|
diff = torch.abs(f1 - f2)
|
||||||
|
prod = f1 * f2
|
||||||
|
combined = torch.cat([f1, f2, diff, prod], dim=1) # (B, 4D)
|
||||||
|
|
||||||
|
similarity = self.head(combined) # (B, 1) в [0, 1]
|
||||||
|
return similarity
|
||||||
|
|
||||||
|
def predict_similarity(self, img1: torch.Tensor, img2: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Инференс без градиентов, интерфейс как у исходной модели.
|
||||||
|
"""
|
||||||
|
was_training = self.training
|
||||||
|
self.eval()
|
||||||
|
with torch.no_grad():
|
||||||
|
sim = self.forward(img1, img2)
|
||||||
|
if was_training:
|
||||||
|
self.train()
|
||||||
|
return sim
|
||||||
|
|
||||||
|
|
||||||
|
class SimilarityLoss(nn.Module):
|
||||||
|
"""
|
||||||
|
Оставляю тот же интерфейс loss, что и в твоём коде.
|
||||||
|
Если таргет бинарный (0/1), BCELoss подходит.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.criterion = nn.BCELoss()
|
||||||
|
|
||||||
|
def forward(self, pred_similarity: torch.Tensor, target_same: torch.Tensor) -> torch.Tensor:
|
||||||
|
return self.criterion(pred_similarity, target_same)
|
||||||
|
|
||||||
|
def compute_metrics(
|
||||||
|
self,
|
||||||
|
pred_similarity: torch.Tensor,
|
||||||
|
target_same: torch.Tensor,
|
||||||
|
threshold: float = 0.5,
|
||||||
|
) -> dict:
|
||||||
|
with torch.no_grad():
|
||||||
|
pred_binary = (pred_similarity > threshold).float()
|
||||||
|
target_binary = (target_same > 0.5).float()
|
||||||
|
|
||||||
|
correct = (pred_binary == target_binary).float()
|
||||||
|
accuracy = correct.mean().item()
|
||||||
|
|
||||||
|
tp = ((pred_binary == 1) & (target_binary == 1)).float().sum().item()
|
||||||
|
fp = ((pred_binary == 1) & (target_binary == 0)).float().sum().item()
|
||||||
|
fn = ((pred_binary == 0) & (target_binary == 1)).float().sum().item()
|
||||||
|
tn = ((pred_binary == 0) & (target_binary == 0)).float().sum().item()
|
||||||
|
|
||||||
|
precision = tp / (tp + fp + 1e-8)
|
||||||
|
recall = tp / (tp + fn + 1e-8)
|
||||||
|
f1 = 2 * precision * recall / (precision + recall + 1e-8)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"accuracy": accuracy,
|
||||||
|
"precision": precision,
|
||||||
|
"recall": recall,
|
||||||
|
"f1": f1,
|
||||||
|
"mean_similarity": pred_similarity.mean().item(),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def create_similarity_model(
|
||||||
|
model_type: str = "backbone",
|
||||||
|
input_size: Tuple[int, int] = (256, 256),
|
||||||
|
**kwargs,
|
||||||
|
) -> nn.Module:
|
||||||
|
"""
|
||||||
|
Аналог вашей фабрики, но с новым типом модели.
|
||||||
|
"""
|
||||||
|
if model_type == "backbone":
|
||||||
|
return SimilarityCNN(**kwargs)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown model type: {model_type}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
print(f"Using device: {device}")
|
||||||
|
|
||||||
|
model = SimilarityCNN(
|
||||||
|
input_channels=3,
|
||||||
|
backbone_name="resnet18",
|
||||||
|
pretrained=True,
|
||||||
|
dropout_rate=0.3,
|
||||||
|
use_batch_norm=True,
|
||||||
|
).to(device)
|
||||||
|
|
||||||
|
print(
|
||||||
|
f"Model created with {sum(p.numel() for p in model.parameters()):,} parameters"
|
||||||
|
)
|
||||||
|
|
||||||
|
batch_size = 4
|
||||||
|
height, width = 256, 256
|
||||||
|
|
||||||
|
img1 = torch.randn(batch_size, 3, height, width).to(device)
|
||||||
|
img2 = torch.randn(batch_size, 3, height, width).to(device)
|
||||||
|
|
||||||
|
print("\nTesting forward pass...")
|
||||||
|
output = model(img1, img2)
|
||||||
|
print(f"Output shape: {output.shape}")
|
||||||
|
print(f"Sample output: {output[0].item():.4f}")
|
||||||
|
|
||||||
|
print("\nTesting prediction...")
|
||||||
|
pred = model.predict_similarity(img1, img2)
|
||||||
|
print(f"Prediction shape: {pred.shape}")
|
||||||
|
|
||||||
|
print("\nTesting loss function...")
|
||||||
|
target = torch.rand(batch_size, 1).to(device)
|
||||||
|
loss_fn = SimilarityLoss().to(device)
|
||||||
|
loss = loss_fn(output, target)
|
||||||
|
print(f"Loss value: {loss.item():.6f}")
|
||||||
|
|
||||||
|
print("\nTesting metrics...")
|
||||||
|
metrics = loss_fn.compute_metrics(output, target)
|
||||||
|
for key, value in metrics.items():
|
||||||
|
print(f"{key}: {value:.6f}")
|
||||||
|
|
||||||
|
print("\nAll tests completed successfully!")
|
||||||
146
models/SiaN-similarity/predict.py
Normal file
146
models/SiaN-similarity/predict.py
Normal file
@@ -0,0 +1,146 @@
|
|||||||
|
"""
|
||||||
|
Script for predicting similarity between two images.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from model import SimilarityCNN
|
||||||
|
from PIL import Image
|
||||||
|
from torchvision import transforms
|
||||||
|
|
||||||
|
|
||||||
|
def load_image(image_path: str, image_size: tuple = (256, 256)) -> torch.Tensor:
|
||||||
|
"""Load and preprocess image."""
|
||||||
|
transform = transforms.Compose(
|
||||||
|
[
|
||||||
|
transforms.Resize(image_size),
|
||||||
|
transforms.ToTensor(),
|
||||||
|
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
image = Image.open(image_path).convert("RGB")
|
||||||
|
return transform(image).unsqueeze(0) # Add batch dimension
|
||||||
|
|
||||||
|
|
||||||
|
def predict_similarity(
|
||||||
|
model: SimilarityCNN,
|
||||||
|
image1_path: str,
|
||||||
|
image2_path: str,
|
||||||
|
device: torch.device,
|
||||||
|
image_size: tuple = (256, 256),
|
||||||
|
) -> float:
|
||||||
|
"""Predict similarity between two images."""
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
img1 = load_image(image1_path, image_size).to(device)
|
||||||
|
img2 = load_image(image2_path, image_size).to(device)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
similarity = model(img1, img2)
|
||||||
|
|
||||||
|
return similarity.item()
|
||||||
|
|
||||||
|
|
||||||
|
def load_model(
|
||||||
|
checkpoint_path: str,
|
||||||
|
device: torch.device,
|
||||||
|
**model_kwargs,
|
||||||
|
) -> SimilarityCNN:
|
||||||
|
"""Load model from checkpoint."""
|
||||||
|
model = SimilarityCNN(**model_kwargs).to(device)
|
||||||
|
|
||||||
|
checkpoint = torch.load(checkpoint_path, map_location=device)
|
||||||
|
model.load_state_dict(checkpoint["model_state_dict"])
|
||||||
|
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description="Predict similarity between two images"
|
||||||
|
)
|
||||||
|
parser.add_argument("--image1", type=str, required=True, help="Path to first image")
|
||||||
|
parser.add_argument(
|
||||||
|
"--image2", type=str, required=True, help="Path to second image"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--checkpoint",
|
||||||
|
type=str,
|
||||||
|
default="runs/similarity/checkpoints/best_model.pt",
|
||||||
|
help="Path to model checkpoint",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--device",
|
||||||
|
type=str,
|
||||||
|
default="cuda" if torch.cuda.is_available() else "cpu",
|
||||||
|
help="Device to use for inference",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--image_size",
|
||||||
|
type=int,
|
||||||
|
default=256,
|
||||||
|
help="Image size for model input",
|
||||||
|
)
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
device = torch.device(args.device)
|
||||||
|
print(f"Using device: {device}")
|
||||||
|
|
||||||
|
if not os.path.exists(args.image1):
|
||||||
|
print(f"Error: Image not found: {args.image1}")
|
||||||
|
return
|
||||||
|
|
||||||
|
if not os.path.exists(args.image2):
|
||||||
|
print(f"Error: Image not found: {args.image2}")
|
||||||
|
return
|
||||||
|
|
||||||
|
if not os.path.exists(args.checkpoint):
|
||||||
|
print(f"Warning: Checkpoint not found: {args.checkpoint}")
|
||||||
|
print("Using randomly initialized model for demonstration")
|
||||||
|
model = SimilarityCNN(
|
||||||
|
input_channels=3,
|
||||||
|
hidden_channels=64,
|
||||||
|
num_blocks=4,
|
||||||
|
dropout_rate=0.3,
|
||||||
|
use_batch_norm=True,
|
||||||
|
).to(device)
|
||||||
|
else:
|
||||||
|
print(f"Loading model from: {args.checkpoint}")
|
||||||
|
model = load_model(
|
||||||
|
checkpoint_path=args.checkpoint,
|
||||||
|
device=device,
|
||||||
|
input_channels=3,
|
||||||
|
hidden_channels=64,
|
||||||
|
num_blocks=4,
|
||||||
|
dropout_rate=0.3,
|
||||||
|
use_batch_norm=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
print(
|
||||||
|
f"Model loaded with {sum(p.numel() for p in model.parameters()):,} parameters"
|
||||||
|
)
|
||||||
|
|
||||||
|
similarity = predict_similarity(
|
||||||
|
model=model,
|
||||||
|
image1_path=args.image1,
|
||||||
|
image2_path=args.image2,
|
||||||
|
device=device,
|
||||||
|
image_size=(args.image_size, args.image_size),
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"\nSimilarity between images:")
|
||||||
|
print(f" Image 1: {args.image1}")
|
||||||
|
print(f" Image 2: {args.image2}")
|
||||||
|
print(f" Similarity score: {similarity:.4f}")
|
||||||
|
print(f" Interpretation: {'Similar' if similarity > 0.5 else 'Different'}")
|
||||||
|
|
||||||
|
return similarity
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
File diff suppressed because one or more lines are too long
256
models/SiaN-similarity/simple_results_explanation.py
Normal file
256
models/SiaN-similarity/simple_results_explanation.py
Normal file
@@ -0,0 +1,256 @@
|
|||||||
|
"""
|
||||||
|
ПРОСТОЕ ОБЪЯСНЕНИЕ РЕЗУЛЬТАТОВ ОБУЧЕНИЯ
|
||||||
|
========================================
|
||||||
|
|
||||||
|
Этот файл объясняет результаты обучения модели простыми словами,
|
||||||
|
как будто ты студент, который только начал изучать машинное обучение.
|
||||||
|
|
||||||
|
Представь, что train.py - это предыдущая ячейка в твоем блокноте,
|
||||||
|
где ты обучил модель. Теперь давай посмотрим, что у нас получилось!
|
||||||
|
"""
|
||||||
|
|
||||||
|
print("=" * 70)
|
||||||
|
print("ПРОСТОЕ ОБЪЯСНЕНИЕ РЕЗУЛЬТАТОВ МОДЕЛИ")
|
||||||
|
print("=" * 70)
|
||||||
|
print()
|
||||||
|
|
||||||
|
# -------------------------------------------------------------------
|
||||||
|
# ЧАСТЬ 1: ЧТО МЫ СДЕЛАЛИ?
|
||||||
|
# -------------------------------------------------------------------
|
||||||
|
print("1. ЧТО МЫ СДЕЛАЛИ?")
|
||||||
|
print("-" * 40)
|
||||||
|
print("Мы создали модель, которая смотрит на две картинки и говорит:")
|
||||||
|
print(" - 'ДА' - если это одно и то же место (с Google и Яндекс карт)")
|
||||||
|
print(" - 'НЕТ' - если это разные места")
|
||||||
|
print()
|
||||||
|
print("Модель училась на тысячах пар картинок!")
|
||||||
|
print("Сначала она делала много ошибок, но потом научилась.")
|
||||||
|
print()
|
||||||
|
|
||||||
|
# -------------------------------------------------------------------
|
||||||
|
# ЧАСТЬ 2: КАК МЫ ИЗМЕРЯЕМ УСПЕХ?
|
||||||
|
# -------------------------------------------------------------------
|
||||||
|
print("2. КАК МЫ ИЗМЕРЯЕМ УСПЕХ?")
|
||||||
|
print("-" * 40)
|
||||||
|
print("Мы проверяем модель на новых картинках, которых она не видела.")
|
||||||
|
print("Считаем, сколько раз она угадала правильно.")
|
||||||
|
print()
|
||||||
|
|
||||||
|
print("Есть 4 возможных исхода:")
|
||||||
|
print(" 1. ✅ Истинно-положительный (True Positive - TP):")
|
||||||
|
print(" Модель сказала 'ДА' и это правда 'ДА'")
|
||||||
|
print()
|
||||||
|
print(" 2. ❌ Ложно-положительный (False Positive - FP):")
|
||||||
|
print(" Модель сказала 'ДА', но на самом деле 'НЕТ'")
|
||||||
|
print(" (Ошибка типа I: приняла разные места за одинаковые)")
|
||||||
|
print()
|
||||||
|
print(" 3. ❌ Ложно-отрицательный (False Negative - FN):")
|
||||||
|
print(" Модель сказала 'НЕТ', но на самом деле 'ДА'")
|
||||||
|
print(" (Ошибка типа II: не узнала одинаковые места)")
|
||||||
|
print()
|
||||||
|
print(" 4. ✅ Истинно-отрицательный (True Negative - TN):")
|
||||||
|
print(" Модель сказала 'НЕТ' и это правда 'НЕТ'")
|
||||||
|
print()
|
||||||
|
|
||||||
|
# -------------------------------------------------------------------
|
||||||
|
# ЧАСТЬ 3: ПРОСТЫЕ МЕТРИКИ
|
||||||
|
# -------------------------------------------------------------------
|
||||||
|
print("3. ПРОСТЫЕ МЕТРИКИ (ЧТО ОНИ ЗНАЧАТ?)")
|
||||||
|
print("-" * 40)
|
||||||
|
|
||||||
|
# Примерные результаты (в реальности будут другие)
|
||||||
|
accuracy = 0.82 # 82%
|
||||||
|
precision = 0.78 # 78%
|
||||||
|
recall = 0.85 # 85%
|
||||||
|
f1_score = 0.81 # 81%
|
||||||
|
|
||||||
|
print(f"ТОЧНОСТЬ (Accuracy): {accuracy:.0%}")
|
||||||
|
print(" Это как общая оценка в школе.")
|
||||||
|
print(" Сколько всего ответов правильных из 100.")
|
||||||
|
print(f" Наша модель правильна в {accuracy:.0%} случаев.")
|
||||||
|
print()
|
||||||
|
|
||||||
|
print(f"ТОЧНОСТЬ КЛАССИФИКАЦИИ (Precision): {precision:.0%}")
|
||||||
|
print(" Когда модель говорит 'ДА', насколько ей можно верить?")
|
||||||
|
print(" Из 100 раз когда она сказала 'ДА', {precision:.0%} были правдой.")
|
||||||
|
print(" Высокая точность = мало ложных 'ДА'.")
|
||||||
|
print()
|
||||||
|
|
||||||
|
print(f"ПОЛНОТА (Recall): {recall:.0%}")
|
||||||
|
print(" Сколько настоящих 'ДА' модель нашла?")
|
||||||
|
print(f" Из 100 настоящих 'ДА', модель нашла {recall:.0%}.")
|
||||||
|
print(" Высокая полнота = мало пропущенных 'ДА'.")
|
||||||
|
print()
|
||||||
|
|
||||||
|
print(f"F1-МЕРА (F1 Score): {f1_score:.0%}")
|
||||||
|
print(" Баланс между точностью и полнотой.")
|
||||||
|
print(" Как золотая середина - не слишком строгая, не слишком добрая.")
|
||||||
|
print()
|
||||||
|
|
||||||
|
# -------------------------------------------------------------------
|
||||||
|
# ЧАСТЬ 4: ТАБЛИЦА РЕЗУЛЬТАТОВ (ПРОСТАЯ)
|
||||||
|
# -------------------------------------------------------------------
|
||||||
|
print("4. ТАБЛИЦА РЕЗУЛЬТАТОВ")
|
||||||
|
print("-" * 40)
|
||||||
|
print("Давай представим, что мы протестировали модель на 1000 пар картинок:")
|
||||||
|
print()
|
||||||
|
|
||||||
|
# Простая таблица
|
||||||
|
print(" | Модель сказала 'ДА' | Модель сказала 'НЕТ' | Всего")
|
||||||
|
print("-----------------|---------------------|----------------------|-------")
|
||||||
|
print(f"На самом деле 'ДА' | TP: 425 | FN: 75 | 500")
|
||||||
|
print(f"На самом деле 'НЕТ' | FP: 95 | TN: 405 | 500")
|
||||||
|
print("-----------------|---------------------|----------------------|-------")
|
||||||
|
print(f"Всего | 520 | 480 | 1000")
|
||||||
|
print()
|
||||||
|
|
||||||
|
print("Расчеты:")
|
||||||
|
print(f" Точность = (TP + TN) / Всего = (425 + 405) / 1000 = {accuracy:.0%}")
|
||||||
|
print(f" Точность классификации = TP / (TP + FP) = 425 / 520 = {precision:.0%}")
|
||||||
|
print(f" Полнота = TP / (TP + FN) = 425 / 500 = {recall:.0%}")
|
||||||
|
print()
|
||||||
|
|
||||||
|
# -------------------------------------------------------------------
|
||||||
|
# ЧАСТЬ 5: КАК ИНТЕРПРЕТИРОВАТЬ РЕЗУЛЬТАТЫ?
|
||||||
|
# -------------------------------------------------------------------
|
||||||
|
print("5. ЧТО ЭТО ЗНАЧИТ ДЛЯ НАШЕЙ ЗАДАЧИ?")
|
||||||
|
print("-" * 40)
|
||||||
|
|
||||||
|
if precision > 0.75:
|
||||||
|
print("✅ ХОРОШО: Когда модель говорит 'это одно место',")
|
||||||
|
print(" ей можно доверять ({precision:.0%} случаев она права).")
|
||||||
|
else:
|
||||||
|
print("⚠ МОЖНО ЛУЧШЕ: Модель иногда путает разные места с одинаковыми.")
|
||||||
|
|
||||||
|
if recall > 0.75:
|
||||||
|
print("✅ ХОРОШО: Модель находит большинство одинаковых мест")
|
||||||
|
print(f" ({recall:.0%} настоящих 'одинаковых' мест она находит).")
|
||||||
|
else:
|
||||||
|
print("⚠ МОЖНО ЛУЧШЕ: Модель пропускает много одинаковых мест.")
|
||||||
|
|
||||||
|
print()
|
||||||
|
print("ДЛЯ АВТОПИЛОТА:")
|
||||||
|
print(" - Ложные 'ДА' (FP): Может думать, что мы в нужном месте,")
|
||||||
|
print(" когда это не так → опасно!")
|
||||||
|
print(" - Ложные 'НЕТ' (FN): Не узнает нужное место → менее опасно,")
|
||||||
|
print(" но машина может проехать мимо.")
|
||||||
|
print()
|
||||||
|
|
||||||
|
# -------------------------------------------------------------------
|
||||||
|
# ЧАСТЬ 6: ГРАФИКИ (ЧТО МЫ ВИДИМ?)
|
||||||
|
# -------------------------------------------------------------------
|
||||||
|
print("6. КАКИЕ ГРАФИКИ МЫ ПОЛУЧАЕМ?")
|
||||||
|
print("-" * 40)
|
||||||
|
print("После обучения мы строим 4 основных графика:")
|
||||||
|
print()
|
||||||
|
|
||||||
|
print("1. 📉 ГРАФИК ОШИБОК (Loss):")
|
||||||
|
print(" - Синяя линия: ошибки на обучающих данных")
|
||||||
|
print(" - Красная линия: ошибки на проверочных данных")
|
||||||
|
print(" - ХОРОШО: обе линии идут вниз и близки друг к другу")
|
||||||
|
print(" - ПЛОХО: линии далеко друг от друга (переобучение)")
|
||||||
|
print()
|
||||||
|
|
||||||
|
print("2. 📈 ГРАФИК ТОЧНОСТИ (Accuracy):")
|
||||||
|
print(" - Показывает, как растет точность со временем")
|
||||||
|
print(" - Должен расти и стабилизироваться")
|
||||||
|
print()
|
||||||
|
|
||||||
|
print("3. 🎯 МАТРИЦА ОШИБОК (Confusion Matrix):")
|
||||||
|
print(" - Квадратная таблица 2x2")
|
||||||
|
print(" - Показывает все 4 типа ответов (TP, FP, FN, TN)")
|
||||||
|
print(" - Идеально: все числа на диагонали, нули вне диагонали")
|
||||||
|
print()
|
||||||
|
|
||||||
|
print("4. 📊 ROC-КРИВАЯ:")
|
||||||
|
print(" - Показывает, насколько хорошо модель отличает 'ДА' от 'НЕТ'")
|
||||||
|
print(" - Чем больше площадь под кривой, тем лучше")
|
||||||
|
print(" - Идеально: площадь = 1.0 (100%)")
|
||||||
|
print()
|
||||||
|
|
||||||
|
# -------------------------------------------------------------------
|
||||||
|
# ЧАСТЬ 7: ЧТО ДЕЛАТЬ ДАЛЬШЕ?
|
||||||
|
# -------------------------------------------------------------------
|
||||||
|
print("7. ЧТО ДЕЛАТЬ, ЕСЛИ РЕЗУЛЬТАТЫ ПЛОХИЕ?")
|
||||||
|
print("-" * 40)
|
||||||
|
print("Если точность меньше 70%:")
|
||||||
|
|
||||||
|
print("1. 🎯 ПРОБЛЕМА: Модель плохо учится")
|
||||||
|
print(" РЕШЕНИЕ:")
|
||||||
|
print(" - Учить дольше (увеличить количество эпох)")
|
||||||
|
print(" - Изменить скорость обучения (learning rate)")
|
||||||
|
print(" - Добавить больше данных для обучения")
|
||||||
|
print()
|
||||||
|
|
||||||
|
print("2. 🎯 ПРОБЛЕМА: Модель запоминает, а не учится (переобучение)")
|
||||||
|
print(" РЕШЕНИЕ:")
|
||||||
|
print(" - Добавить регуляризацию (dropout)")
|
||||||
|
print(" - Использовать augmentation (искажать картинки)")
|
||||||
|
print(" - Упростить модель (меньше слоев)")
|
||||||
|
print()
|
||||||
|
|
||||||
|
print("3. 🎯 ПРОБЛЕМА: Много ложных 'ДА' (FP)")
|
||||||
|
print(" РЕШЕНИЕ:")
|
||||||
|
print(" - Повысить порог принятия решения (например, 0.7 вместо 0.5)")
|
||||||
|
print(" - Добавить больше примеров 'разных' мест")
|
||||||
|
print()
|
||||||
|
|
||||||
|
print("4. 🎯 ПРОБЛЕМА: Много ложных 'НЕТ' (FN)")
|
||||||
|
print(" РЕШЕНИЕ:")
|
||||||
|
print(" - Понизить порог принятия решения (например, 0.3 вместо 0.5)")
|
||||||
|
print(" - Добавить больше примеров 'одинаковых' мест")
|
||||||
|
print()
|
||||||
|
|
||||||
|
# -------------------------------------------------------------------
|
||||||
|
# ЧАСТЬ 8: ПРАКТИЧЕСКИЙ ПРИМЕР
|
||||||
|
# -------------------------------------------------------------------
|
||||||
|
print("8. ПРАКТИЧЕСКИЙ ПРИМЕР: КАК ИСПОЛЬЗОВАТЬ МОДЕЛЬ")
|
||||||
|
print("-" * 40)
|
||||||
|
print("После обучения модель можно использовать так:")
|
||||||
|
print()
|
||||||
|
|
||||||
|
print("```python")
|
||||||
|
print("# 1. Загружаем обученную модель")
|
||||||
|
print("model = load_trained_model('best_model.pt')")
|
||||||
|
print()
|
||||||
|
print("# 2. Берем две картинки")
|
||||||
|
print("google_img = load_image('google_map.png')")
|
||||||
|
print("yandex_img = load_image('yandex_map.png')")
|
||||||
|
print()
|
||||||
|
print("# 3. Спрашиваем у модели")
|
||||||
|
print("similarity_score = model.predict(google_img, yandex_img)")
|
||||||
|
print()
|
||||||
|
print("# 4. Интерпретируем результат")
|
||||||
|
print("if similarity_score > 0.5:")
|
||||||
|
print(" print('✅ Это похоже на одно и то же место!')")
|
||||||
|
print("else:")
|
||||||
|
print(" print('❌ Это разные места')")
|
||||||
|
print("```")
|
||||||
|
print()
|
||||||
|
|
||||||
|
print(f"Порог 0.5 можно менять:")
|
||||||
|
print(f" - Порог 0.7: более строгая модель (меньше ложных 'ДА')")
|
||||||
|
print(f" - Порог 0.3: более добрая модель (меньше ложных 'НЕТ')")
|
||||||
|
print()
|
||||||
|
|
||||||
|
# -------------------------------------------------------------------
|
||||||
|
# ЧАСТЬ 9: ЗАКЛЮЧЕНИЕ
|
||||||
|
# -------------------------------------------------------------------
|
||||||
|
print("9. ЧТО МЫ УЗНАЛИ?")
|
||||||
|
print("-" * 40)
|
||||||
|
print("✅ Модель учится сравнивать картинки")
|
||||||
|
print("✅ Мы можем измерить, насколько она хороша")
|
||||||
|
print("✅ Есть разные метрики для разных целей")
|
||||||
|
print("✅ Графики помогают понять процесс обучения")
|
||||||
|
print("✅ Можно улучшить модель, если результаты плохие")
|
||||||
|
print()
|
||||||
|
|
||||||
|
print("=" * 70)
|
||||||
|
print("🎉 ВОТ И ВСЁ! ТЕПЕРЬ ТЫ ЗНАЕШЬ, КАК ОЦЕНИВАТЬ МОДЕЛЬ!")
|
||||||
|
print("=" * 70)
|
||||||
|
print()
|
||||||
|
print("Следующие шаги:")
|
||||||
|
print("1. Запусти evaluation.py чтобы увидеть реальные графики")
|
||||||
|
print("2. Посмотри на матрицу ошибок - какие ошибки чаще?")
|
||||||
|
print("3. Попробуй изменить порог принятия решений")
|
||||||
|
print("4. Если нужно - переобучи модель с другими параметрами")
|
||||||
917
models/SiaN-similarity/train-adv.py
Normal file
917
models/SiaN-similarity/train-adv.py
Normal file
@@ -0,0 +1,917 @@
|
|||||||
|
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.optim as optim
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# TRAINING LOOP WITH VISUALIZATION
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
class SimilarityTrainer:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model: nn.Module,
|
||||||
|
trainloader: DataLoader,
|
||||||
|
valloader: DataLoader,
|
||||||
|
device: torch.device,
|
||||||
|
config: dict,
|
||||||
|
):
|
||||||
|
self.model = model.to(device)
|
||||||
|
self.trainloader = trainloader
|
||||||
|
self.valloader = valloader
|
||||||
|
self.device = device
|
||||||
|
self.config = config
|
||||||
|
|
||||||
|
self.criterion = SimilarityLoss()
|
||||||
|
self.optimizer = optim.Adam(
|
||||||
|
model.parameters(),
|
||||||
|
lr=config.get('learning_rate', 2e-4),
|
||||||
|
betas=(config.get('beta1', 0.5), config.get('beta2', 0.999))
|
||||||
|
)
|
||||||
|
|
||||||
|
self.writer = None
|
||||||
|
self.best_val_loss = float('inf')
|
||||||
|
self.epochs_without_improvement = 0
|
||||||
|
|
||||||
|
# Для хранения истории метрик
|
||||||
|
self.history = {
|
||||||
|
'train_loss': [],
|
||||||
|
'val_loss': [],
|
||||||
|
'val_accuracy': [],
|
||||||
|
'val_precision': [],
|
||||||
|
'val_recall': [],
|
||||||
|
'val_f1': [],
|
||||||
|
'learning_rate': []
|
||||||
|
}
|
||||||
|
|
||||||
|
def train_epoch(self, epoch: int) -> dict:
|
||||||
|
"""Обучение на одной эпохе"""
|
||||||
|
self.model.train()
|
||||||
|
total_loss = 0
|
||||||
|
total_samples = 0
|
||||||
|
all_metrics = []
|
||||||
|
|
||||||
|
pbar = tqdm(self.trainloader, desc=f'Epoch {epoch}')
|
||||||
|
for batch_idx, batch in enumerate(pbar):
|
||||||
|
google_img = batch['google_img'].to(self.device)
|
||||||
|
yandex_img = batch['yandex_img'].to(self.device)
|
||||||
|
target = batch['same_domain'].float().to(self.device).unsqueeze(1)
|
||||||
|
|
||||||
|
self.optimizer.zero_grad()
|
||||||
|
|
||||||
|
# Forward pass
|
||||||
|
output = self.model(google_img, yandex_img)
|
||||||
|
loss = self.criterion(output, target)
|
||||||
|
|
||||||
|
# Backward pass
|
||||||
|
loss.backward()
|
||||||
|
|
||||||
|
# Gradient clipping
|
||||||
|
if self.config.get('grad_clip', None):
|
||||||
|
torch.nn.utils.clip_grad_norm_(
|
||||||
|
self.model.parameters(),
|
||||||
|
self.config['grad_clip']
|
||||||
|
)
|
||||||
|
|
||||||
|
self.optimizer.step()
|
||||||
|
|
||||||
|
total_loss += loss.item() * google_img.size(0)
|
||||||
|
total_samples += google_img.size(0)
|
||||||
|
|
||||||
|
# Compute metrics
|
||||||
|
if batch_idx % self.config.get('log_interval', 10) == 0:
|
||||||
|
metrics = self.criterion.compute_metrics(output, target)
|
||||||
|
all_metrics.append(metrics)
|
||||||
|
pbar.set_postfix({
|
||||||
|
'loss': f"{loss.item():.4f}",
|
||||||
|
'acc': f"{metrics['accuracy']:.4f}"
|
||||||
|
})
|
||||||
|
|
||||||
|
if self.writer:
|
||||||
|
global_step = epoch * len(self.trainloader) + batch_idx
|
||||||
|
self.writer.add_scalar('train/loss', loss.item(), global_step)
|
||||||
|
self.writer.add_scalar('train/accuracy', metrics['accuracy'], global_step)
|
||||||
|
|
||||||
|
avg_loss = total_loss / total_samples
|
||||||
|
|
||||||
|
# Average metrics
|
||||||
|
if all_metrics:
|
||||||
|
avg_metrics = {
|
||||||
|
key: sum(m[key] for m in all_metrics) / len(all_metrics)
|
||||||
|
for key in all_metrics[0].keys()
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
avg_metrics = {}
|
||||||
|
|
||||||
|
return {'loss': avg_loss, **avg_metrics}
|
||||||
|
|
||||||
|
def validate(self) -> dict:
|
||||||
|
"""Валидация модели"""
|
||||||
|
self.model.eval()
|
||||||
|
total_loss = 0
|
||||||
|
total_samples = 0
|
||||||
|
all_metrics = []
|
||||||
|
|
||||||
|
# Для ROC и confusion matrix
|
||||||
|
all_predictions = []
|
||||||
|
all_targets = []
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
for batch in tqdm(self.valloader, desc='Validation'):
|
||||||
|
google_img = batch['google_img'].to(self.device)
|
||||||
|
yandex_img = batch['yandex_img'].to(self.device)
|
||||||
|
target = batch['same_domain'].float().to(self.device).unsqueeze(1)
|
||||||
|
|
||||||
|
output = self.model(google_img, yandex_img)
|
||||||
|
loss = self.criterion(output, target)
|
||||||
|
|
||||||
|
total_loss += loss.item() * google_img.size(0)
|
||||||
|
total_samples += google_img.size(0)
|
||||||
|
|
||||||
|
metrics = self.criterion.compute_metrics(output, target)
|
||||||
|
all_metrics.append(metrics)
|
||||||
|
|
||||||
|
all_predictions.append(output.cpu())
|
||||||
|
all_targets.append(target.cpu())
|
||||||
|
|
||||||
|
avg_loss = total_loss / total_samples
|
||||||
|
avg_metrics = {
|
||||||
|
key: sum(m[key] for m in all_metrics) / len(all_metrics)
|
||||||
|
for key in all_metrics[0].keys()
|
||||||
|
}
|
||||||
|
|
||||||
|
# Concatenate all predictions and targets
|
||||||
|
all_predictions = torch.cat(all_predictions, dim=0)
|
||||||
|
all_targets = torch.cat(all_targets, dim=0)
|
||||||
|
|
||||||
|
return {
|
||||||
|
'loss': avg_loss,
|
||||||
|
**avg_metrics,
|
||||||
|
'predictions': all_predictions,
|
||||||
|
'targets': all_targets
|
||||||
|
}
|
||||||
|
|
||||||
|
def train(self, num_epochs: int):
|
||||||
|
"""Основной цикл обучения"""
|
||||||
|
log_dir = os.path.join(self.config.get('output_dir', 'runs/similarity'))
|
||||||
|
os.makedirs(log_dir, exist_ok=True)
|
||||||
|
self.writer = SummaryWriter(log_dir)
|
||||||
|
|
||||||
|
print(f'\n{"="*70}')
|
||||||
|
print(f'Starting training for {num_epochs} epochs')
|
||||||
|
print(f'Logging to {log_dir}')
|
||||||
|
print(f'{"="*70}\n')
|
||||||
|
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
for epoch in range(1, num_epochs + 1):
|
||||||
|
epoch_start = time.time()
|
||||||
|
print(f'\n--- Epoch {epoch}/{num_epochs} ---')
|
||||||
|
|
||||||
|
# Train
|
||||||
|
train_metrics = self.train_epoch(epoch)
|
||||||
|
|
||||||
|
# Validate
|
||||||
|
val_metrics = self.validate()
|
||||||
|
|
||||||
|
# Store history
|
||||||
|
self.history['train_loss'].append(train_metrics['loss'])
|
||||||
|
self.history['val_loss'].append(val_metrics['loss'])
|
||||||
|
self.history['val_accuracy'].append(val_metrics['accuracy'])
|
||||||
|
self.history['val_precision'].append(val_metrics['precision'])
|
||||||
|
self.history['val_recall'].append(val_metrics['recall'])
|
||||||
|
self.history['val_f1'].append(val_metrics['f1'])
|
||||||
|
self.history['learning_rate'].append(
|
||||||
|
self.optimizer.param_groups[0]['lr']
|
||||||
|
)
|
||||||
|
|
||||||
|
# Print metrics
|
||||||
|
print(f'\nTrain Loss: {train_metrics["loss"]:.4f}')
|
||||||
|
print(f'Val Loss: {val_metrics["loss"]:.4f}')
|
||||||
|
print(f'Val Accuracy: {val_metrics["accuracy"]:.4f}')
|
||||||
|
print(f'Val Precision: {val_metrics["precision"]:.4f}')
|
||||||
|
print(f'Val Recall: {val_metrics["recall"]:.4f}')
|
||||||
|
print(f'Val F1: {val_metrics["f1"]:.4f}')
|
||||||
|
|
||||||
|
epoch_time = time.time() - epoch_start
|
||||||
|
print(f'Epoch time: {epoch_time:.2f}s')
|
||||||
|
|
||||||
|
# TensorBoard logging
|
||||||
|
if self.writer:
|
||||||
|
self.writer.add_scalar('epoch/train_loss', train_metrics['loss'], epoch)
|
||||||
|
self.writer.add_scalar('epoch/val_loss', val_metrics['loss'], epoch)
|
||||||
|
self.writer.add_scalar('epoch/val_accuracy', val_metrics['accuracy'], epoch)
|
||||||
|
self.writer.add_scalar('epoch/val_precision', val_metrics['precision'], epoch)
|
||||||
|
self.writer.add_scalar('epoch/val_recall', val_metrics['recall'], epoch)
|
||||||
|
self.writer.add_scalar('epoch/val_f1', val_metrics['f1'], epoch)
|
||||||
|
|
||||||
|
# Save checkpoint
|
||||||
|
if val_metrics['loss'] < self.best_val_loss:
|
||||||
|
self.best_val_loss = val_metrics['loss']
|
||||||
|
self.epochs_without_improvement = 0
|
||||||
|
self.save_checkpoint(epoch, val_metrics['loss'], is_best=True)
|
||||||
|
print(f'✓ New best model saved with val loss: {val_metrics["loss"]:.4f}')
|
||||||
|
else:
|
||||||
|
self.epochs_without_improvement += 1
|
||||||
|
if epoch % self.config.get('save_interval', 5) == 0:
|
||||||
|
self.save_checkpoint(epoch, val_metrics['loss'], is_best=False)
|
||||||
|
|
||||||
|
# Early stopping
|
||||||
|
patience = self.config.get('early_stopping_patience', 20)
|
||||||
|
if self.epochs_without_improvement >= patience:
|
||||||
|
print(f'\n⚠ Early stopping triggered after {patience} epochs without improvement')
|
||||||
|
break
|
||||||
|
|
||||||
|
total_time = time.time() - start_time
|
||||||
|
print(f'\n{"="*70}')
|
||||||
|
print(f'Training completed in {total_time/60:.2f} minutes')
|
||||||
|
print(f'Best validation loss: {self.best_val_loss:.4f}')
|
||||||
|
print(f'{"="*70}\n')
|
||||||
|
|
||||||
|
self.writer.close()
|
||||||
|
|
||||||
|
return self.history
|
||||||
|
|
||||||
|
def save_checkpoint(self, epoch: int, val_loss: float, is_best: bool = False):
|
||||||
|
"""Сохранение чекпоинта модели"""
|
||||||
|
checkpoint_dir = os.path.join(
|
||||||
|
self.config.get('output_dir', 'runs/similarity'),
|
||||||
|
'checkpoints'
|
||||||
|
)
|
||||||
|
os.makedirs(checkpoint_dir, exist_ok=True)
|
||||||
|
|
||||||
|
checkpoint = {
|
||||||
|
'epoch': epoch,
|
||||||
|
'model_state_dict': self.model.state_dict(),
|
||||||
|
'optimizer_state_dict': self.optimizer.state_dict(),
|
||||||
|
'val_loss': val_loss,
|
||||||
|
'config': self.config,
|
||||||
|
'history': self.history
|
||||||
|
}
|
||||||
|
|
||||||
|
checkpoint_path = os.path.join(checkpoint_dir, f'checkpoint_epoch{epoch}.pt')
|
||||||
|
torch.save(checkpoint, checkpoint_path)
|
||||||
|
|
||||||
|
if is_best:
|
||||||
|
best_path = os.path.join(checkpoint_dir, 'best_model.pt')
|
||||||
|
torch.save(checkpoint, best_path)
|
||||||
|
|
||||||
|
def load_checkpoint(self, checkpoint_path: str):
|
||||||
|
"""Загрузка чекпоинта"""
|
||||||
|
checkpoint = torch.load(checkpoint_path, map_location=self.device)
|
||||||
|
self.model.load_state_dict(checkpoint['model_state_dict'])
|
||||||
|
self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
|
||||||
|
if 'history' in checkpoint:
|
||||||
|
self.history = checkpoint['history']
|
||||||
|
return checkpoint['epoch'], checkpoint['val_loss']
|
||||||
|
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# VISUALIZATION FUNCTIONS
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
def plot_training_history(history: dict, save_path: str = None):
|
||||||
|
"""Построение графиков обучения"""
|
||||||
|
fig, axes = plt.subplots(2, 3, figsize=(18, 10))
|
||||||
|
fig.suptitle('Training History - Siamese Network для корреляции снимков',
|
||||||
|
fontsize=16, fontweight='bold')
|
||||||
|
|
||||||
|
epochs = range(1, len(history['train_loss']) + 1)
|
||||||
|
|
||||||
|
# Loss
|
||||||
|
axes[0, 0].plot(epochs, history['train_loss'], 'b-', label='Train Loss', linewidth=2)
|
||||||
|
axes[0, 0].plot(epochs, history['val_loss'], 'r-', label='Val Loss', linewidth=2)
|
||||||
|
axes[0, 0].set_xlabel('Epoch')
|
||||||
|
axes[0, 0].set_ylabel('Loss')
|
||||||
|
axes[0, 0].set_title('Loss Curves')
|
||||||
|
axes[0, 0].legend()
|
||||||
|
axes[0, 0].grid(True, alpha=0.3)
|
||||||
|
|
||||||
|
# Accuracy
|
||||||
|
axes[0, 1].plot(epochs, history['val_accuracy'], 'g-', linewidth=2)
|
||||||
|
axes[0, 1].set_xlabel('Epoch')
|
||||||
|
axes[0, 1].set_ylabel('Accuracy')
|
||||||
|
axes[0, 1].set_title('Validation Accuracy')
|
||||||
|
axes[0, 1].grid(True, alpha=0.3)
|
||||||
|
axes[0, 1].set_ylim([0, 1])
|
||||||
|
|
||||||
|
# F1 Score
|
||||||
|
axes[0, 2].plot(epochs, history['val_f1'], 'm-', linewidth=2)
|
||||||
|
axes[0, 2].set_xlabel('Epoch')
|
||||||
|
axes[0, 2].set_ylabel('F1 Score')
|
||||||
|
axes[0, 2].set_title('Validation F1 Score')
|
||||||
|
axes[0, 2].grid(True, alpha=0.3)
|
||||||
|
axes[0, 2].set_ylim([0, 1])
|
||||||
|
|
||||||
|
# Precision
|
||||||
|
axes[1, 0].plot(epochs, history['val_precision'], 'c-', linewidth=2)
|
||||||
|
axes[1, 0].set_xlabel('Epoch')
|
||||||
|
axes[1, 0].set_ylabel('Precision')
|
||||||
|
axes[1, 0].set_title('Validation Precision')
|
||||||
|
axes[1, 0].grid(True, alpha=0.3)
|
||||||
|
axes[1, 0].set_ylim([0, 1])
|
||||||
|
|
||||||
|
# Recall
|
||||||
|
axes[1, 1].plot(epochs, history['val_recall'], 'y-', linewidth=2)
|
||||||
|
axes[1, 1].set_xlabel('Epoch')
|
||||||
|
axes[1, 1].set_ylabel('Recall')
|
||||||
|
axes[1, 1].set_title('Validation Recall')
|
||||||
|
axes[1, 1].grid(True, alpha=0.3)
|
||||||
|
axes[1, 1].set_ylim([0, 1])
|
||||||
|
|
||||||
|
# Learning Rate
|
||||||
|
axes[1, 2].plot(epochs, history['learning_rate'], 'k-', linewidth=2)
|
||||||
|
axes[1, 2].set_xlabel('Epoch')
|
||||||
|
axes[1, 2].set_ylabel('Learning Rate')
|
||||||
|
axes[1, 2].set_title('Learning Rate Schedule')
|
||||||
|
axes[1, 2].grid(True, alpha=0.3)
|
||||||
|
axes[1, 2].set_yscale('log')
|
||||||
|
|
||||||
|
plt.tight_layout()
|
||||||
|
|
||||||
|
if save_path:
|
||||||
|
plt.savefig(save_path, dpi=300, bbox_inches='tight')
|
||||||
|
print(f'Training history plot saved to {save_path}')
|
||||||
|
|
||||||
|
plt.show()
|
||||||
|
|
||||||
|
|
||||||
|
def plot_roc_curve(predictions: torch.Tensor, targets: torch.Tensor, save_path: str = None):
|
||||||
|
"""Построение ROC кривой"""
|
||||||
|
from sklearn.metrics import roc_curve, auc
|
||||||
|
|
||||||
|
predictions_np = predictions.numpy().flatten()
|
||||||
|
targets_np = targets.numpy().flatten()
|
||||||
|
|
||||||
|
fpr, tpr, thresholds = roc_curve(targets_np, predictions_np)
|
||||||
|
roc_auc = auc(fpr, tpr)
|
||||||
|
|
||||||
|
plt.figure(figsize=(10, 8))
|
||||||
|
plt.plot(fpr, tpr, color='darkorange', lw=2,
|
||||||
|
label=f'ROC curve (AUC = {roc_auc:.3f})')
|
||||||
|
plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--', label='Random')
|
||||||
|
plt.xlim([0.0, 1.0])
|
||||||
|
plt.ylim([0.0, 1.05])
|
||||||
|
plt.xlabel('False Positive Rate', fontsize=12)
|
||||||
|
plt.ylabel('True Positive Rate', fontsize=12)
|
||||||
|
plt.title('ROC Curve - Siamese Network', fontsize=14, fontweight='bold')
|
||||||
|
plt.legend(loc="lower right", fontsize=12)
|
||||||
|
plt.grid(True, alpha=0.3)
|
||||||
|
|
||||||
|
if save_path:
|
||||||
|
plt.savefig(save_path, dpi=300, bbox_inches='tight')
|
||||||
|
print(f'ROC curve saved to {save_path}')
|
||||||
|
|
||||||
|
plt.show()
|
||||||
|
|
||||||
|
return roc_auc
|
||||||
|
|
||||||
|
|
||||||
|
def plot_confusion_matrix(predictions: torch.Tensor, targets: torch.Tensor,
|
||||||
|
threshold: float = 0.5, save_path: str = None):
|
||||||
|
"""Построение матрицы ошибок"""
|
||||||
|
from sklearn.metrics import confusion_matrix
|
||||||
|
|
||||||
|
predictions_binary = (predictions.numpy().flatten() >= threshold).astype(int)
|
||||||
|
targets_binary = (targets.numpy().flatten() >= 0.5).astype(int)
|
||||||
|
|
||||||
|
cm = confusion_matrix(targets_binary, predictions_binary)
|
||||||
|
|
||||||
|
plt.figure(figsize=(10, 8))
|
||||||
|
plt.imshow(cm, interpolation='nearest', cmap=plt.cm.Blues)
|
||||||
|
plt.title('Confusion Matrix - Correlation Detection', fontsize=14, fontweight='bold')
|
||||||
|
plt.colorbar()
|
||||||
|
|
||||||
|
classes = ['Different Domains', 'Same Domain']
|
||||||
|
tick_marks = np.arange(len(classes))
|
||||||
|
plt.xticks(tick_marks, classes, rotation=45, fontsize=12)
|
||||||
|
plt.yticks(tick_marks, classes, fontsize=12)
|
||||||
|
|
||||||
|
# Add text annotations
|
||||||
|
thresh = cm.max() / 2.
|
||||||
|
for i in range(cm.shape[0]):
|
||||||
|
for j in range(cm.shape[1]):
|
||||||
|
plt.text(j, i, format(cm[i, j], 'd'),
|
||||||
|
ha="center", va="center",
|
||||||
|
color="white" if cm[i, j] > thresh else "black",
|
||||||
|
fontsize=16, fontweight='bold')
|
||||||
|
|
||||||
|
plt.ylabel('True Label', fontsize=12)
|
||||||
|
plt.xlabel('Predicted Label', fontsize=12)
|
||||||
|
plt.tight_layout()
|
||||||
|
|
||||||
|
if save_path:
|
||||||
|
plt.savefig(save_path, dpi=300, bbox_inches='tight')
|
||||||
|
print(f'Confusion matrix saved to {save_path}')
|
||||||
|
|
||||||
|
plt.show()
|
||||||
|
|
||||||
|
|
||||||
|
def plot_similarity_distribution(predictions: torch.Tensor, targets: torch.Tensor,
|
||||||
|
save_path: str = None):
|
||||||
|
"""Распределение предсказанных значений схожести"""
|
||||||
|
predictions_np = predictions.numpy().flatten()
|
||||||
|
targets_np = targets.numpy().flatten()
|
||||||
|
|
||||||
|
same_domain = predictions_np[targets_np >= 0.5]
|
||||||
|
diff_domain = predictions_np[targets_np < 0.5]
|
||||||
|
|
||||||
|
plt.figure(figsize=(12, 6))
|
||||||
|
|
||||||
|
plt.subplot(1, 2, 1)
|
||||||
|
plt.hist(same_domain, bins=50, alpha=0.7, color='green', edgecolor='black', label='Same Domain')
|
||||||
|
plt.hist(diff_domain, bins=50, alpha=0.7, color='red', edgecolor='black', label='Different Domains')
|
||||||
|
plt.xlabel('Predicted Similarity Score', fontsize=12)
|
||||||
|
plt.ylabel('Frequency', fontsize=12)
|
||||||
|
plt.title('Distribution of Similarity Scores', fontsize=14, fontweight='bold')
|
||||||
|
plt.legend(fontsize=11)
|
||||||
|
plt.grid(True, alpha=0.3)
|
||||||
|
|
||||||
|
plt.subplot(1, 2, 2)
|
||||||
|
plt.boxplot([diff_domain, same_domain], labels=['Different', 'Same'])
|
||||||
|
plt.ylabel('Similarity Score', fontsize=12)
|
||||||
|
plt.xlabel('Domain Match', fontsize=12)
|
||||||
|
plt.title('Similarity Score by Domain Match', fontsize=14, fontweight='bold')
|
||||||
|
plt.grid(True, alpha=0.3, axis='y')
|
||||||
|
|
||||||
|
plt.tight_layout()
|
||||||
|
|
||||||
|
if save_path:
|
||||||
|
plt.savefig(save_path, dpi=300, bbox_inches='tight')
|
||||||
|
print(f'Similarity distribution plot saved to {save_path}')
|
||||||
|
|
||||||
|
plt.show()
|
||||||
|
|
||||||
|
# Print statistics
|
||||||
|
print(f'\n--- Similarity Score Statistics ---')
|
||||||
|
print(f'Same Domain:')
|
||||||
|
print(f' Mean: {same_domain.mean():.4f}')
|
||||||
|
print(f' Std: {same_domain.std():.4f}')
|
||||||
|
print(f' Min: {same_domain.min():.4f}')
|
||||||
|
print(f' Max: {same_domain.max():.4f}')
|
||||||
|
print(f'\nDifferent Domains:')
|
||||||
|
print(f' Mean: {diff_domain.mean():.4f}')
|
||||||
|
print(f' Std: {diff_domain.std():.4f}')
|
||||||
|
print(f' Min: {diff_domain.min():.4f}')
|
||||||
|
print(f' Max: {diff_domain.max():.4f}')
|
||||||
|
|
||||||
|
|
||||||
|
def visualize_sample_predictions(model: nn.Module, dataset, device: torch.device,
|
||||||
|
num_samples: int = 8, save_path: str = None):
|
||||||
|
"""Визуализация примеров предсказаний"""
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
# Get random samples
|
||||||
|
indices = np.random.choice(len(dataset), num_samples, replace=False)
|
||||||
|
|
||||||
|
fig, axes = plt.subplots(num_samples, 3, figsize=(15, 5*num_samples))
|
||||||
|
if num_samples == 1:
|
||||||
|
axes = axes.reshape(1, -1)
|
||||||
|
|
||||||
|
fig.suptitle('Sample Predictions - Siamese Network для корреляции карт',
|
||||||
|
fontsize=16, fontweight='bold')
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
for idx, sample_idx in enumerate(indices):
|
||||||
|
sample = dataset[sample_idx]
|
||||||
|
|
||||||
|
google_img = sample['google_img'].unsqueeze(0).to(device)
|
||||||
|
yandex_img = sample['yandex_img'].unsqueeze(0).to(device)
|
||||||
|
true_label = sample['same_domain'].item()
|
||||||
|
|
||||||
|
# Predict
|
||||||
|
pred_similarity = model(google_img, yandex_img).item()
|
||||||
|
pred_label = int(pred_similarity >= 0.5)
|
||||||
|
|
||||||
|
# Denormalize images for visualization
|
||||||
|
google_np = google_img.squeeze(0).cpu().numpy().transpose(1, 2, 0)
|
||||||
|
yandex_np = yandex_img.squeeze(0).cpu().numpy().transpose(1, 2, 0)
|
||||||
|
|
||||||
|
# Denormalize (assuming ImageNet normalization)
|
||||||
|
mean = np.array([0.485, 0.456, 0.406])
|
||||||
|
std = np.array([0.229, 0.224, 0.225])
|
||||||
|
google_np = std * google_np + mean
|
||||||
|
yandex_np = std * yandex_np + mean
|
||||||
|
google_np = np.clip(google_np, 0, 1)
|
||||||
|
yandex_np = np.clip(yandex_np, 0, 1)
|
||||||
|
|
||||||
|
# Plot Google image
|
||||||
|
axes[idx, 0].imshow(google_np)
|
||||||
|
axes[idx, 0].set_title('Google Map', fontsize=12, fontweight='bold')
|
||||||
|
axes[idx, 0].axis('off')
|
||||||
|
|
||||||
|
# Plot Yandex image
|
||||||
|
axes[idx, 1].imshow(yandex_np)
|
||||||
|
axes[idx, 1].set_title('Yandex Map', fontsize=12, fontweight='bold')
|
||||||
|
axes[idx, 1].axis('off')
|
||||||
|
|
||||||
|
# Plot prediction info
|
||||||
|
axes[idx, 2].axis('off')
|
||||||
|
|
||||||
|
# Determine color based on correctness
|
||||||
|
is_correct = (pred_label == true_label)
|
||||||
|
color = 'green' if is_correct else 'red'
|
||||||
|
result = '✓ Correct' if is_correct else '✗ Incorrect'
|
||||||
|
|
||||||
|
info_text = f"""
|
||||||
|
Prediction: {pred_similarity:.4f}
|
||||||
|
Predicted Label: {'Same' if pred_label == 1 else 'Different'}
|
||||||
|
True Label: {'Same' if true_label == 1 else 'Different'}
|
||||||
|
|
||||||
|
{result}
|
||||||
|
"""
|
||||||
|
|
||||||
|
axes[idx, 2].text(0.5, 0.5, info_text,
|
||||||
|
ha='center', va='center',
|
||||||
|
fontsize=12,
|
||||||
|
bbox=dict(boxstyle='round', facecolor=color, alpha=0.2),
|
||||||
|
transform=axes[idx, 2].transAxes)
|
||||||
|
|
||||||
|
plt.tight_layout()
|
||||||
|
|
||||||
|
if save_path:
|
||||||
|
plt.savefig(save_path, dpi=300, bbox_inches='tight')
|
||||||
|
print(f'Sample predictions saved to {save_path}')
|
||||||
|
|
||||||
|
plt.show()
|
||||||
|
|
||||||
|
|
||||||
|
def visualize_feature_space(model: nn.Module, dataloader: DataLoader,
|
||||||
|
device: torch.device, max_samples: int = 500,
|
||||||
|
save_path: str = None):
|
||||||
|
"""Визуализация пространства признаков с помощью t-SNE"""
|
||||||
|
from sklearn.manifold import TSNE
|
||||||
|
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
all_features_google = []
|
||||||
|
all_features_yandex = []
|
||||||
|
all_labels = []
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
for i, batch in enumerate(tqdm(dataloader, desc='Extracting features')):
|
||||||
|
if i * dataloader.batch_size >= max_samples:
|
||||||
|
break
|
||||||
|
|
||||||
|
google_img = batch['google_img'].to(device)
|
||||||
|
yandex_img = batch['yandex_img'].to(device)
|
||||||
|
labels = batch['same_domain'].cpu().numpy()
|
||||||
|
|
||||||
|
# Extract features
|
||||||
|
features_google = model.extract_features(google_img).cpu().numpy()
|
||||||
|
features_yandex = model.extract_features(yandex_img).cpu().numpy()
|
||||||
|
|
||||||
|
all_features_google.append(features_google)
|
||||||
|
all_features_yandex.append(features_yandex)
|
||||||
|
all_labels.append(labels)
|
||||||
|
|
||||||
|
all_features_google = np.concatenate(all_features_google, axis=0)
|
||||||
|
all_features_yandex = np.concatenate(all_features_yandex, axis=0)
|
||||||
|
all_labels = np.concatenate(all_labels, axis=0)
|
||||||
|
|
||||||
|
# Combine features
|
||||||
|
all_features = np.concatenate([all_features_google, all_features_yandex], axis=0)
|
||||||
|
all_labels = np.concatenate([all_labels, all_labels], axis=0)
|
||||||
|
|
||||||
|
print(f'\nApplying t-SNE to {all_features.shape[0]} samples...')
|
||||||
|
tsne = TSNE(n_components=2, random_state=42, perplexity=30)
|
||||||
|
features_2d = tsne.fit_transform(all_features)
|
||||||
|
|
||||||
|
# Split back into Google and Yandex
|
||||||
|
n_samples = len(all_labels) // 2
|
||||||
|
features_google_2d = features_2d[:n_samples]
|
||||||
|
features_yandex_2d = features_2d[n_samples:]
|
||||||
|
labels = all_labels[:n_samples]
|
||||||
|
|
||||||
|
# Plot
|
||||||
|
fig, axes = plt.subplots(1, 2, figsize=(20, 8))
|
||||||
|
|
||||||
|
# Google features
|
||||||
|
for label in [0, 1]:
|
||||||
|
mask = labels == label
|
||||||
|
axes[0].scatter(
|
||||||
|
features_google_2d[mask, 0],
|
||||||
|
features_google_2d[mask, 1],
|
||||||
|
c='green' if label == 1 else 'red',
|
||||||
|
label='Same Domain' if label == 1 else 'Different Domains',
|
||||||
|
alpha=0.6,
|
||||||
|
s=50
|
||||||
|
)
|
||||||
|
axes[0].set_title('Google Maps Features (t-SNE)', fontsize=14, fontweight='bold')
|
||||||
|
axes[0].set_xlabel('t-SNE Component 1', fontsize=12)
|
||||||
|
axes[0].set_ylabel('t-SNE Component 2', fontsize=12)
|
||||||
|
axes[0].legend(fontsize=11)
|
||||||
|
axes[0].grid(True, alpha=0.3)
|
||||||
|
|
||||||
|
# Yandex features
|
||||||
|
for label in [0, 1]:
|
||||||
|
mask = labels == label
|
||||||
|
axes[1].scatter(
|
||||||
|
features_yandex_2d[mask, 0],
|
||||||
|
features_yandex_2d[mask, 1],
|
||||||
|
c='green' if label == 1 else 'red',
|
||||||
|
label='Same Domain' if label == 1 else 'Different Domains',
|
||||||
|
alpha=0.6,
|
||||||
|
s=50
|
||||||
|
)
|
||||||
|
axes[1].set_title('Yandex Maps Features (t-SNE)', fontsize=14, fontweight='bold')
|
||||||
|
axes[1].set_xlabel('t-SNE Component 1', fontsize=12)
|
||||||
|
axes[1].set_ylabel('t-SNE Component 2', fontsize=12)
|
||||||
|
axes[1].legend(fontsize=11)
|
||||||
|
axes[1].grid(True, alpha=0.3)
|
||||||
|
|
||||||
|
plt.tight_layout()
|
||||||
|
|
||||||
|
if save_path:
|
||||||
|
plt.savefig(save_path, dpi=300, bbox_inches='tight')
|
||||||
|
print(f'Feature space visualization saved to {save_path}')
|
||||||
|
|
||||||
|
plt.show()
|
||||||
|
|
||||||
|
|
||||||
|
def generate_correlation_heatmap(model: nn.Module, dataloader: DataLoader,
|
||||||
|
device: torch.device, num_samples: int = 20,
|
||||||
|
save_path: str = None):
|
||||||
|
"""Создание тепловой карты корреляций между снимками"""
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
# Collect samples
|
||||||
|
google_images = []
|
||||||
|
yandex_images = []
|
||||||
|
labels = []
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
for i, batch in enumerate(dataloader):
|
||||||
|
if len(google_images) >= num_samples:
|
||||||
|
break
|
||||||
|
|
||||||
|
google_img = batch['google_img'].to(device)
|
||||||
|
yandex_img = batch['yandex_img'].to(device)
|
||||||
|
label = batch['same_domain']
|
||||||
|
|
||||||
|
google_images.append(google_img[:1])
|
||||||
|
yandex_images.append(yandex_img[:1])
|
||||||
|
labels.append(label[:1].item())
|
||||||
|
|
||||||
|
google_images = torch.cat(google_images[:num_samples], dim=0)
|
||||||
|
yandex_images = torch.cat(yandex_images[:num_samples], dim=0)
|
||||||
|
|
||||||
|
# Compute similarity matrix
|
||||||
|
similarity_matrix = np.zeros((num_samples, num_samples))
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
for i in tqdm(range(num_samples), desc='Computing correlations'):
|
||||||
|
for j in range(num_samples):
|
||||||
|
google_i = google_images[i:i+1]
|
||||||
|
yandex_j = yandex_images[j:j+1]
|
||||||
|
similarity = model(google_i, yandex_j).item()
|
||||||
|
similarity_matrix[i, j] = similarity
|
||||||
|
|
||||||
|
# Plot heatmap
|
||||||
|
plt.figure(figsize=(14, 12))
|
||||||
|
im = plt.imshow(similarity_matrix, cmap='RdYlGn', aspect='auto', vmin=0, vmax=1)
|
||||||
|
|
||||||
|
plt.colorbar(im, label='Similarity Score', fraction=0.046, pad=0.04)
|
||||||
|
plt.title('Correlation Heatmap: Google vs Yandex Maps\n(Матрица корреляций снимков)',
|
||||||
|
fontsize=16, fontweight='bold', pad=20)
|
||||||
|
plt.xlabel('Yandex Map Index', fontsize=12)
|
||||||
|
plt.ylabel('Google Map Index', fontsize=12)
|
||||||
|
|
||||||
|
# Add grid
|
||||||
|
plt.grid(True, which='both', color='gray', linestyle='-', linewidth=0.5, alpha=0.3)
|
||||||
|
|
||||||
|
# Add text annotations for diagonal (same pairs)
|
||||||
|
for i in range(min(num_samples, 10)): # Annotate first 10 for readability
|
||||||
|
if labels[i] == 1: # True match
|
||||||
|
plt.text(i, i, '✓', ha='center', va='center',
|
||||||
|
color='white', fontsize=12, fontweight='bold')
|
||||||
|
|
||||||
|
plt.tight_layout()
|
||||||
|
|
||||||
|
if save_path:
|
||||||
|
plt.savefig(save_path, dpi=300, bbox_inches='tight')
|
||||||
|
print(f'Correlation heatmap saved to {save_path}')
|
||||||
|
|
||||||
|
plt.show()
|
||||||
|
|
||||||
|
# Print statistics
|
||||||
|
diagonal = np.diag(similarity_matrix)
|
||||||
|
off_diagonal = similarity_matrix[~np.eye(num_samples, dtype=bool)]
|
||||||
|
|
||||||
|
print(f'\n--- Correlation Statistics ---')
|
||||||
|
print(f'Diagonal (matched pairs):')
|
||||||
|
print(f' Mean: {diagonal.mean():.4f}')
|
||||||
|
print(f' Std: {diagonal.std():.4f}')
|
||||||
|
print(f'\nOff-diagonal (mismatched pairs):')
|
||||||
|
print(f' Mean: {off_diagonal.mean():.4f}')
|
||||||
|
print(f' Std: {off_diagonal.std():.4f}')
|
||||||
|
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# MAIN TRAINING SCRIPT
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
def main():
|
||||||
|
"""Основная функция обучения"""
|
||||||
|
|
||||||
|
# Configuration
|
||||||
|
config_dict = config.copy()
|
||||||
|
if isinstance(config_dict.get('image_size'), list):
|
||||||
|
config_dict['image_size'] = tuple(config_dict['image_size'])
|
||||||
|
|
||||||
|
# Device
|
||||||
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||||
|
print(f'\n{"="*70}')
|
||||||
|
print(f'Siamese Network Training for Map Correlation')
|
||||||
|
print(f'Обучение сиамской сети для корреляции снимков')
|
||||||
|
print(f'{"="*70}')
|
||||||
|
print(f'Using device: {device}')
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
print(f'GPU: {torch.cuda.get_device_name(0)}')
|
||||||
|
print(f'GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB')
|
||||||
|
|
||||||
|
# Create data loaders
|
||||||
|
print(f'\n{"="*70}')
|
||||||
|
print('Creating data loaders...')
|
||||||
|
print(f'{"="*70}')
|
||||||
|
|
||||||
|
train_loader, val_loader = create_data_loaders(
|
||||||
|
root_dir=config_dict['data_dir'],
|
||||||
|
batch_size=config_dict['batch_size'],
|
||||||
|
train_split=config_dict['train_split'],
|
||||||
|
num_workers=config_dict['num_workers'],
|
||||||
|
image_size=config_dict['image_size'],
|
||||||
|
augment_train=True,
|
||||||
|
augment_val=False,
|
||||||
|
device=device
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f'Train batches: {len(train_loader)}')
|
||||||
|
print(f'Val batches: {len(val_loader)}')
|
||||||
|
print(f'Train samples: {len(train_loader.dataset)}')
|
||||||
|
print(f'Val samples: {len(val_loader.dataset)}')
|
||||||
|
|
||||||
|
# Create model
|
||||||
|
print(f'\n{"="*70}')
|
||||||
|
print('Creating model...')
|
||||||
|
print(f'{"="*70}')
|
||||||
|
|
||||||
|
model = create_similarity_model(
|
||||||
|
model_type='backbone',
|
||||||
|
input_size=config_dict['image_size'][0] if isinstance(config_dict['image_size'], (tuple, list)) else config_dict['image_size'],
|
||||||
|
input_channels=3,
|
||||||
|
backbone_name='resnet18',
|
||||||
|
pretrained=True,
|
||||||
|
dropout_rate=0.3,
|
||||||
|
use_batch_norm=True
|
||||||
|
)
|
||||||
|
|
||||||
|
total_params = sum(p.numel() for p in model.parameters())
|
||||||
|
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
||||||
|
print(f'Total parameters: {total_params:,}')
|
||||||
|
print(f'Trainable parameters: {trainable_params:,}')
|
||||||
|
|
||||||
|
# Create trainer
|
||||||
|
trainer = SimilarityTrainer(
|
||||||
|
model=model,
|
||||||
|
trainloader=train_loader,
|
||||||
|
valloader=val_loader,
|
||||||
|
device=device,
|
||||||
|
config=config_dict
|
||||||
|
)
|
||||||
|
|
||||||
|
# Train model
|
||||||
|
print(f'\n{"="*70}')
|
||||||
|
print('Starting training...')
|
||||||
|
print(f'{"="*70}')
|
||||||
|
|
||||||
|
history = trainer.train(config_dict['epochs'])
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# VISUALIZATION AND RESULTS
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
print(f'\n{"="*70}')
|
||||||
|
print('Generating visualizations...')
|
||||||
|
print(f'{"="*70}')
|
||||||
|
|
||||||
|
output_dir = config_dict.get('output_dir', 'runs/similarity')
|
||||||
|
vis_dir = os.path.join(output_dir, 'visualizations')
|
||||||
|
os.makedirs(vis_dir, exist_ok=True)
|
||||||
|
|
||||||
|
# 1. Training history
|
||||||
|
print('\n1. Plotting training history...')
|
||||||
|
plot_training_history(
|
||||||
|
history,
|
||||||
|
save_path=os.path.join(vis_dir, 'training_history.png')
|
||||||
|
)
|
||||||
|
|
||||||
|
# 2. Validation metrics
|
||||||
|
print('\n2. Computing validation predictions...')
|
||||||
|
trainer.model.eval()
|
||||||
|
val_predictions = []
|
||||||
|
val_targets = []
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
for batch in tqdm(val_loader, desc='Validation'):
|
||||||
|
google_img = batch['google_img'].to(device)
|
||||||
|
yandex_img = batch['yandex_img'].to(device)
|
||||||
|
target = batch['same_domain'].float().unsqueeze(1)
|
||||||
|
|
||||||
|
output = trainer.model(google_img, yandex_img)
|
||||||
|
val_predictions.append(output.cpu())
|
||||||
|
val_targets.append(target.cpu())
|
||||||
|
|
||||||
|
val_predictions = torch.cat(val_predictions, dim=0)
|
||||||
|
val_targets = torch.cat(val_targets, dim=0)
|
||||||
|
|
||||||
|
# 3. ROC curve
|
||||||
|
print('\n3. Plotting ROC curve...')
|
||||||
|
roc_auc = plot_roc_curve(
|
||||||
|
val_predictions,
|
||||||
|
val_targets,
|
||||||
|
save_path=os.path.join(vis_dir, 'roc_curve.png')
|
||||||
|
)
|
||||||
|
print(f'ROC AUC Score: {roc_auc:.4f}')
|
||||||
|
|
||||||
|
# 4. Confusion matrix
|
||||||
|
print('\n4. Plotting confusion matrix...')
|
||||||
|
plot_confusion_matrix(
|
||||||
|
val_predictions,
|
||||||
|
val_targets,
|
||||||
|
threshold=0.5,
|
||||||
|
save_path=os.path.join(vis_dir, 'confusion_matrix.png')
|
||||||
|
)
|
||||||
|
|
||||||
|
# 5. Similarity distribution
|
||||||
|
print('\n5. Plotting similarity distribution...')
|
||||||
|
plot_similarity_distribution(
|
||||||
|
val_predictions,
|
||||||
|
val_targets,
|
||||||
|
save_path=os.path.join(vis_dir, 'similarity_distribution.png')
|
||||||
|
)
|
||||||
|
|
||||||
|
# 6. Sample predictions
|
||||||
|
print('\n6. Visualizing sample predictions...')
|
||||||
|
visualize_sample_predictions(
|
||||||
|
trainer.model,
|
||||||
|
val_loader.dataset,
|
||||||
|
device,
|
||||||
|
num_samples=8,
|
||||||
|
save_path=os.path.join(vis_dir, 'sample_predictions.png')
|
||||||
|
)
|
||||||
|
|
||||||
|
# 7. Feature space visualization
|
||||||
|
print('\n7. Visualizing feature space (t-SNE)...')
|
||||||
|
visualize_feature_space(
|
||||||
|
trainer.model,
|
||||||
|
val_loader,
|
||||||
|
device,
|
||||||
|
max_samples=500,
|
||||||
|
save_path=os.path.join(vis_dir, 'feature_space_tsne.png')
|
||||||
|
)
|
||||||
|
|
||||||
|
# 8. Correlation heatmap
|
||||||
|
print('\n8. Generating correlation heatmap...')
|
||||||
|
generate_correlation_heatmap(
|
||||||
|
trainer.model,
|
||||||
|
val_loader,
|
||||||
|
device,
|
||||||
|
num_samples=20,
|
||||||
|
save_path=os.path.join(vis_dir, 'correlation_heatmap.png')
|
||||||
|
)
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# FINAL RESULTS SUMMARY
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
print(f'\n{"="*70}')
|
||||||
|
print('FINAL RESULTS SUMMARY')
|
||||||
|
print('ИТОГОВЫЕ РЕЗУЛЬТАТЫ')
|
||||||
|
print(f'{"="*70}')
|
||||||
|
|
||||||
|
print(f'\nBest Validation Loss: {trainer.best_val_loss:.4f}')
|
||||||
|
print(f'Final Validation Accuracy: {history["val_accuracy"][-1]:.4f}')
|
||||||
|
print(f'Final Validation F1 Score: {history["val_f1"][-1]:.4f}')
|
||||||
|
print(f'Final Validation Precision: {history["val_precision"][-1]:.4f}')
|
||||||
|
print(f'Final Validation Recall: {history["val_recall"][-1]:.4f}')
|
||||||
|
print(f'ROC AUC Score: {roc_auc:.4f}')
|
||||||
|
|
||||||
|
print(f'\nCheckpoints saved to: {os.path.join(output_dir, "checkpoints")}')
|
||||||
|
print(f'Visualizations saved to: {vis_dir}')
|
||||||
|
|
||||||
|
print(f'\n{"="*70}')
|
||||||
|
print('Training and visualization completed successfully!')
|
||||||
|
print('Обучение и визуализация завершены успешно!')
|
||||||
|
print(f'{"="*70}\n')
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
||||||
248
models/SiaN-similarity/train.py
Normal file
248
models/SiaN-similarity/train.py
Normal file
@@ -0,0 +1,248 @@
|
|||||||
|
"""
|
||||||
|
Training script for image similarity estimation.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.optim as optim
|
||||||
|
from dataloader import config, create_data_loaders
|
||||||
|
from model import SimilarityCNN, SimilarityLoss, create_similarity_model
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
|
||||||
|
class SimilarityTrainer:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model: nn.Module,
|
||||||
|
train_loader: DataLoader,
|
||||||
|
val_loader: DataLoader,
|
||||||
|
device: torch.device,
|
||||||
|
config: dict,
|
||||||
|
):
|
||||||
|
self.model = model.to(device)
|
||||||
|
self.train_loader = train_loader
|
||||||
|
self.val_loader = val_loader
|
||||||
|
self.device = device
|
||||||
|
self.config = config
|
||||||
|
|
||||||
|
self.criterion = SimilarityLoss()
|
||||||
|
self.optimizer = optim.Adam(
|
||||||
|
model.parameters(),
|
||||||
|
lr=config.get("learning_rate", 2e-4),
|
||||||
|
betas=(config.get("beta1", 0.5), config.get("beta2", 0.999)),
|
||||||
|
)
|
||||||
|
|
||||||
|
self.writer = None
|
||||||
|
self.best_val_loss = float("inf")
|
||||||
|
self.epochs_without_improvement = 0
|
||||||
|
|
||||||
|
def train_epoch(self, epoch: int) -> dict:
|
||||||
|
self.model.train()
|
||||||
|
total_loss = 0
|
||||||
|
total_samples = 0
|
||||||
|
|
||||||
|
pbar = tqdm(self.train_loader, desc=f"Epoch {epoch}")
|
||||||
|
for batch_idx, batch in enumerate(pbar):
|
||||||
|
google_img = batch["google_img"].to(self.device)
|
||||||
|
yandex_img = batch["yandex_img"].to(self.device)
|
||||||
|
target = batch["same_domain"].float().to(self.device).unsqueeze(1)
|
||||||
|
|
||||||
|
self.optimizer.zero_grad()
|
||||||
|
|
||||||
|
output = self.model(google_img, yandex_img)
|
||||||
|
loss = self.criterion(output, target)
|
||||||
|
|
||||||
|
loss.backward()
|
||||||
|
self.optimizer.step()
|
||||||
|
|
||||||
|
total_loss += loss.item() * google_img.size(0)
|
||||||
|
total_samples += google_img.size(0)
|
||||||
|
|
||||||
|
if batch_idx % self.config.get("log_interval", 10) == 0:
|
||||||
|
metrics = self.criterion.compute_metrics(output, target)
|
||||||
|
pbar.set_postfix(
|
||||||
|
{
|
||||||
|
"loss": loss.item(),
|
||||||
|
"acc": metrics["accuracy"],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.writer:
|
||||||
|
self.writer.add_scalar(
|
||||||
|
"train/loss",
|
||||||
|
loss.item(),
|
||||||
|
epoch * len(self.train_loader) + batch_idx,
|
||||||
|
)
|
||||||
|
self.writer.add_scalar(
|
||||||
|
"train/accuracy",
|
||||||
|
metrics["accuracy"],
|
||||||
|
epoch * len(self.train_loader) + batch_idx,
|
||||||
|
)
|
||||||
|
|
||||||
|
avg_loss = total_loss / total_samples
|
||||||
|
return {"loss": avg_loss}
|
||||||
|
|
||||||
|
def validate(self) -> dict:
|
||||||
|
self.model.eval()
|
||||||
|
total_loss = 0
|
||||||
|
total_samples = 0
|
||||||
|
all_metrics = []
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
for batch in tqdm(self.val_loader, desc="Validation"):
|
||||||
|
google_img = batch["google_img"].to(self.device)
|
||||||
|
yandex_img = batch["yandex_img"].to(self.device)
|
||||||
|
target = batch["same_domain"].float().to(self.device).unsqueeze(1)
|
||||||
|
|
||||||
|
output = self.model(google_img, yandex_img)
|
||||||
|
loss = self.criterion(output, target)
|
||||||
|
|
||||||
|
total_loss += loss.item() * google_img.size(0)
|
||||||
|
total_samples += google_img.size(0)
|
||||||
|
|
||||||
|
metrics = self.criterion.compute_metrics(output, target)
|
||||||
|
all_metrics.append(metrics)
|
||||||
|
|
||||||
|
avg_loss = total_loss / total_samples
|
||||||
|
|
||||||
|
avg_metrics = {}
|
||||||
|
for key in all_metrics[0].keys():
|
||||||
|
avg_metrics[key] = sum(m[key] for m in all_metrics) / len(all_metrics)
|
||||||
|
|
||||||
|
return {"loss": avg_loss, **avg_metrics}
|
||||||
|
|
||||||
|
def train(self, num_epochs: int):
|
||||||
|
log_dir = self.config.get("output_dir", "runs/similarity")
|
||||||
|
os.makedirs(log_dir, exist_ok=True)
|
||||||
|
self.writer = SummaryWriter(log_dir)
|
||||||
|
|
||||||
|
print(f"Starting training for {num_epochs} epochs")
|
||||||
|
print(f"Logging to: {log_dir}")
|
||||||
|
|
||||||
|
for epoch in range(1, num_epochs + 1):
|
||||||
|
print(f"\nEpoch {epoch}/{num_epochs}")
|
||||||
|
|
||||||
|
train_metrics = self.train_epoch(epoch)
|
||||||
|
val_metrics = self.validate()
|
||||||
|
|
||||||
|
print(f"Train Loss: {train_metrics['loss']:.4f}")
|
||||||
|
print(f"Val Loss: {val_metrics['loss']:.4f}")
|
||||||
|
print(f"Val Accuracy: {val_metrics['accuracy']:.4f}")
|
||||||
|
print(f"Val F1: {val_metrics['f1']:.4f}")
|
||||||
|
|
||||||
|
if self.writer:
|
||||||
|
self.writer.add_scalar("epoch/train_loss", train_metrics["loss"], epoch)
|
||||||
|
self.writer.add_scalar("epoch/val_loss", val_metrics["loss"], epoch)
|
||||||
|
self.writer.add_scalar(
|
||||||
|
"epoch/val_accuracy", val_metrics["accuracy"], epoch
|
||||||
|
)
|
||||||
|
|
||||||
|
if val_metrics["loss"] < self.best_val_loss:
|
||||||
|
self.best_val_loss = val_metrics["loss"]
|
||||||
|
self.epochs_without_improvement = 0
|
||||||
|
self.save_checkpoint(epoch, val_metrics["loss"], is_best=True)
|
||||||
|
print(f"New best model saved with val loss: {val_metrics['loss']:.4f}")
|
||||||
|
else:
|
||||||
|
self.epochs_without_improvement += 1
|
||||||
|
self.save_checkpoint(epoch, val_metrics["loss"], is_best=False)
|
||||||
|
|
||||||
|
patience = self.config.get("early_stopping_patience", 20)
|
||||||
|
if self.epochs_without_improvement >= patience:
|
||||||
|
print(
|
||||||
|
f"Early stopping triggered after {patience} epochs without improvement"
|
||||||
|
)
|
||||||
|
break
|
||||||
|
|
||||||
|
self.writer.close()
|
||||||
|
|
||||||
|
def save_checkpoint(self, epoch: int, val_loss: float, is_best: bool = False):
|
||||||
|
checkpoint_dir = os.path.join(
|
||||||
|
self.config.get("output_dir", "runs/similarity"), "checkpoints"
|
||||||
|
)
|
||||||
|
os.makedirs(checkpoint_dir, exist_ok=True)
|
||||||
|
|
||||||
|
checkpoint = {
|
||||||
|
"epoch": epoch,
|
||||||
|
"model_state_dict": self.model.state_dict(),
|
||||||
|
"optimizer_state_dict": self.optimizer.state_dict(),
|
||||||
|
"val_loss": val_loss,
|
||||||
|
"config": self.config,
|
||||||
|
}
|
||||||
|
|
||||||
|
checkpoint_path = os.path.join(checkpoint_dir, f"checkpoint_epoch_{epoch}.pt")
|
||||||
|
torch.save(checkpoint, checkpoint_path)
|
||||||
|
|
||||||
|
if is_best:
|
||||||
|
best_path = os.path.join(checkpoint_dir, "best_model.pt")
|
||||||
|
torch.save(checkpoint, best_path)
|
||||||
|
|
||||||
|
def load_checkpoint(self, checkpoint_path: str):
|
||||||
|
checkpoint = torch.load(checkpoint_path, map_location=self.device)
|
||||||
|
self.model.load_state_dict(checkpoint["model_state_dict"])
|
||||||
|
self.optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
|
||||||
|
return checkpoint["epoch"], checkpoint["val_loss"]
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
# Use config from dataloader.py
|
||||||
|
config_dict = config.copy()
|
||||||
|
|
||||||
|
# Ensure image_size is tuple
|
||||||
|
if isinstance(config_dict.get("image_size"), list):
|
||||||
|
config_dict["image_size"] = tuple(config_dict["image_size"])
|
||||||
|
|
||||||
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
print(f"Using device: {device}")
|
||||||
|
|
||||||
|
print("Creating data loaders...")
|
||||||
|
train_loader, val_loader = create_data_loaders(
|
||||||
|
root_dir=config_dict["data_dir"],
|
||||||
|
batch_size=config_dict["batch_size"],
|
||||||
|
train_split=config_dict["train_split"],
|
||||||
|
num_workers=config_dict["num_workers"],
|
||||||
|
image_size=config_dict["image_size"],
|
||||||
|
augment_train=True,
|
||||||
|
augment_val=False,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"Train batches: {len(train_loader)}")
|
||||||
|
print(f"Val batches: {len(val_loader)}")
|
||||||
|
|
||||||
|
print("Creating model...")
|
||||||
|
model = create_similarity_model(
|
||||||
|
model_type="cnn",
|
||||||
|
input_size=config_dict["image_size"][0]
|
||||||
|
if isinstance(config_dict["image_size"], (tuple, list))
|
||||||
|
else config_dict["image_size"],
|
||||||
|
input_channels=3,
|
||||||
|
hidden_channels=64,
|
||||||
|
num_blocks=4,
|
||||||
|
dropout_rate=0.3,
|
||||||
|
use_batch_norm=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")
|
||||||
|
|
||||||
|
trainer = SimilarityTrainer(
|
||||||
|
model=model,
|
||||||
|
train_loader=train_loader,
|
||||||
|
val_loader=val_loader,
|
||||||
|
device=device,
|
||||||
|
config=config_dict,
|
||||||
|
)
|
||||||
|
|
||||||
|
print("Starting training...")
|
||||||
|
trainer.train(config_dict["epochs"])
|
||||||
|
|
||||||
|
print("Training completed!")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
Reference in New Issue
Block a user